回溯算法及其优化

什么是回溯

回溯可以看作在树状的解空间中搜索,一旦发现不可行,就回退到历史状态,然后选择另一个分叉搜索。由于要遍历尝试各个分支,所以也是一种暴力算法。

回溯的标准流程

一般过程:

  • 存档

  • 执行动作

  • 回档

其实就和操作系统的调度、中断啥的一样,保存现场恢复现场。

有的情况下,比如递归实现的后序遍历,可以利用调用栈存档,这样就不用我们显式地存档。

回溯的优化

回溯没有动态规划那么容易进行空间优化,因为 DP 往往能够利用历史状态进行缓存。回溯常用的优化有:

  • 交换代替插入删除

  • 利用数学性质,无重复优化

  • 通过增加条件限制,剪枝优化

全排列

【例子】46. 全排列

对于集合 {1,2,3} 全排列,相当于将 1,2,3 填到三个格子里。

那么对于第一个格子,我们选择填入 ?

  • 1

  • 2

  • 3

这里就产生了三个分支。剩下的两个格子,就是剩余元素的全排列,比如第一个格子填入 3,那么剩下的格子就是 {1,2} 的全排列。

  • 历史状态:已经填入的格子。

  • 候选集合:剩下能填的数字。

  • 对于每个候选元素,都将会产生一个新的候选分支。

参考代码如下:

 1class Solution {
 2  void rec(vector<int> &available, vector<int> &history,
 3           vector<vector<int>> &ret) {
 4    if (available.size() == 0) {
 5      ret.push_back(history);
 6    }
 7    // == 尝试各个可用分支 ==
 8    for (int i = 0; i < available.size(); i++) {      
 9      // 存档
10      int num = available[i];
11      history.push_back(num);
12      // 尝试
13      available.erase(available.begin() + i);
14      rec(available, history, ret);
15      // 回档
16      available.insert(available.begin() + i, num);
17      history.pop_back();
18    }
19  }
20
21 public:
22  vector<vector<int>> permute(vector<int> &nums) {
23    vector<vector<int>> ret;
24    vector<int> history;
25    rec(nums, history, ret);
26    return ret;
27  }
28};

你可以注意到,核心代码就这么几句:

1      history.push_back(num);
2      // 尝试
3      available.erase(available.begin() + i);
4      rec(available, history, ret);
5      // 回档
6      available.insert(available.begin() + i, num);
7      history.pop_back();

很有对称美。总之,原来啥样,还原之后就是啥样。

增删性能优化

在时间复杂度上已经很难优化了,因为必须列出所有结果。但是从计算机设计的角度,我们可以减少数组的动态操作。实际上,可以用 swap 代替频繁的数组增删。下面是 LC 提供的题解:

 1class Solution {
 2public:
 3    void backtrack(vector<vector<int>>& res, vector<int>& output, int first, int len){
 4        // 所有数都填完了
 5        if (first == len) {
 6            res.emplace_back(output);
 7            return;
 8        }
 9        for (int i = first; i < len; ++i) {
10            // 动态维护数组
11            swap(output[i], output[first]);
12            // 继续递归填下一个数
13            backtrack(res, output, first + 1, len);
14            // 撤销操作
15            swap(output[i], output[first]);
16        }
17    }
18    vector<vector<int>> permute(vector<int>& nums) {
19        vector<vector<int> > res;
20        backtrack(res, nums, 0, (int)nums.size());
21        return res;
22    }
23};

非递归实现

可以通过反复调用 Next Permutation 实现。(参考资料(2))

子集

子集

这题的朴素思路就是:

  1. 取出集合中的每个元素,求剩余元素的子集,各个推入结果列表

  2. 再把整体作为一个子集推入结果列表

代码:

 1class Solution {
 2 public:
 3  void rec(vector<int> &nums, vector<int>& history, vector<vector<int>>& ret) {
 4    if (nums.size() == 0) {
 5      return;
 6    }
 7    // 注意推入的时机
 8    ret.push_back(history);
 9    // 下面都是常规的保存现场恢复现场
10    for (int i = 0; i < nums.size(); i++) {
11      int tmp = nums[i];            
12      history.push_back(tmp);
13      nums.erase(nums.begin() + i);
14      rec(nums, history, ret);
15      nums.insert(nums.begin() + i, tmp);
16      history.pop_back();
17    }
18  }
19  vector<vector<int>> subsets(vector<int>& nums) {
20    vector<vector<int>> ret;
21    vector<int> history;
22    rec(nums, history, ret);
23    ret.push_back(nums);
24    return ret;
25  }
26};

输出:

 1{
 2  {}
 3  {1}
 4  {1, 2}
 5  {1, 3}
 6  {2}
 7  {2, 1}
 8  {2, 3}
 9  {3}
10  {3, 1}
11  {3, 2}
12  {1, 2, 3}
13}

可以看到出现了重复。最简单的解决方法就是改用 set 结构。但是这样性能不佳。

无重复优化

发生重复的根源是什么?观察输出:

 1{
 2  {}
 3  {1}
 4  {1, 2}
 5  {1, 3}
 6  {2}
 7  {2, 1} // 重复
 8  {2, 3}
 9  {3}
10  {3, 1} // 重复
11  {3, 2} // 重复
12  {1, 2, 3}
13}

注意到重复的原因在于回溯追加的元素小于首次选择的元素。比如 {2, 1} 中 $1 < 2$. 所以我们可以优化遍历时的起点:

 1class Solution {
 2 public:
 3  void rec(vector<int> &nums, vector<int>& history, vector<vector<int>>& ret, int startIndex = 0) {
 4    if (nums.size() == 0) {
 5      return;
 6    }
 7    ret.push_back(history);
 8
 9    for (int i = startIndex; i < nums.size(); i++) { // 注意这里
10      int tmp = nums[i];
11      nums.erase(nums.begin() + i);
12      history.push_back(tmp);
13      rec(nums, history, ret, i);
14      history.pop_back();
15      nums.insert(nums.begin() + i, tmp);
16    }
17  }
18  vector<vector<int>> subsets(vector<int>& nums) {
19    vector<vector<int>> ret;
20    vector<int> history;
21    rec(nums, history, ret);
22    ret.push_back(nums);
23    return ret;
24  }
25};

通过引入 startIndex,直接跳过了重复项。

对比

对比全排列和子集的回溯穷举算法,可以发现在回溯途中推送解,就是子集的算法,在回溯的末端推送解,就是全排列的算法。

组合

组合

我们从 1,2,3,4 中选 3 个,则相当于:

  • 1 + 从2,3,4 中选 2

  • 2 + 从1,3,4 中选 2

  • ……

代码如下(已经进行了无重复优化):

 1class Solution {
 2 public:
 3  void rec(vector<int>& nums, int k, vector<int>& history,
 4           vector<vector<int>>& ret, int startIndex = 0) {
 5    if (k == 0) {
 6      ret.push_back(history);
 7      return;
 8    }
 9    for (int i = startIndex; i < nums.size(); i++) {
10      int tmp = nums[i];
11      nums.erase(nums.begin() + i);
12      history.push_back(tmp);
13      rec(nums, k - 1, history, ret, i);
14      history.pop_back();
15      nums.insert(nums.begin() + i, tmp);
16    }
17  }
18  vector<vector<int>> combine(int n, int k) {
19    std::vector<int> nums(n);
20    for (int i = 0; i < n; i++) {
21      nums[i] = i + 1;
22    }
23    vector<vector<int>> ret;
24    vector<int> history;
25    rec(nums, k, history, ret);
26    return ret;
27  }
28};

性能很烂,怎么回事呢?

执行用时:76 ms, 在所有 C++ 提交中击败了7.87%的用户

内存消耗:8.8 MB, 在所有 C++ 提交中击败了96.70%的用户

剪枝优化

如果 n = 7, k = 4,从 5 开始搜索就已经没有意义了,这是因为:即使把 5 选上,后面的数只有 6 和 7,一共就 3 个候选数,凑不出 4 个数的组合。(参考

根据上面这句话,假设终止条件是 $x$,则有 $n - x + 1 < k$,即 $x = n - k + 1$。

也就是说,如果 v[startIndex] > n - k + 1,则可以直接 return.

v[startIndex] = 1 + startIndex(因为题给条件说组合所用数为 $1\cdots n$)

所以 startIndex > n - k 可以直接退出。我们要限定 i <= n - k

 1class Solution {
 2 public:
 3  void rec(vector<int>& nums, int k, vector<int>& history,
 4           vector<vector<int>>& ret, int n, int startIndex = 0) {
 5    if (k == 0) {
 6      ret.push_back(history);
 7      return;
 8    }
 9    for (int i = startIndex; i < nums.size() && i <= n - k; i++) {
10      int tmp = nums[i];
11      nums.erase(nums.begin() + i);
12      history.push_back(tmp);
13      rec(nums, k - 1, history, ret, n, i);
14      history.pop_back();
15      nums.insert(nums.begin() + i, tmp);
16    }
17  }
18  vector<vector<int>> combine(int n, int k) {
19    std::vector<int> nums(n);
20    for (int i = 0; i < n; i++) {
21      nums[i] = i + 1;
22    }
23    vector<vector<int>> ret;
24    vector<int> history;
25    rec(nums, k, history, ret, n, 0);
26    return ret;
27  }
28};

执行用时:60 ms, 在所有 C++ 提交中击败了8.58%的用户

内存消耗:8.7 MB, 在所有 C++ 提交中击败了97.13%的用户

空间优化

问题在哪儿?其实我们完全没必要维护一个 nums 数组,因为 nums 可以通过 i + 1 得出:

 1class Solution {
 2 public:
 3  void rec(int k, vector<int>& history,
 4           vector<vector<int>>& ret, int n, int startIndex = 0) {
 5    if (k == 0) {
 6      ret.push_back(history);
 7      return;
 8    }
 9    for (int i = startIndex; i <= n - k; i++) {
10      history.push_back(i + 1);
11      rec(k - 1, history, ret, n, i + 1);
12      history.pop_back();
13    }
14  }
15  vector<vector<int>> combine(int n, int k) {
16    vector<vector<int>> ret;
17    vector<int> history;
18    rec(k, history, ret, n, 0);
19    return ret;
20  }
21};

执行用时:4 ms, 在所有 C++ 提交中击败了99.14%的用户

内存消耗:8.9 MB, 在所有 C++ 提交中击败了89.49%的用户

这样,节省了空间,也减少了操作步骤,使得计算速度提高了。

组合总和

组合总和

给定一个无重复元素的正整数数组 candidates 和一个正整数 target ,找出 candidates 中所有可以使数字和为目标数 target 的唯一组合。

candidates 中的数字可以无限制重复被选取。如果至少一个所选数字数量不同,则两种组合是唯一的。

对于给定的输入,保证和为 target 的唯一组合数少于 150 个。

示例 1:

输入: candidates = [2,3,6,7], target = 7
输出: [[7],[2,2,3]]

思路分析

分析:

如果采用排列穷举验证,相当于不剪枝,那么难点在于怎么处理重复元素。

不妨换个思路:可以利用 target - candidates[i] 缩小问题规模:

  • [2,3,6,7], target = 7

    • -2, target = 5

      • -2, target = 3

      • -3, target = 2

      • -6, target = -1

      • -7 target = -2

    • -3, target = 4

    • -6, target = 1

    • -7 target = 0

    一旦 target = 0 就将搜索路径推送到答案列表。

    一旦 target < 0 就停止搜索。

    而可选数一直是 [2, 3, 6, 7]

代码及无重复优化

参照这个例子写出代码,并进行无重复优化:

 1class Solution {
 2 private:
 3  void backtrace(vector<int>& cand, vector<int>& path, vector<vector<int>>& ret,
 4                 int target, int startIndex = 0) {
 5    if (target == 0) {
 6      ret.push_back(path);
 7      return;
 8    }
 9    if (target < 0) {
10      return;
11    }
12
13    for (int i = startIndex; i < cand.size(); i++) {
14      path.push_back(cand[i]);
15      backtrace(cand, path, ret, target - cand[i], i);
16      path.pop_back();
17    }
18  }
19
20 public:
21  vector<vector<int>> combinationSum(vector<int>& cand, int target) {
22    vector<vector<int>> ret;
23    vector<int> path;
24    backtrace(cand, path, ret, target);
25    return ret;
26  }
27};

执行用时:0 ms, 在所有 C++ 提交中击败了100.00%的用户

内存消耗:10.3 MB, 在所有 C++ 提交中击败了98.78%的用户

看起来不错。

N 皇后

N 皇后

N 皇后问题将回溯代入了二维世界(二次元)。但思路依旧是相同的 。

我们可以尝试各个初始位置,并锁定不能防止的单元:

image-20211124233215165

基本思路

斜向判断

基本代码

解法:

  1#include <debug.h>
  2
  3class Solution {
  4 private:
  5  void printState(int n, map<int, bool> &history) {
  6    cout << "state:" << endl;
  7    for (int i = 0; i < n; i++) {
  8      for (int j = 0; j < n; j++) {
  9        cout << history[i * n + j] << " ";
 10      }
 11      cout << endl;
 12    }
 13  }
 14  vector<string> historyToStrings(int n, map<int, bool> &history) {
 15    vector<string> ret;
 16    for (int i = 0; i < n; i++) {
 17      string s(n, '.');
 18      for (int j = 0; j < n; j++) {
 19        if (history[i * n + j]) {
 20          s[j] = 'Q';
 21        }
 22      }
 23      ret.push_back(s);
 24    }
 25    return ret;
 26  }
 27  // 检查棋盘 i,j 位置是否允许落子
 28  bool available(int n, int i, int j, map<int, bool> &history) {
 29    // printState(n, history);
 30    if (history[i * n + j]) {
 31      return false;
 32    }
 33    // row,col 为当前检测的起点
 34    for (int row = 0; row < n; row++) {
 35      for (int col = 0; col < n; col++) {
 36        // 一旦 row, col 处落子,则同行同列禁止落子
 37        if (history[row * n + col]) {
 38          if (row == i || col == j) {
 39            return false;
 40          }
 41          // == 斜向检测,利用和/差为定值 ==
 42          // p,q 为以 row,col 为起点的斜向元素的坐标
 43
 44          // 左斜向检测
 45          auto coordSum = row + col;
 46          // p 是临时 row
 47          // q 是临时 col
 48          // col = coordSum - row >= 0
 49          int p = 0, q = coordSum - p;
 50          while (q >= 0) {
 51            if (p == i && q == j) {
 52              return false;
 53            }
 54            p++;
 55            q = (coordSum - p);
 56          }
 57
 58          // 右斜向检测
 59          //               p     q
 60          auto coordDiff = row - col;
 61          // col = row  -  coordDiff >= 0
 62          p = 0, q = p - coordDiff;
 63          while (p < n && q < n) {
 64            if (p == i && q == j) {
 65              return false;
 66            }
 67            p++;
 68            q = p - coordDiff;
 69          }
 70        }
 71      }
 72    }
 73    return true;
 74  }
 75  int placedCount(map<int, bool> &history) {
 76    auto itr = history.begin();
 77    int counter = 0;
 78    while (itr != history.end()) {
 79      if ((*itr).second) {
 80        counter++;
 81      }
 82      itr++;
 83    }
 84    return counter;
 85  }
 86  void backtrace(int n, map<int, bool> &history, vector<vector<string>> &ret) {
 87    if (placedCount(history) == n) {
 88      auto state = historyToStrings(n, history);
 89      // 重复则不添加
 90      for (size_t i = 0; i < ret.size(); i++) {
 91        if (state == ret[i]) {
 92          return;
 93        }
 94      }
 95      ret.push_back(state);
 96
 97      return;
 98    }
 99    bool anyAvaliable = false;
100    for (int i = 0; i < n; i++) {
101      string s(n, '.');
102      for (int j = 0; j < n; j++) {
103        if (available(n, i, j, history)) {
104          // printf("i,j = %d,%d placed \n", i, j);
105          anyAvaliable = true;
106          history[i * n + j] = true;
107          backtrace(n, history, ret);
108          history[i * n + j] = false;
109        }
110      }
111    }
112    if (!anyAvaliable) {
113      return;
114    }
115  }
116
117 public:
118  vector<vector<string>> solveNQueens(int n) {
119    // key: idx n*i+j, value: availability
120    map<int, bool> history;
121    for (int i = 0; i < n * n; i++) {
122      history[i] = false;
123    }
124
125    vector<vector<string>> ret;
126    backtrace(n, history, ret);
127    return ret;
128  }
129};
130int main(int argc, char const *argv[]) {
131  Solution s;
132  auto ret = s.solveNQueens(5);
133  print_vec_2d(ret, 0, true);
134  return 0;
135}

输出:

1.Q.., 
2...Q, 
3Q..., 
4..Q.
5
6..Q., 
7Q..., 
8...Q, 
9.Q..

算法是对的,但是超时

无重复优化

  1
  2class Solution {
  3 private:
  4  void printState(int n, map<int, bool> &history) {
  5    cout << "state:" << endl;
  6    for (int i = 0; i < n; i++) {
  7      for (int j = 0; j < n; j++) {
  8        cout << history[i * n + j] << " ";
  9      }
 10      cout << endl;
 11    }
 12  }
 13  vector<string> historyToStrings(int n, map<int, bool> &history) {
 14    vector<string> ret;
 15    for (int i = 0; i < n; i++) {
 16      string s(n, '.');
 17      for (int j = 0; j < n; j++) {
 18        if (history[i * n + j]) {
 19          s[j] = 'Q';
 20        }
 21      }
 22      ret.push_back(s);
 23    }
 24    return ret;
 25  }
 26  // 检查棋盘 i,j 位置是否允许落子
 27  bool available(int n, int i, int j, map<int, bool> &history) {
 28    // printState(n, history);
 29    if (history[i * n + j]) {
 30      return false;
 31    }
 32    // row,col 为当前检测的起点
 33    for (int row = 0; row < n; row++) {
 34      for (int col = 0; col < n; col++) {
 35        // 一旦 row, col 处落子,则同行同列禁止落子
 36        if (history[row * n + col]) {
 37          if (row == i || col == j) {
 38            return false;
 39          }
 40          // == 斜向检测,利用和/差为定值 ==
 41          // p,q 为以 row,col 为起点的斜向元素的坐标
 42
 43          // 左斜向检测
 44          auto coordSum = row + col;
 45          // p 是临时 row
 46          // q 是临时 col
 47          // col = coordSum - row >= 0
 48          int p = 0, q = coordSum - p;
 49          while (q >= 0) {
 50            if (p == i && q == j) {
 51              return false;
 52            }
 53            p++;
 54            q = (coordSum - p);
 55          }
 56
 57          // 右斜向检测
 58          //               p     q
 59          auto coordDiff = row - col;
 60          // col = row  -  coordDiff >= 0
 61          p = 0, q = p - coordDiff;
 62          while (p < n && q < n) {
 63            if (p == i && q == j) {
 64              return false;
 65            }
 66            p++;
 67            q = p - coordDiff;
 68          }
 69        }
 70      }
 71    }
 72    return true;
 73  }
 74  int placedCount(map<int, bool> &history) {
 75    auto itr = history.begin();
 76    int counter = 0;
 77    while (itr != history.end()) {
 78      if ((*itr).second) {
 79        counter++;
 80      }
 81      itr++;
 82    }
 83    return counter;
 84  }
 85  void backtrace(int n, map<int, bool> &history, vector<vector<string>> &ret, int iStart = 0) {
 86    if (placedCount(history) == n) {
 87      auto state = historyToStrings(n, history);
 88      ret.push_back(state);
 89      return;
 90    }
 91    bool anyAvaliable = false;
 92    for (int i = iStart; i < n; i++) {
 93      string s(n, '.');
 94      for (int j = 0; j < n; j++) {
 95        if (available(n, i, j, history)) {
 96          // printf("i,j = %d,%d placed \n", i, j);
 97          anyAvaliable = true;
 98          history[i * n + j] = true;
 99          backtrace(n, history, ret, i + 1);
100          history[i * n + j] = false;
101        }
102      }
103    }
104    if (!anyAvaliable) {
105      return;
106    }
107  }
108
109 public:
110  vector<vector<string>> solveNQueens(int n) {
111    // key: idx n*i+j, value: availability
112    map<int, bool> history;
113    for (int i = 0; i < n * n; i++) {
114      history[i] = false;
115    }
116
117    vector<vector<string>> ret;
118    backtrace(n, history, ret);
119    return ret;
120  }
121};

搜索优化

上面的代码依然超时。原因在于我们判断可行区域时的效率太低。优化的方法是采用专门的结构,记录斜向是否可行。

 1class Solution {
 2 private:
 3  vector<string> historyToStrings(int n, map<int, bool> &history) {
 4    vector<string> ret;
 5    for (int i = 0; i < n; i++) {
 6      string s(n, '.');
 7      for (int j = 0; j < n; j++) {
 8        if (history[i * n + j]) {
 9          s[j] = 'Q';
10        }
11      }
12      ret.push_back(s);
13    }
14    return ret;
15  }
16  void backtrace(int n, map<int, bool> &history, vector<bool> &curRow,
17                 vector<bool> &diag1, vector<bool> &diag2,
18                 vector<vector<string>> &ret, int iStart = 0) {
19    if (placedCount(history) == n) {
20      auto state = historyToStrings(n, history);
21      ret.push_back(state);
22      return;
23    }
24    bool anyAvaliable = false;
25    for (int i = iStart; i < n; i++) {
26      for (int j = 0; j < n; j++) {
27        if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
28          continue;
29        }
30        // printf("i,j = %d,%d placed \n", i, j);
31        anyAvaliable = true;
32
33        history[i * n + j] = true;
34        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
35        backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
36        history[i * n + j] = false;
37        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
38      }
39    }
40    if (!anyAvaliable) {
41      return;
42    }
43  }
44
45 public:
46  vector<vector<string>> solveNQueens(int n) {
47    // key: idx n*i+j, value: availability
48    map<int, bool> history;
49    for (int i = 0; i < n * n; i++) {
50      history[i] = false;
51    }
52    vector<bool> curRow(n);
53    vector<bool> diag1(2 * n - 1);
54    vector<bool> diag2(2 * n - 1);
55    vector<vector<string>> ret;
56    backtrace(n, history, curRow, diag1, diag2, ret);
57    return ret;
58  }
59};

执行用时:636 ms, 在所有 C++ 提交中击败了5.15%的用户

内存消耗:7.8 MB, 在所有 C++ 提交中击败了32.43%的用户

无效解优化

我们的代码还有优化空间,如果棋盘第一行(或者列)没有放置过,它还会尝试第二行。但既然已经有一行(或者列)没有放置过,那么必然无法放满 N 个。可以通过一个标识来跳过这种情况:

 1
 2class Solution {
 3 private:
 4  vector<string> historyToStrings(int n, map<int, bool> &history) {
 5    vector<string> ret;
 6    for (int i = 0; i < n; i++) {
 7      string s(n, '.');
 8      for (int j = 0; j < n; j++) {
 9        if (history[i * n + j]) {
10          s[j] = 'Q';
11        }
12      }
13      ret.push_back(s);
14    }
15    return ret;
16  }
17  void backtrace(int n, map<int, bool> &history, vector<bool> &curRow,
18                 vector<bool> &diag1, vector<bool> &diag2,
19                 vector<vector<string>> &ret, int iStart = 0) {
20    if (iStart == n) {
21      auto state = historyToStrings(n, history);
22      ret.push_back(state);
23      return;
24    }
25    bool rowPlaced = false;
26    for (int i = iStart; i < n; i++) {
27      for (int j = 0; j < n; j++) {
28        if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
29          continue;
30        }
31        // printf("i,j = %d,%d placed \n", i, j);
32        rowPlaced = true;
33        history[i * n + j] = true;
34        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
35        backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
36
37        rowPlaced = false;
38        history[i * n + j] = false;
39        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
40      }
41      if (!rowPlaced) {
42        return;
43      }
44    }
45
46    return;
47  }
48
49 public:
50  vector<vector<string>> solveNQueens(int n) {
51    // key: idx n*i+j, value: availability
52    map<int, bool> history;
53    for (int i = 0; i < n * n; i++) {
54      history[i] = false;
55    }
56    vector<bool> curRow(n);
57    vector<bool> diag1(2 * n - 1);
58    vector<bool> diag2(2 * n - 1);
59    vector<vector<string>> ret;
60    backtrace(n, history, curRow, diag1, diag2, ret);
61    return ret;
62  }
63};

执行用时:8 ms, 在所有 C++ 提交中击败了57.32%的用户

内存消耗:7.8 MB, 在所有 C++ 提交中击败了31.88%的用户

这次执行时间足足提高了上百倍。

返回值优化

由于我们上面为了代码的结构性,流水式处理,history 状态和状态的展现采用的是不同的方式,后者通过前者经过 historyToStrings 函数转换。这样会增加调用次数。

下面我们采用 history 直接作为返回状态:

 1
 2class Solution {
 3 private:
 4  void backtrace(int n, vector<string> &history, vector<bool> &curRow,
 5                 vector<bool> &diag1, vector<bool> &diag2,
 6                 vector<vector<string>> &ret, int iStart = 0) {
 7    if (iStart == n) {
 8      ret.push_back(history);
 9      return;
10    }
11    bool rowPlaced = false;
12    for (int i = iStart; i < n; i++) {
13      for (int j = 0; j < n; j++) {
14        if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
15          continue;
16        }
17        // printf("i,j = %d,%d placed \n", i, j);
18        rowPlaced = true;
19        history[i][j] = 'Q';
20        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
21        backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
22
23        rowPlaced = false;
24        history[i][j] = '.';
25        curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
26      }
27      if (!rowPlaced) {
28        return;
29      }
30    }
31  }
32
33 public:
34  vector<vector<string>> solveNQueens(int n) {
35    // key: idx n*i+j, value: availability
36    vector<string> history(n);
37    for (int i = 0; i < n; i++) {
38      history[i] = string(n, '.');
39    }
40    vector<bool> curRow(n);
41    vector<bool> diag1(2 * n - 1);
42    vector<bool> diag2(2 * n - 1);
43    vector<vector<string>> ret;
44    backtrace(n, history, curRow, diag1, diag2, ret);
45    return ret;
46  }
47};

执行用时:4 ms, 在所有 C++ 提交中击败了95.27%的用户

内存消耗:7 MB, 在所有 C++ 提交中击败了90.23%的用户

执行时间降低了已经比较令人满意了。

回溯实现深度优先搜索

给定一棵树,要求搜索某个节点,并返回其路径。参考代码:

 1  void FindPathImpl(stack<TreeNode *> &history, TreeNode *root,
 2                    TreeNode *target, bool &over) {
 3    if (over) {
 4      return;
 5    }
 6    history.push(root);
 7    if (root == nullptr) {
 8      return;
 9    }
10    if (root == target) {
11      over = true;
12      return;
13    }
14    FindPathImpl(history, root->left, target, over);
15    if (over) {
16      return;
17    } else {
18
19      history.pop();
20    }
21    FindPathImpl(history, root->right, target, over);
22    if (over) {
23      return;
24    } else {
25
26      history.pop();
27    }
28  }
29  deque<TreeNode *> FindPath(TreeNode *root, TreeNode *target) {
30    stack<TreeNode *> history;
31    bool found = false;
32    FindPathImpl(history, root, target, found);
33    // reverse
34    deque<TreeNode *> ret;
35    while (!history.empty()) {
36      auto top = history.top();
37      history.pop();
38      ret.push_back(top);
39    }
40    return ret;
41  }

参考

(1)【算法】回溯法四步走 - Nemo& - 博客园 (cnblogs.com):比较通俗易懂,推荐。

(2)Next lexicographical permutation algorithm (nayuki.io):“下一个全排列”算法,很厉害。

(3)回溯算法入门级详解 + 练习(持续更新) - 全排列 - 力扣(LeetCode) (leetcode-cn.com)