什么是回溯
回溯可以看作在树状的解空间中搜索,一旦发现不可行,就回退到历史状态,然后选择另一个分叉搜索。由于要遍历尝试各个分支,所以也是一种暴力算法。
回溯的标准流程
一般过程:
- 存档
- 执行动作
- 回档
其实就和操作系统的调度、中断啥的一样,保存现场恢复现场。
有的情况下,比如递归实现的后序遍历,可以利用调用栈存档,这样就不用我们显式地存档。
回溯的优化
回溯没有动态规划那么容易进行空间优化,因为 DP 往往能够利用历史状态进行缓存。回溯常用的优化有:
- 交换代替插入删除
- 利用数学性质,无重复优化
- 通过增加条件限制,剪枝优化
全排列
【例子】46. 全排列
对于集合 {1,2,3}
全排列,相当于将 1,2,3
填到三个格子里。
那么对于第一个格子,我们选择填入 ?
- 1
- 2
- 3
这里就产生了三个分支。剩下的两个格子,就是剩余元素的全排列,比如第一个格子填入 3
,那么剩下的格子就是 {1,2}
的全排列。
- 历史状态:已经填入的格子。
- 候选集合:剩下能填的数字。
- 对于每个候选元素,都将会产生一个新的候选分支。
参考代码如下:
class Solution {
void rec(vector<int> &available, vector<int> &history,
vector<vector<int>> &ret) {
if (available.size() == 0) {
ret.push_back(history);
}
// == 尝试各个可用分支 ==
for (int i = 0; i < available.size(); i++) {
// 存档
int num = available[i];
history.push_back(num);
// 尝试
available.erase(available.begin() + i);
rec(available, history, ret);
// 回档
available.insert(available.begin() + i, num);
history.pop_back();
}
}
public:
vector<vector<int>> permute(vector<int> &nums) {
vector<vector<int>> ret;
vector<int> history;
rec(nums, history, ret);
return ret;
}
};
你可以注意到,核心代码就这么几句:
history.push_back(num);
// 尝试
available.erase(available.begin() + i);
rec(available, history, ret);
// 回档
available.insert(available.begin() + i, num);
history.pop_back();
很有对称美。总之,原来啥样,还原之后就是啥样。
增删性能优化
在时间复杂度上已经很难优化了,因为必须列出所有结果。但是从计算机设计的角度,我们可以减少数组的动态操作。实际上,可以用 swap 代替频繁的数组增删。下面是 LC 提供的题解:
class Solution {
public:
void backtrack(vector<vector<int>>& res, vector<int>& output, int first, int len){
// 所有数都填完了
if (first == len) {
res.emplace_back(output);
return;
}
for (int i = first; i < len; ++i) {
// 动态维护数组
swap(output[i], output[first]);
// 继续递归填下一个数
backtrack(res, output, first + 1, len);
// 撤销操作
swap(output[i], output[first]);
}
}
vector<vector<int>> permute(vector<int>& nums) {
vector<vector<int> > res;
backtrack(res, nums, 0, (int)nums.size());
return res;
}
};
非递归实现
可以通过反复调用 Next Permutation 实现。(参考资料(2))
子集
这题的朴素思路就是:
- 取出集合中的每个元素,求剩余元素的子集,各个推入结果列表
- 再把整体作为一个子集推入结果列表
代码:
class Solution {
public:
void rec(vector<int> &nums, vector<int>& history, vector<vector<int>>& ret) {
if (nums.size() == 0) {
return;
}
// 注意推入的时机
ret.push_back(history);
// 下面都是常规的保存现场恢复现场
for (int i = 0; i < nums.size(); i++) {
int tmp = nums[i];
history.push_back(tmp);
nums.erase(nums.begin() + i);
rec(nums, history, ret);
nums.insert(nums.begin() + i, tmp);
history.pop_back();
}
}
vector<vector<int>> subsets(vector<int>& nums) {
vector<vector<int>> ret;
vector<int> history;
rec(nums, history, ret);
ret.push_back(nums);
return ret;
}
};
输出:
{
{}
{1}
{1, 2}
{1, 3}
{2}
{2, 1}
{2, 3}
{3}
{3, 1}
{3, 2}
{1, 2, 3}
}
可以看到出现了重复。最简单的解决方法就是改用 set
结构。但是这样性能不佳。
无重复优化
发生重复的根源是什么?观察输出:
{
{}
{1}
{1, 2}
{1, 3}
{2}
{2, 1} // 重复
{2, 3}
{3}
{3, 1} // 重复
{3, 2} // 重复
{1, 2, 3}
}
注意到重复的原因在于回溯追加的元素小于首次选择的元素。比如 {2, 1}
中 $1 < 2$. 所以我们可以优化遍历时的起点:
class Solution {
public:
void rec(vector<int> &nums, vector<int>& history, vector<vector<int>>& ret, int startIndex = 0) {
if (nums.size() == 0) {
return;
}
ret.push_back(history);
for (int i = startIndex; i < nums.size (); i++) { // 注意这里
int tmp = nums[i];
nums.erase(nums.begin() + i);
history.push_back(tmp);
rec(nums, history, ret, i);
history.pop_back();
nums.insert(nums.begin() + i, tmp);
}
}
vector<vector<int>> subsets(vector<int>& nums) {
vector<vector<int>> ret;
vector<int> history;
rec(nums, history, ret);
ret.push_back(nums);
return ret;
}
};
通过引入 startIndex
,直接跳过了重复项。
对比
对比全排列和子集的回溯穷举算法,可以发现在回溯途中推送解,就是子集的算法,在回溯的末端推送解,就是全排列的算法。
组合
我们从 1,2,3,4
中选 3
个,则相当于:
1
+ 从2,3,4
中选2
个2
+ 从1,3,4
中选2
个- ……
代码如下(已经进行了无重复优化):
class Solution {
public:
void rec(vector<int>& nums, int k, vector<int>& history,
vector<vector<int>>& ret, int startIndex = 0) {
if (k == 0) {
ret.push_back(history);
return;
}
for (int i = startIndex; i < nums.size(); i++) {
int tmp = nums[i];
nums.erase(nums.begin() + i);
history.push_back(tmp);
rec(nums, k - 1, history, ret, i);
history.pop_back();
nums.insert(nums.begin() + i, tmp);
}
}
vector<vector<int>> combine(int n, int k) {
std::vector<int> nums(n);
for (int i = 0; i < n; i++) {
nums[i] = i + 1;
}
vector<vector<int>> ret;
vector<int> history;
rec(nums, k, history, ret);
return ret;
}
};
性能很烂,怎么回事呢?
执行用时:76 ms, 在所有 C++ 提交中击败了 7.87%的用户
内存消耗:8.8 MB, 在所有 C++ 提交中击败了 96.70%的用户
剪枝优化
如果 n = 7, k = 4,从 5 开始搜索就已经没有意义了,这是因为:即使把 5 选上,后面的数只有 6 和 7,一共就 3 个候选数,凑不出 4 个数的组合。(参考)
根据上面这句话,假设终止条件是 $x$,则有 $n - x + 1 < k$,即 $x = n - k + 1$。
也就是说,如果 v [startIndex] > n - k + 1
,则可以直接 return.
而 v [startIndex] = 1 + startIndex
(因为题给条件说组合所用数为 $1\cdots n$)
所以 startIndex > n - k
可以直接退出。我们要限定 i <= n - k
class Solution {
public:
void rec(vector<int>& nums, int k, vector<int>& history,
vector<vector<int>>& ret, int n, int startIndex = 0) {
if (k == 0) {
ret.push_back(history);
return;
}
for (int i = startIndex; i < nums.size() && i <= n - k; i++) {
int tmp = nums[i];
nums.erase(nums.begin() + i);
history.push_back(tmp);
rec(nums, k - 1, history, ret, n, i);
history.pop_back();
nums.insert(nums.begin() + i, tmp);
}
}
vector<vector<int>> combine(int n, int k) {
std::vector<int> nums(n);
for (int i = 0; i < n; i++) {
nums[i] = i + 1;
}
vector<vector<int>> ret;
vector<int> history;
rec(nums, k, history, ret, n, 0);
return ret;
}
};
执行用时:60 ms, 在所有 C++ 提交中击败了 8.58%的用户
内存消耗:8.7 MB, 在所有 C++ 提交中击败了 97.13%的用户
空间优化
问题在哪儿?其实我们完全没必要维护一个 nums 数组,因为 nums 可以通过 i + 1 得出:
class Solution {
public:
void rec(int k, vector<int>& history,
vector<vector<int>>& ret, int n, int startIndex = 0) {
if (k == 0) {
ret.push_back(history);
return;
}
for (int i = startIndex; i <= n - k; i++) {
history.push_back(i + 1);
rec(k - 1, history, ret, n, i + 1);
history.pop_back();
}
}
vector<vector<int>> combine(int n, int k) {
vector<vector<int>> ret;
vector<int> history;
rec(k, history, ret, n, 0);
return ret;
}
};
执行用时:4 ms, 在所有 C++ 提交中击败了 99.14%的用户
内存消耗:8.9 MB, 在所有 C++ 提交中击败了 89.49%的用户
这样,节省了空间,也减少了操作步骤,使得计算速度提高了。
组合总和
给定一个无重复元素的正整数数组 candidates
和一个正整数 target
,找出 candidates
中所有可以使数字和为目标数 target
的唯一组合。
candidates
中的数字可以无限制重复被选取。如果至少一个所选数字数量不同,则两种组合是唯一的。
对于给定的输入,保证和为 target
的唯一组合数少于 150
个。
示例 1:
输入: candidates = [2,3,6,7], target = 7
输出: [[7],[2,2,3]]
思路分析
分析:
如果采用排列穷举验证,相当于不剪枝,那么难点在于怎么处理重复元素。
不妨换个思路:可以利用 target - candidates [i]
缩小问题规模:
-
[2,3,6,7], target = 7
-2, target = 5
-2, target = 3
-3, target = 2
-6, target = -1
-7 target = -2
-3, target = 4
-6, target = 1
-7 target = 0
一旦 target = 0 就将搜索路径推送到答案列表。
一旦 target < 0 就停止搜索。
而可选数一直是
[2, 3, 6, 7]
代码及无重复优化
参照这个例子写出代码,并进行无重复优化:
class Solution {
private:
void backtrace(vector<int>& cand, vector<int>& path, vector<vector<int>>& ret,
int target, int startIndex = 0) {
if (target == 0) {
ret.push_back(path);
return;
}
if (target < 0) {
return;
}
for (int i = startIndex; i < cand.size(); i++) {
path.push_back(cand[i]);
backtrace(cand, path, ret, target - cand[i], i);
path.pop_back();
}
}
public:
vector<vector<int>> combinationSum(vector<int>& cand, int target) {
vector<vector<int>> ret;
vector<int> path;
backtrace(cand, path, ret, target);
return ret;
}
};
执行用时:0 ms, 在所有 C++ 提交中击败了 100.00%的用户
内存消耗:10.3 MB, 在所有 C++ 提交中击败了 98.78%的用户
看起来不错。
N 皇后
N 皇后问题将回溯代入了二维世界 (二次元)。但思路依旧是相同的 。
我们可以尝试各个初始位置,并锁定不能防止的单元:
基本思路
斜向判断
基本代码
解法:
#include <debug.h>
class Solution {
private:
void printState(int n, map<int, bool> &history) {
cout << "state:" << endl;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
cout << history[i * n + j] << " ";
}
cout << endl;
}
}
vector<string> historyToStrings(int n, map<int, bool> &history) {
vector<string> ret;
for (int i = 0; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (history[i * n + j]) {
s[j] = 'Q';
}
}
ret.push_back(s);
}
return ret;
}
// 检查棋盘 i,j 位置是否允许落子
bool available(int n, int i, int j, map<int, bool> &history) {
// printState(n, history);
if (history[i * n + j]) {
return false;
}
//row,col 为当前检测的起点
for (int row = 0; row < n; row++) {
for (int col = 0; col < n; col++) {
// 一旦 row, col 处落子,则同行同列禁止落子
if (history[row * n + col]) {
if (row == i || col == j) {
return false;
}
// == 斜向检测,利用和 / 差为定值 ==
//p,q 为以 row,col 为起点的斜向元素的坐标
// 左斜向检测
auto coordSum = row + col;
//p 是临时 row
//q 是临时 col
// col = coordSum - row >= 0
int p = 0, q = coordSum - p;
while (q >= 0) {
if (p == i && q == j) {
return false;
}
p++;
q = (coordSum - p);
}
// 右斜向检测
// p q
auto coordDiff = row - col;
// col = row - coordDiff >= 0
p = 0, q = p - coordDiff;
while (p < n && q < n) {
if (p == i && q == j) {
return false;
}
p++;
q = p - coordDiff;
}
}
}
}
return true;
}
int placedCount(map<int, bool> &history) {
auto itr = history.begin();
int counter = 0;
while (itr != history.end()) {
if ((*itr).second) {
counter++;
}
itr++;
}
return counter;
}
void backtrace(int n, map<int, bool> &history, vector<vector<string>> &ret) {
if (placedCount(history) == n) {
auto state = historyToStrings(n, history);
// 重复则不添加
for (size_t i = 0; i < ret.size(); i++) {
if (state == ret[i]) {
return;
}
}
ret.push_back(state);
return;
}
bool anyAvaliable = false;
for (int i = 0; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (available(n, i, j, history)) {
// printf("i,j = %d,%d placed \n", i, j);
anyAvaliable = true;
history[i * n + j] = true;
backtrace(n, history, ret);
history[i * n + j] = false;
}
}
}
if (!anyAvaliable) {
return;
}
}
public:
vector<vector<string>> solveNQueens(int n) {
// key: idx n*i+j, value: availability
map<int, bool> history;
for (int i = 0; i < n * n; i++) {
history[i] = false;
}
vector<vector<string>> ret;
backtrace(n, history, ret);
return ret;
}
};
int main(int argc, char const *argv[]) {
Solution s;
auto ret = s.solveNQueens(5);
print_vec_2d(ret, 0, true);
return 0;
}
输出:
.Q..,
...Q,
Q...,
..Q.
..Q.,
Q...,
...Q,
.Q..
算法是对的,但是超时。
无重复优化
class Solution {
private:
void printState(int n, map<int, bool> &history) {
cout << "state:" << endl;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
cout << history[i * n + j] << " ";
}
cout << endl;
}
}
vector<string> historyToStrings(int n, map<int, bool> &history) {
vector<string> ret;
for (int i = 0; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (history[i * n + j]) {
s[j] = 'Q';
}
}
ret.push_back(s);
}
return ret;
}
// 检查棋盘 i,j 位置是否允许落子
bool available(int n, int i, int j, map<int, bool> &history) {
// printState(n, history);
if (history[i * n + j]) {
return false;
}
//row,col 为当前检测的起点
for (int row = 0; row < n; row++) {
for (int col = 0; col < n; col++) {
// 一旦 row, col 处落子,则同行同列禁止落子
if (history[row * n + col]) {
if (row == i || col == j) {
return false;
}
// == 斜向检测,利用和 / 差为定值 ==
//p,q 为以 row,col 为起点的斜向元素的坐标
// 左斜向检测
auto coordSum = row + col;
//p 是临时 row
//q 是临时 col
// col = coordSum - row >= 0
int p = 0, q = coordSum - p;
while (q >= 0) {
if (p == i && q == j) {
return false;
}
p++;
q = (coordSum - p);
}
// 右斜向检测
// p q
auto coordDiff = row - col;
// col = row - coordDiff >= 0
p = 0, q = p - coordDiff;
while (p < n && q < n) {
if (p == i && q == j) {
return false;
}
p++;
q = p - coordDiff;
}
}
}
}
return true;
}
int placedCount(map<int, bool> &history) {
auto itr = history.begin();
int counter = 0;
while (itr != history.end()) {
if ((*itr).second) {
counter++;
}
itr++;
}
return counter;
}
void backtrace(int n, map<int, bool> &history, vector<vector<string>> &ret, int iStart = 0) {
if (placedCount(history) == n) {
auto state = historyToStrings(n, history);
ret.push_back(state);
return;
}
bool anyAvaliable = false;
for (int i = iStart; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (available(n, i, j, history)) {
// printf("i,j = %d,%d placed \n", i, j);
anyAvaliable = true;
history[i * n + j] = true;
backtrace(n, history, ret, i + 1);
history[i * n + j] = false;
}
}
}
if (!anyAvaliable) {
return;
}
}
public:
vector<vector<string>> solveNQueens(int n) {
// key: idx n*i+j, value: availability
map<int, bool> history;
for (int i = 0; i < n * n; i++) {
history[i] = false;
}
vector<vector<string>> ret;
backtrace(n, history, ret);
return ret;
}
};
搜索优化
上面的代码依然超时。原因在于我们判断可行区域时的效率太低。优化的方法是采用专门的结构,记录斜向是否可行。
class Solution {
private:
vector<string> historyToStrings(int n, map<int, bool> &history) {
vector<string> ret;
for (int i = 0; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (history[i * n + j]) {
s[j] = 'Q';
}
}
ret.push_back(s);
}
return ret;
}
void backtrace(int n, map<int, bool> &history, vector<bool> &curRow,
vector<bool> &diag1, vector<bool> &diag2,
vector<vector<string>> &ret, int iStart = 0) {
if (placedCount(history) == n) {
auto state = historyToStrings(n, history);
ret.push_back(state);
return;
}
bool anyAvaliable = false;
for (int i = iStart; i < n; i++) {
for (int j = 0; j < n; j++) {
if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
continue;
}
// printf("i,j = %d,%d placed \n", i, j);
anyAvaliable = true;
history[i * n + j] = true;
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
history[i * n + j] = false;
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
}
}
if (!anyAvaliable) {
return;
}
}
public:
vector<vector<string>> solveNQueens(int n) {
// key: idx n*i+j, value: availability
map<int, bool> history;
for (int i = 0; i < n * n; i++) {
history[i] = false;
}
vector<bool> curRow(n);
vector<bool> diag1(2 * n - 1);
vector<bool> diag2(2 * n - 1);
vector<vector<string>> ret;
backtrace(n, history, curRow, diag1, diag2, ret);
return ret;
}
};
执行用时:636 ms, 在所有 C++ 提交中击败了 5.15%的用户
内存消耗:7.8 MB, 在所有 C++ 提交中击败了 32.43%的用户
无效解优化
我们的代码还有优化空间,如果棋盘第一行(或者列)没有放置过,它还会尝试第二行。但既然已经有一行(或者列)没有放置过,那么必然无法放满 N 个。可以通过一个标识来跳过这种情况:
class Solution {
private:
vector<string> historyToStrings(int n, map<int, bool> &history) {
vector<string> ret;
for (int i = 0; i < n; i++) {
string s(n, '.');
for (int j = 0; j < n; j++) {
if (history[i * n + j]) {
s[j] = 'Q';
}
}
ret.push_back(s);
}
return ret;
}
void backtrace(int n, map<int, bool> &history, vector<bool> &curRow,
vector<bool> &diag1, vector<bool> &diag2,
vector<vector<string>> &ret, int iStart = 0) {
if (iStart == n) {
auto state = historyToStrings(n, history);
ret.push_back(state);
return;
}
bool rowPlaced = false;
for (int i = iStart; i < n; i++) {
for (int j = 0; j < n; j++) {
if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
continue;
}
// printf("i,j = %d,%d placed \n", i, j);
rowPlaced = true;
history[i * n + j] = true;
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
rowPlaced = false;
history[i * n + j] = false;
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
}
if (!rowPlaced) {
return;
}
}
return;
}
public:
vector<vector<string>> solveNQueens(int n) {
// key: idx n*i+j, value: availability
map<int, bool> history;
for (int i = 0; i < n * n; i++) {
history[i] = false;
}
vector<bool> curRow(n);
vector<bool> diag1(2 * n - 1);
vector<bool> diag2(2 * n - 1);
vector<vector<string>> ret;
backtrace(n, history, curRow, diag1, diag2, ret);
return ret;
}
};
执行用时:8 ms, 在所有 C++ 提交中击败了 57.32%的用户
内存消耗:7.8 MB, 在所有 C++ 提交中击败了 31.88%的用户
这次执行时间足足提高了上百倍。
返回值优化
由于我们上面为了代码的结构性,流水式处理,history 状态和状态的展现采用的是不同的方式,后者通过前者经过 historyToStrings
函数转换。这样会增加调用次数。
下面我们采用 history 直接作为返回状态:
class Solution {
private:
void backtrace(int n, vector<string> &history, vector<bool> &curRow,
vector<bool> &diag1, vector<bool> &diag2,
vector<vector<string>> &ret, int iStart = 0) {
if (iStart == n) {
ret.push_back(history);
return;
}
bool rowPlaced = false;
for (int i = iStart; i < n; i++) {
for (int j = 0; j < n; j++) {
if (curRow[j] || diag1[j + i] || diag2[j + n - 1 - i]) {
continue;
}
// printf("i,j = %d,%d placed \n", i, j);
rowPlaced = true;
history[i][j] = 'Q';
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = true;
backtrace(n, history, curRow, diag1, diag2, ret, i + 1);
rowPlaced = false;
history[i][j] = '.';
curRow[j] = diag1[j + i] = diag2[j + n - 1 - i] = false;
}
if (!rowPlaced) {
return;
}
}
}
public:
vector<vector<string>> solveNQueens(int n) {
// key: idx n*i+j, value: availability
vector<string> history(n);
for (int i = 0; i < n; i++) {
history[i] = string(n, '.');
}
vector<bool> curRow(n);
vector<bool> diag1(2 * n - 1);
vector<bool> diag2(2 * n - 1);
vector<vector<string>> ret;
backtrace(n, history, curRow, diag1, diag2, ret);
return ret;
}
};
执行用时:4 ms, 在所有 C++ 提交中击败了 95.27%的用户
内存消耗:7 MB, 在所有 C++ 提交中击败了 90.23%的用户
执行时间降低了已经比较令人满意了。
回溯实现深度优先搜索
给定一棵树,要求搜索某个节点,并返回其路径。参考代码:
void FindPathImpl(stack<TreeNode *> &history, TreeNode *root,
TreeNode *target, bool &over) {
if (over) {
return;
}
history.push(root);
if (root == nullptr) {
return;
}
if (root == target) {
over = true;
return;
}
FindPathImpl(history, root->left, target, over);
if (over) {
return;
} else {
history.pop();
}
FindPathImpl(history, root->right, target, over);
if (over) {
return;
} else {
history.pop();
}
}
deque<TreeNode *> FindPath(TreeNode *root, TreeNode *target) {
stack<TreeNode *> history;
bool found = false;
FindPathImpl(history, root, target, found);
// reverse
deque<TreeNode *> ret;
while (!history.empty()) {
auto top = history.top();
history.pop();
ret.push_back(top);
}
return ret;
}
参考
(1)【算法】回溯法四步走 - Nemo& - 博客园 (cnblogs.com):比较通俗易懂,推荐。
(2)Next lexicographical permutation algorithm (nayuki.io):“下一个全排列” 算法,很厉害。
(3)回溯算法入门级详解 + 练习(持续更新) - 全排列 - 力扣(LeetCode) (leetcode-cn.com)