模拟栈的几个方法

任何递归操作都是通过操作系统提供的函数栈实现的,所以我们如果自己模拟一个栈,同样可以将递归操作转换为栈上的非递归操作。

计算机函数调用栈的原理

SS:栈段寄存器(stack-segment register),用于指示栈所在的内存块

SP:栈顶寄存器(stack pointer register),指向栈顶位置。

注意:X86 的栈顶是向下增长的,如果向栈中 PUSH 数据,则 SP 指向的内存地址是更低位置

以下都是以 X86 指令集为例。

请先阅读此文:当我们调用一个函数的时候,发生了什么?

以中序遍历为例

递归的代码如下:

private:
    void _process(TreeNode *root, vector<int> &ret)
    {
        if (root == nullptr)
        {
            return;
        }
        _process(root->left, ret);
        ret.push_back(root->val);        
        _process(root->right, ret);
    }

public:
    vector<int> inorderTraversal(TreeNode *root)
    {
        vector<int> ret;
        _process(root, ret);
        return ret;
    }

要改成非递归模式,我们需要模拟出一个调用栈。由于调用栈是随时可以访问的,所以应该将其设置为共有的,所以我们将它作为一个参数传入。

-    void _process(TreeNode *root, vector<int> &ret)
+    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
        if (root == nullptr)
        {
            return;
        }
        _process(root->left, ret);
        ret.push_back(root->val);        
        _process(root->right, ret);
    }

此外由于是改成非递归调用,需要将所有调用本函数的形式写成一样的。下面就是全部写成了:

_process(root, ret, callstack);

而为了保证形式的一致,我们加上了诸如 root = root->left; 的语句:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
        if (root == nullptr)
        {
            return;
        }
-       _process(root->left, ret);
+		root = root->left;
+       _process(root, ret, callstack);
        ret.push_back(root->val);        
-       _process(root->right, ret);
+		root = root->right;
+       _process(root, ret, callstack);
    }

但是这样又会导致原本的 root 丢失。怎么办?很简单,保存上下文!将其压入栈。所以开头加上一句

callstack.push(root);

而执行完还要恢复上下问,所以还得加上对应的 pop 语句:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
        if (root == nullptr)
        {
            return;
        }
+       callstack.push(root);
        root = root->left;
        _process(root, ret, callstack);
+       root = callstack.top();
+       callstack.pop();
        ret.push_back(root->val);
        root = root->right;
        _process(root, ret, callstack);
	}

pop 自然要考虑为空的情况。如果调用栈为空就可以退出程序了:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
        if (root == nullptr)
        {
            return;
        }
        callstack.push(root);
        root = root->left;
        _process(root, ret, callstack);
+       if(callstack.empty()){
+        return;
+       }
        root = callstack.top();
        callstack.pop();
        ret.push_back(root->val);
        root = root->right;
        _process(root, ret, callstack);
	}

另外,原本的 return 条件应该改成恢复执行:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
        if (root == nullptr)
        {
+            goto resume;
        }
        callstack.push(root);
        root = root->left;
        _process(root, ret, callstack);
+ resume:
        if(callstack.empty()){ 
          return;
        }
        root = callstack.top();
        callstack.pop();
        ret.push_back(root->val);
        root = root->right;
        _process(root, ret, callstack);
	}

两个 _process 调用也可以直接跳转到函数头:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
+ proc:
        if (root == nullptr)
        {
            goto resume;
        }
        callstack.push(root);
        root = root->left;
+       goto proc;
resume:
        if(callstack.empty()){ 
          return;
        }
        root = callstack.top();
        callstack.pop();
        ret.push_back(root->val);
        root = root->right;
+       goto proc;
	}

至此,改造基本完毕。运行程序也能得到正确的答案。只不过代码的 goto 语句太多,造成逻辑上的混乱。我们一个一个消掉:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
    proc:
        if (root != nullptr)
        {
            callstack.push(root);
            root = root->left;
            goto proc;
        }
        if (callstack.empty())
        {
            return;
        }
        root = callstack.top();
        callstack.pop();
        ret.push_back(root->val);
        root = root->right;
        goto proc;
    }

再整理:

    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
    proc:
        if (root != nullptr)
        {
            callstack.push(root);
            root = root->left;
            goto proc;
        }
        if (!callstack.empty())
        {

            root = callstack.top();
            callstack.pop();
            ret.push_back(root->val);
            root = root->right;
            goto proc;
        }
    }
    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
    proc:
        if (root != nullptr)
        {
            callstack.push(root);
            root = root->left;
            goto proc;
        }
        else if (!callstack.empty())
        {

            root = callstack.top();
            callstack.pop();
            ret.push_back(root->val);
            root = root->right;
            goto proc;
        }
        else
        {
            return;
        }
    }
    void _process(TreeNode *root, vector<int> &ret, stack<TreeNode *> &callstack)
    {
    proc:
        if (root != nullptr)
        {
            callstack.push(root);
            root = root->left;
        }
        else if (!callstack.empty())
        {
            root = callstack.top();
            callstack.pop();
            ret.push_back(root->val);
            root = root->right;
        }
        else
        {
            return;
        }
        goto proc;
    }

最终答案:

    vector<int> inorderTraversal(TreeNode *root)
    {
        vector<int> ret;
        stack<TreeNode *> stack;
        while (!(root == nullptr && stack.empty()))
        {
            if (root != nullptr)
            {
                stack.push(root);
                root = root->left;
            }
            else if (!stack.empty())
            {
                root = stack.top();
                stack.pop();
                ret.push_back(root->val);
                root = root->right;
            }
        }
        return ret;
    }

以等差数列求和为例

上面的比较简单,因为不涉及返回值。试试这个:

剑指 Offer 64. 求 1+2+…+n - 力扣(LeetCode) (leetcode-cn.com)

#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

class Solution
{
public:
    int sumNums(int n)
    {
        if(n == 0) return 0;
        return n + sumNums(n - 1);
    }
};

int main(int argc, char const *argv[])
{
    Solution s;
    printf("%d\n", s.sumNums(100));
    return 0;
}

实际上这种类型是尾递归。所有尾递归都可以转换为这样的形式:

var stack = [];
stack.push(first);

while (!stack.empty()) {
    o = stack.pop();
    // ...
}

因此高斯求和问题可以转换为:

#include <cstdio>
#include <stack>

using namespace std;

stack<int> s;

void push(stack<int> &s, int v)
{
    s.push(v);
}

int pop(stack<int> &s)
{
    int v = s.top();
    s.pop();
    return v;
}

int sum(int n)
{
    push(s, n);
    int eax, o;
    while (!s.empty())
    {
        o = pop(s);
        eax = eax + o;
        if (o == 0)
            break;
        else
            push(s, o - 1);
    }
    return eax;
}

int main()
{
    int a = sum(100);
    printf("%d\n", a);
    return 0;
}

以斐波那契数列为例

    int fib(int n)
    {
        if(n == 2 || n == 1) return 1;
        if(n <= 0) return 0;        
        return fib(n - 2) + fib(n - 1);
    }

将其转化为栈将变成这样:

#include <cstdio>
#include <stack>

using namespace std;

stack<int> s;

void push(stack<int> &s, int v)
{
    //printf("push %d\n", v);
    s.push(v);
}

int pop(stack<int> &s)
{
    int v = s.top();
    //printf("pop %d\n", v);
    s.pop();
    return v;
}
int fib0(int n)
{
    if (n == 2 || n == 1)
        return 1;
    if (n <= 0)
        return 0;
    return fib0(n - 2) + fib0(n - 1);
}

int fib(int n)
{
    int eax = 0, o;
    push(s, n);
    while (!s.empty())
    {
        o = pop(s);
        if (o == 2 || o == 1)
        {
            eax += 1;
        }
        else if (o <= 0)
        {
            eax += 0;
        }
        else
        {
            push(s, o - 1);
            push(s, o - 2);
        }
    }
    return eax;
}

int main()
{
    int n = 15;
    int a = fib(n);
    int b = fib0(n);
    printf("%d, expect %d\n", a, b);
    return 0;
}

不过实际上你会发现,自己用栈实现反而性能会大幅度下降。不要迷信 “非递归”