用c++递归lambda重写二叉树的生成与遍历完整程序

 
Category: DSA

写在前面

之前写过一篇有关二叉树的生成与遍历的C++版本, 但是当时的递归用的是两个函数, 写法有点臃肿, 这次重写一下代码, 然后对树的结构部分加以改进, 不需要getroot()方法, 直接在函数内部传入根节点, 类的声明与实现分离.

完整代码见dsa/Binary_Tree.cpp at main · Apocaly-pse/dsa (github.com);

声明部分

#ifndef BTREE
#define BTREE

#include <iostream>
#include <queue>
#include <stack>
#include <vector>
#include <functional> // 递归lambda


using namespace std;

// 重载<<操作符输出数组
template <typename T>
ostream& operator<<(ostream& os, const vector<T>& v) {
    os << "[";
    for (auto it = v.begin(); it != v.end(); it++) {
        cout << *it << (it != v.end() - 1 ? ", " : "] ");
    }
    return os;
}

// 二叉树节点的结构体实现, 参考LeetCode
struct TreeNode {
    int val;
    TreeNode* left;
    TreeNode* right;

    TreeNode() : val(0), left(nullptr), right(nullptr) {}
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
    TreeNode(int x, TreeNode* left, TreeNode* right)
        : val(x), left(left), right(right) {}
};


class BinaryTree {
public:
    BinaryTree() : root(nullptr) {}
    ~BinaryTree() {}
    // 树的生成
    void add_iter(int item); // 迭代方式生成树(广度遍历逆操作)
    void add_recur1();       // 递归捕获用户输入方式
    void add_recur2(vector<int>& item); // 直接遍历数组实现
    // 遍历部分
    void breadth_travel(); // bfs
    void pre_order();      //前序(递归)
    void pre_order1();     //迭代
    void in_order();       //中序(递归
    void in_order1();      //迭代(指针)
    void in_order2();      //迭代(通解)
    void post_order();	   //后序(递归)
    void post_order1();    //从前序的迭代逆序得到
    void post_order2();    //通解

private:
    TreeNode* root;
};
#endif

树的生成

这里重点说一下树的生成, 这里给出了三种方法:

  1. 第一种就是经典的广度优先搜索方法, 通过队列实现节点的添加.
    void build(int item) {
        TreeNode* node = new TreeNode(item);
         
        if (!root) {
            root = node;
            return;
        }
        // 用队列模拟二叉树的层序遍历
        queue<TreeNode*> que;
        que.push(root);
        while (!que.empty()) {
            TreeNode* cur_node = que.front();
            que.pop();
         
            if (!cur_node->left) {
                cur_node->left = node;
                return;
            } else
                que.push(cur_node->left);
         
            // 接下来同样判断右边
            if (!cur_node->right) {
                cur_node->right = node;
                return;
            } else
                que.push(cur_node->right);
        }
    }
    
  2. 第二种是比较直观的深搜递归写法, 通过读取用户输入的方式给出, 递归的终止条件是取空节点全为0, 这里算是一个弊端, 因为不能捕获任意数据.

    // 递归版本生成二叉树(需要捕获用户输入)
    void BinaryTree::add_recur1() {
        function<TreeNode*(void)> f = [&](void) {
            int data;
            cin >> data;
            if (data == 0) { return (TreeNode*)nullptr; }
            TreeNode* node = new TreeNode(data);
            node->val = data;
            node->left = f();
            node->right = f();
            return node;
        };
        root = f();
    }
    

    代码方面, 主要注意的是递归跳出的返回值, 需要加上类型转换, 这里其实应该使用C++ style的static_cast, 为方便就直接用C强制类型转换来做了.

  3. 第三种跟上面的很相似, 但是不需要捕获用户的输入, 将待读取的节点信息存为数组, 同样, 空节点用0表示, 省去了输入的麻烦.
    // 递归生成二叉树: 直接遍历外部数组实现
    void BinaryTree::add_recur2(vector<int>& item) {
        function<TreeNode*(void)> f = [&](void) {
            if (item.empty() || item.front() == 0) {
                item.erase(item.begin());
                return (TreeNode*)nullptr;
            }
            TreeNode* node = new TreeNode(item.front());
            node->val = item.front();
            item.erase(item.begin());
            node->left = f();
            node->right = f();
            return node;
        };
        root = f();
    }
    

    代码方面, 这里直接对输入的数组进行删减操作了, 其实应该设置常量然后重新拷贝为内部数组, 但是二叉树类都在内部实现, 可以忽略一些数据修改问题.

树的遍历

这里我就着重说递归的lambda写法以及深搜遍历时候的迭代写法了.

递归lambda

上面树的生成部分我采用function<>类模板, 这样做的好处是能在lambda内调用自身, 但是缺点是要写两遍参数列表的类型, 下面的auto很好地解决了这个问题, 不过这种写法相当于把函数自身的地址传入了函数, 所以写的时候需要加上一个参数.

void BinaryTree::pre_order() {
    vector<int> ret;
    auto recur_0 = [&ret](auto&& self, TreeNode* node) {
        // 递归终止条件
        if (!node) return;
        ret.push_back(node->val);
        self(self, node->left);
        self(self, node->right);
    };
    recur_0(recur_0, root);
    cout << ret << endl;
}

对于中序和后序递归格式同样.

前序遍历的迭代写法

这个算是比较简单的一种了, 需要理解前序遍历的栈模拟, 下面举一个例子:

	  1
    /   \
   2     3
  / \   / \
 4   5 6   7
前序遍历(中左右): 1->2->4->5->3->6->7

先读取到的是根结点, 然后是左子树的根节点, 直到没有根节点, 然后去找左节点, 右节点……

从递归的情况可以得到, 需要先将遍历到的根节点入栈, 只要栈不空, 就一直重复此过程. 需要注意入栈的顺序, 将根节点的右节点,左节点依次入栈, 才能在弹栈的时候先读取到左节点, 栈的先入后出特性. 下面给出代码:

void BinaryTree::pre_order1() {
    vector<int> ret;
    stack<TreeNode*> st;
    if (root) st.push(root);
    while (!st.empty()) {
        auto node = st.top();
        st.pop();
        ret.push_back(node->val);
        if (node->right) st.push(node->right);
        if (node->left) st.push(node->left);
    }
    cout << ret << endl;
}

单步执行时栈的情况为:

栈的变化 -> 结果数组
1      -> [1]
3      -> [1]
3 2    -> [1]
3      -> [1,2]
3 5    -> [1,2]
3 5 4  -> [1,2]
3 5    -> [1,2,4]
3      -> [1,2,4,5]
       -> [1,2,4,5,3]
7      -> [1,2,4,5,3]
7 6    -> [1,2,4,5,3]
7      -> [1,2,4,5,3,6]
       -> [1,2,4,5,3,6,7]

后序遍历的迭代写法

这里有一个小技巧, 可以直接从前序遍历得到的数组逆序得到后序遍历的结果. 因为后序遍历是左右中, 恰好与前序的相反(这里的相反大家可以自行验证, 由于前序的中节点在最前面,也即最先出栈,而后序遍历的中节点在最后, 这才导致了二者的遍历结果顺序相反)

void BinaryTree::post_order1() {
    stack<TreeNode*> st;
    vector<int> ret;
    if (root) { st.push(root); }
    while (!st.empty()) {
        auto node = st.top();
        st.pop();
        ret.push_back(node->val);
        if (node->left) st.push(node->left);
        if (node->right) st.push(node->right);
    }
    // 进行逆序操作: 
    // vector<int> ret1;
    // for (auto it = ret.rbegin(); it != ret.rend(); it++)
    // ret1.push_back(*it);
    // 或者采用逆序算法
    reverse(ret.begin(), ret.end());
    cout << ret << endl;
}

下面主要说一下后序遍历迭代形式的一般写法:

void BinaryTree::post_order2() {
    stack<TreeNode*> st;
    vector<int> ret;
    auto node = root;
    while (!st.empty() || node) {
        while (node) {
            st.push(node);
            // 遍历直到结点不含有左右结点(到达叶节点)
            node = node->left ? node->left : node->right;
        }
        node = st.top();
        st.pop();
        // 此时从叶节点开始弹栈存结果, 将左右结点都存了之后最后只剩下中节点
        ret.push_back(node->val);
        // 更新当前节点,如果是左节点则更新为右节点,否则该子树已遍历完,置空
        node =
            !st.empty() && st.top()->left == node ? st.top()->right : nullptr;
    }
    cout << ret << endl;
}

这里面的基本思想就是先到达叶节点, 然后弹栈存值, 用node指向当前节点, 将子树两叶结点都存入之后再开始放中节点.

中序遍历的迭代写法

这里先给出一种代码比较简洁的思路(不过初学者不太容易想到), 递归的本质就是采用函数的栈帧存放数据, 这里可以模拟一个栈, 然后根据中序遍历的顺序入栈, 即先遍历所有的左子结点, 然后在弹栈过程中存中节点和右节点.

void BinaryTree::in_order1() {
    vector<int> ret;
    stack<TreeNode*> st;
    auto cur = root;
    while (!st.empty() || cur) {
        if (cur) {
            st.push(cur);
            // 每次都先遍历到最左边的叶节点
            cur = cur->left;
        } else {
            cur = st.top();
            st.pop();
            // 然后先存最左边的叶节点(如果有)
            ret.push_back(cur->val);
            cur = cur->right;
        }
    }
    cout << ret << endl;
}

这里还有一种通用的解法:

比较符合中序遍历的逆序, 因为栈先入后出, 中序遍历为左中右, 所以先存右节点,然后是中节点和左节点

第11行的空节点是必要的, 因为入栈的顺序是右中左, 这里有可能会出现没有右节点或者左节点的情况, 这时候需要一个标志节点表示前一个入栈的结点为中节点. 当遍历到叶节点时候, 依然需要这个空节点占位, 否则最后判断到空节点时候不会进行结果数组的保存了.

当栈顶为空的时候, 就到了叶节点, 因为入栈顺序是先右后左的, 所以这里出栈就满足了中序的要求.

void BinaryTree::in_order2() {
    vector<int> ret;
    stack<TreeNode*> st;
    if (root) { st.push(root); }
    while (!st.empty()) {
        auto node = st.top();
        st.pop();
        if (node) {
            if (node->right) { st.push(node->right); }
            st.push(node);
            st.push(nullptr);
            if (node->left) { st.push(node->left); }
        } else {
            node = st.top();
            st.pop();
            ret.push_back(node->val);
        }
    }
    cout << ret << endl;
}

最后一种是我在力扣上看到的方法1, 非常直观, 但是时间消耗较长(因为需要两次遍历)

void BinaryTree::in_order3() {
    vector<int> ret{};
    stack<pair<int, TreeNode*>> st;
    if (root) st.push(make_pair(0, root));
    int color;
    TreeNode* node;
    while (!st.empty()) {
        auto [color, node] = st.top();
        st.pop();
        if (node == nullptr) continue;
        if (color == 0) {
            st.push(make_pair(0, node->right));
            st.push(make_pair(1, node));
            st.push(make_pair(0, node->left));
        } else {
            ret.push_back(node->val);
        }
    }
    cout << ret << endl;
}

通过标记来判断是否遍历到, 然后根据遍历逆序入栈即可. 这个方法的一点好处就是容易理解, 并且前序和后序都可以通过这个方法写出来.

ref