工具函数
#include <bits/stdc++.h>
using namespace std;
template <typename T>
ostream& operator<<(ostream& os, const vector<T>& v) {
for (auto x : v) {
os << x << "\t";
}
os << '\n';
return os;
}
template <typename T>
ostream& operator<<(ostream& os, const vector<vector<T>>& v) {
for (auto x : v) {
os << x;
}
return os;
}
基本写法
namespace base_version {
int edit_dist(const string& s1, const string& s2) {
//
int m = s1.size();
int n = s2.size();
vector<vector<int>> dp(n + 1, vector<int>(m + 1));
// 将s1[1:m] 变成 s2[1:n] 的最小编辑距离
// init
for (int i = 0; i <= n; ++i) {
dp[i][0] = i; // 增加
}
for (int j = 1; j <= m; ++j) {
dp[0][j] = j; // 删除
}
// cout << dp;
// calc
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
if (s1[j - 1] == s2[i - 1]) {
dp[i][j] = dp[i - 1][j - 1]; // 不变
} else {
dp[i][j] = min(dp[i - 1][j], // 增加
min(dp[i][j - 1], // 删除
dp[i - 1][j - 1] // 替换
)) +
1;
}
}
}
cout << dp;
return dp.back().back();
}
} // namespace base_version
进阶写法: 找最小编辑距离的编辑路径
enum class Choice : char {
no = 0,
add,
del,
sub,
};
struct Node {
int dist{};
Choice choice{};
Node(int dist_ = 0) : dist(dist_) {}
};
Node operator+(const Node& n1, const Node& n2) {
Node n{};
n.dist = n1.dist + n2.dist;
return n;
}
std::ostream& operator<<(std::ostream& os, const Choice& c) {
switch (c) {
case Choice::no:
return os << " ";
case Choice::add:
return os << "+";
case Choice::del:
return os << "-";
case Choice::sub:
return os << "^";
default:
return os << "unknown";
}
}
ostream& operator<<(ostream& os, const Node& node) {
os << node.dist << " " << node.choice;
return os;
}
namespace find_path_version {
bool operator<(const Node& n1, const Node& n2) { return n1.dist < n2.dist; }
int edit_dist(const string& s1, const string& s2) {
//
int m = s1.size();
int n = s2.size();
vector<vector<Node>> dp(n + 1, vector<Node>(m + 1));
// 将s1[1:m] 变成 s2[1:n] 的最小编辑距离
// init
for (int i = 1; i <= n; ++i) {
dp[i][0].dist = i;
dp[i][0].choice = Choice::add;
}
for (int j = 1; j <= m; ++j) {
dp[0][j].dist = j;
dp[0][j].choice = Choice::del;
}
// cout << dp << '\n';
// calc
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
if (s1[j - 1] == s2[i - 1]) {
dp[i][j] = dp[i - 1][j - 1];
dp[i][j].choice = Choice::no;
} else {
dp[i][j].dist = dp[i][j - 1].dist;
dp[i][j].choice = Choice::del;
if (dp[i - 1][j] < dp[i][j]) {
dp[i][j].dist = dp[i - 1][j].dist;
dp[i][j].choice = Choice::add;
}
if (dp[i - 1][j - 1] < dp[i][j]) {
dp[i][j].dist = dp[i - 1][j - 1].dist;
dp[i][j].choice = Choice::sub;
}
dp[i][j].dist++;
}
}
}
cout << dp;
auto find_best_edition_path = [&]() -> int {
vector<string> ops;
int i = n, j = m;
while (i > 0 and j > 0) {
char c1 = s1[j - 1];
char c2 = s2[i - 1];
switch (dp[i][j].choice) {
case Choice::no:
--i;
--j;
ops.emplace_back(format("skip s1[{}]:{} (s2[{}])\n", j, c1, i));
break;
case Choice::add:
--i;
ops.emplace_back(format("insert s2[{}]:{}\n", i, c2));
break;
case Choice::del:
--j;
ops.emplace_back(format("delete s1[{}]:{}\n", j, c1));
break;
case Choice::sub:
--i;
--j;
ops.emplace_back(format(
"replace s1[{}]:{} with s2[{}]:{}\n", j, c1, i, c2));
break;
default:
break;
}
}
while (j > 0) { // s1 need delete
--j;
ops.emplace_back(format("delete s1[{}]:{}\n", j, s1[j]));
}
while (i > 0) { // s1 need insert with s2
--i;
ops.emplace_back(format("insert s2[{}]:{} to s1\n", i, s2[i]));
}
reverse(ops.begin(), ops.end());
for (auto& op : ops) {
cout << op;
}
return 0;
}();
return dp.back().back().dist;
}
} // namespace find_path_version
测试
void t1() {
// using namespace base_version;
using namespace find_path_version;
string s1 = "rap";
string s2 = "apple";
cout << "change s1:[" << s1 << "] to s2:[" << s2 << "]\n";
auto ret = edit_dist(s1, s2);
cout << "\nret:" << ret << '\n';
}
/*
change s1:[rap] to s2:[apple]
0 1 - 2 - 3 -
1 + 1 ^ 1 2 -
2 + 2 + 2 + 1
3 + 3 + 3 + 2
4 + 4 + 4 + 3 +
5 + 5 + 5 + 4 +
delete s1[0]:r
skip s1[1]:a (s2[0])
insert s2[1]:p
skip s1[2]:p (s2[2])
insert s2[3]:l
insert s2[4]:e
ret:4
*/
可交换版本
namespace swap_version {
int edit_dist(const string& s1, const string& s2) {
//
int m = s1.size();
int n = s2.size();
vector<vector<int>> dp(n + 1, vector<int>(m + 1));
// 将s1[1:m] 变成 s2[1:n] 的最小编辑距离
// init
for (int i = 0; i <= n; ++i) {
dp[i][0] = i; // 增加
}
for (int j = 1; j <= m; ++j) {
dp[0][j] = j; // 删除
}
// cout << dp << '\n';
// calc
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
if (s1[j - 1] == s2[i - 1]) {
dp[i][j] = dp[i - 1][j - 1]; // 不变
} else if (i > 1 and j > 1 and s1[j - 1] == s2[i - 2] and
s1[j - 2] == s2[i - 1]) {
dp[i][j] = dp[i - 2][j - 2] + 1; // 交换
} else {
dp[i][j] = min(dp[i - 1][j], // 增加
min(dp[i][j - 1], // 删除
dp[i - 1][j - 1] // 替换
)) +
1;
}
}
}
cout << dp;
return dp.back().back();
}
} // namespace swap_version
void t2() {
// using namespace base_version;
using namespace swap_version;
string s1 = "aplpe";
string s2 = "apple";
cout << "change s1:[" << s1 << "] to s2:[" << s2 << "]\n";
auto ret = edit_dist(s1, s2);
cout << "\nret:" << ret << '\n';
}