编辑距离与进阶题型解析

 
Category: DSA

工具函数

#include <bits/stdc++.h>

using namespace std;

template <typename T>
ostream& operator<<(ostream& os, const vector<T>& v) {
    for (auto x : v) {
        os << x << "\t";
    }
    os << '\n';
    return os;
}

template <typename T>
ostream& operator<<(ostream& os, const vector<vector<T>>& v) {
    for (auto x : v) {
        os << x;
    }
    return os;
}

基本写法

namespace base_version {
int edit_dist(const string& s1, const string& s2) {
    //
    int m = s1.size();
    int n = s2.size();
    vector<vector<int>> dp(n + 1, vector<int>(m + 1));
    // 将s1[1:m] 变成 s2[1:n] 的最小编辑距离

    // init
    for (int i = 0; i <= n; ++i) {
        dp[i][0] = i; // 增加
    }
    for (int j = 1; j <= m; ++j) {
        dp[0][j] = j; // 删除
    }
    // cout << dp;

    // calc
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            if (s1[j - 1] == s2[i - 1]) {
                dp[i][j] = dp[i - 1][j - 1]; // 不变
            } else {
                dp[i][j] = min(dp[i - 1][j],        // 增加
                               min(dp[i][j - 1],    // 删除
                                   dp[i - 1][j - 1] // 替换
                                   )) +
                           1;
            }
        }
    }
    cout << dp;
    return dp.back().back();
}
} // namespace base_version

进阶写法: 找最小编辑距离的编辑路径

enum class Choice : char {
    no = 0,
    add,
    del,
    sub,
};

struct Node {
    int dist{};
    Choice choice{};

    Node(int dist_ = 0) : dist(dist_) {}
};

Node operator+(const Node& n1, const Node& n2) {
    Node n{};
    n.dist = n1.dist + n2.dist;
    return n;
}

std::ostream& operator<<(std::ostream& os, const Choice& c) {
    switch (c) {
        case Choice::no:
            return os << " ";
        case Choice::add:
            return os << "+";
        case Choice::del:
            return os << "-";
        case Choice::sub:
            return os << "^";
        default:
            return os << "unknown";
    }
}

ostream& operator<<(ostream& os, const Node& node) {
    os << node.dist << " " << node.choice;
    return os;
}

namespace find_path_version {

bool operator<(const Node& n1, const Node& n2) { return n1.dist < n2.dist; }

int edit_dist(const string& s1, const string& s2) {
    //
    int m = s1.size();
    int n = s2.size();
    vector<vector<Node>> dp(n + 1, vector<Node>(m + 1));
    // 将s1[1:m] 变成 s2[1:n] 的最小编辑距离

    // init
    for (int i = 1; i <= n; ++i) {
        dp[i][0].dist = i;
        dp[i][0].choice = Choice::add;
    }
    for (int j = 1; j <= m; ++j) {
        dp[0][j].dist = j;
        dp[0][j].choice = Choice::del;
    }
    // cout << dp << '\n';

    // calc
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            if (s1[j - 1] == s2[i - 1]) {
                dp[i][j] = dp[i - 1][j - 1];
                dp[i][j].choice = Choice::no;
            } else {
                dp[i][j].dist = dp[i][j - 1].dist;
                dp[i][j].choice = Choice::del;
                if (dp[i - 1][j] < dp[i][j]) {
                    dp[i][j].dist = dp[i - 1][j].dist;
                    dp[i][j].choice = Choice::add;
                }
                if (dp[i - 1][j - 1] < dp[i][j]) {
                    dp[i][j].dist = dp[i - 1][j - 1].dist;
                    dp[i][j].choice = Choice::sub;
                }
                dp[i][j].dist++;
            }
        }
    }
    cout << dp;
    auto find_best_edition_path = [&]() -> int {
        vector<string> ops;
        int i = n, j = m;
        while (i > 0 and j > 0) {

            char c1 = s1[j - 1];
            char c2 = s2[i - 1];
            switch (dp[i][j].choice) {
                case Choice::no:
                    --i;
                    --j;
                    ops.emplace_back(format("skip s1[{}]:{} (s2[{}])\n", j, c1, i));
                    break;
                case Choice::add:
                    --i;
                    ops.emplace_back(format("insert s2[{}]:{}\n", i, c2));
                    break;
                case Choice::del:
                    --j;
                    ops.emplace_back(format("delete s1[{}]:{}\n", j, c1));
                    break;
                case Choice::sub:
                    --i;
                    --j;
                    ops.emplace_back(format(
                        "replace s1[{}]:{} with s2[{}]:{}\n", j, c1, i, c2));
                    break;
                default:
                    break;
            }
        }
        while (j > 0) { // s1 need delete
            --j;
            ops.emplace_back(format("delete s1[{}]:{}\n", j, s1[j]));
        }
        while (i > 0) { // s1 need insert with s2
            --i;
            ops.emplace_back(format("insert s2[{}]:{} to s1\n", i, s2[i]));
        }
        reverse(ops.begin(), ops.end());
        for (auto& op : ops) {
            cout << op;
        }
        return 0;
    }();


    return dp.back().back().dist;
}
} // namespace find_path_version

测试

void t1() {
    // using namespace base_version;
    using namespace find_path_version;
    string s1 = "rap";
    string s2 = "apple";
    cout << "change s1:[" << s1 << "] to s2:[" << s2 << "]\n";
    auto ret = edit_dist(s1, s2);
    cout << "\nret:" << ret << '\n';
}
/*
change s1:[rap] to s2:[apple]
0       1 -     2 -     3 -
1 +     1 ^     1       2 -
2 +     2 +     2 +     1  
3 +     3 +     3 +     2  
4 +     4 +     4 +     3 +
5 +     5 +     5 +     4 +
delete s1[0]:r
skip s1[1]:a (s2[0])
insert s2[1]:p
skip s1[2]:p (s2[2])
insert s2[3]:l
insert s2[4]:e

ret:4
*/

可交换版本

namespace swap_version {
int edit_dist(const string& s1, const string& s2) {
    //
    int m = s1.size();
    int n = s2.size();
    vector<vector<int>> dp(n + 1, vector<int>(m + 1));
    // 将s1[1:m] 变成 s2[1:n] 的最小编辑距离

    // init
    for (int i = 0; i <= n; ++i) {
        dp[i][0] = i; // 增加
    }
    for (int j = 1; j <= m; ++j) {
        dp[0][j] = j; // 删除
    }
    // cout << dp << '\n';

    // calc
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            if (s1[j - 1] == s2[i - 1]) {
                dp[i][j] = dp[i - 1][j - 1]; // 不变
            } else if (i > 1 and j > 1 and s1[j - 1] == s2[i - 2] and
                       s1[j - 2] == s2[i - 1]) {
                dp[i][j] = dp[i - 2][j - 2] + 1; // 交换
            } else {
                dp[i][j] = min(dp[i - 1][j],        // 增加
                               min(dp[i][j - 1],    // 删除
                                   dp[i - 1][j - 1] // 替换
                                   )) +
                           1;
            }
        }
    }
    cout << dp;
    return dp.back().back();
}
} // namespace swap_version

void t2() {
    // using namespace base_version;
    using namespace swap_version;
    string s1 = "aplpe";
    string s2 = "apple";
    cout << "change s1:[" << s1 << "] to s2:[" << s2 << "]\n";
    auto ret = edit_dist(s1, s2);
    cout << "\nret:" << ret << '\n';
}