C++定制比较规则的各种方法总结

 
Category: C++

写在前面

写一下刷题中常用的 C++ 算法库函数 sort(), lower_bound()等函数以及优先队列等需要自己定制比较规则的方法.

  #include <algorithm>

utils

输出

template <typename T> // nd-array can use this
ostream& operator<<(ostream& os, const vector<T>& v) {
    for (auto i : v) os << i << " ";
    return os << '\n';
}

void print_arr(int arr[], size_t n) { // C-style array need length
    for (int i{}; i < n; ++i) cout << arr[i] << " ";
    cout << '\n';
}

sort() 篇

一般排序

STL 容器

void t1() {
    vector<int> v{1, 32, 9, 8, 6, 3};
    sort(v.begin(), v.end());
    cout << v;
    sort(v.rbegin(), v.rend()); // reverse order
    cout << v;
    // 1 3 6 8 9 32
    // 32 9 8 6 3 1
}

C 数组

总体来说还是很方便的, 传入首地址和尾地址, 算是迭代器的向下兼容了

void t2() {
    int arr[]{1, 9, 3, 32, 8, 6};
    auto n = sizeof(arr) / sizeof(arr[0]);
    sort(arr, arr + n); // 顺序(增序)
    print_arr(arr, n);
    sort(arr, arr + n, greater<>()); // 逆序
    print_arr(arr, n);
    // 1 3 6 8 9 32
    // 32 9 8 6 3 1
}

指定比较规则

这里直接用 auto-lambda, 好像是 C++17 的新语法.

void t3() {
    vector<int> v{1, 32, 9, 8, 6, 3};
    sort(v.begin(), v.end(),
         [](const auto& lhs, const auto& rhs) { return lhs % 4 < rhs % 4; });
    cout << v; // 32 8 1 9 6 3
}

双键比较(多用于二维数组, 较常用)

这里举一个最基本的例子, 就是先比较第一个值, 第一个值如果相等, 才比较第二个值.

void t4() {
    vector<vector<int>> vv{
        {1, 2}, {2, 9}, {2, 3}, {0, 8}, {1, 1} //
    };
    sort(vv.begin(), vv.end(), [](const auto& lhs, const auto& rhs) {
        return lhs[0] < rhs[0] || (lhs[0] == rhs[0] && lhs[1] < rhs[1]);
        // return lhs[0] == rhs[0] ? lhs[1] < rhs[1] : lhs[0] < rhs[0];
    });
    cout << vv;
    // 0 8
    //  1 1
    //  1 2
    //  2 3
    //  2 9
}

特定结构比较

有时候上面的 lambda 指定比较规则还是略显不足, 这时候就要用仿函数(函数对象)了:

以两个题为例:

791. 自定义字符串排序;

class Solution {
public:
    string customSortString(string order, string s) {
        int val[26]{};
        for (int i{}; i < order.size(); ++i) val[order[i] - 'a'] = i;
        sort(s.begin(), s.end(),
             [&](char& l, char& r) { return val[l - 'a'] < val[r - 'a']; });
        return s;
    }
};

其实我一开始想到的是哈希计数然后重建字符串, 后来发现可以原地排序.

1122. 数组的相对排序

class Solution {
public:
    vector<int> relativeSortArray(vector<int>& arr1, vector<int>& arr2) {
        int rank[1001];
        memset(rank, -1, sizeof(rank));
        for (int i = 0; i < arr2.size(); ++i) rank[arr2[i]] = i;

        auto mycmp = [&](int x) {
            // 0 表示在 arr2 比较规则中的键,
            // 1 表示之后要按升序放在末尾的键
            return ~rank[x] ? pair{0, rank[x]} : pair{1, x};
        };
        sort(arr1.begin(), arr1.end(),
             [&](int x, int y) { return mycmp(x) < mycmp(y); });
        return arr1;
    }
};

事实上这两个题都是计数排序的题, 如果用默认排序函数的话反而不是很快.

lower_bound() 篇

这里就以lower_bound()为例了, 其他类似的函数可以一样做.

可以这样定制比较规则:

lower_bound(v.begin(), v.end(), value_to_found,
            [&](const auto &lhs, const auto &) { // 捕获列表是为了待查找的值, 第二参数为了占位
                return lhs.first < value_to_found;
            });

一个例子:

假如有一个 pair 类型数组, 需要找对应的key(即 pair 的 first 值)

// for C-array
#define SIZE(x) ((sizeof(x)) / (sizeof(x[0])))

void t5() {
    // user definition compare rule
    using MapType = pair<int, string>;
    MapType marr[]{
        {1, "one"}, {2, "two1"}, {2, "two2"}, {2, "two3"}, {3, "three"} //
    };
    int value_to_found = 2;
    auto it = lower_bound(marr, marr + SIZE(marr), value_to_found,
                          [&](const auto &lhs, const auto &) {
                              return lhs.first < value_to_found;
                          });
    cout << it << endl;
    cout << it->second << endl;         // two1
    cout << it - marr << endl;          // 1
    cout << distance(marr, it) << endl; // 1
}

事实上第二参数不是必须的(查看llvm源码得知):

  template <class _Compare, class _ForwardIterator, class _Tp>
  _LIBCPP_CONSTEXPR_AFTER_CXX17 _ForwardIterator
  __lower_bound(_ForwardIterator __first, _ForwardIterator __last, const _Tp& __value_, _Compare __comp)
  {
      typedef typename iterator_traits<_ForwardIterator>::difference_type difference_type;
      difference_type __len = _VSTD::distance(__first, __last);
      while (__len != 0)
      {
          difference_type __l2 = _VSTD::__half_positive(__len);
          _ForwardIterator __m = __first;
          _VSTD::advance(__m, __l2);
          if (__comp(*__m, __value_)) // 这里的比较函数其实就用了第一参数, 第二参数是要找的 value
          {
              __first = ++__m;
              __len -= __l2 + 1;
          }
          else
              __len = __l2;
      }
      return __first;
  }

优先队列

这里有两种写法, 第一种比较传统, 就是直接写struct , 里面套operator(), 这里推荐第二种写法, 就是decltype(cmp)写法, 看着就舒服.

方法一: 函数对象

struct cmp {
    using pii = pair<int, int>;
    bool operator()(const pii &lhs, const pii &rhs) const {
        return lhs.first < rhs.first; // 这里就是举个例子, 实际上这种简单的比较可以直接用默认比较函数
    }
};

void t1() {
    priority_queue<pair<int, int>, vector<pair<int, int>>, cmp> pq;
    pq.emplace(1, 3);
    pq.emplace(4, 2);
    pq.emplace(6, 2);
    for (; !pq.empty(); pq.pop()) {
        auto [k, v] = pq.top();
        cout << k << " : " << v << endl;
    }
    // 6 : 2
    // 4 : 2
    // 1 : 3
}

题外话: 关于 C++ pair对象比较函数的默认规则:

  using pii = pair<int, int>;
  
  void t1() {
      vector<pii> v{
          {1, 2}, {1, 1}, {3, 2}, {3, 0}, {2, 2} //
      };
      sort(v.begin(), v.end());
      for (auto [p1, p2] : v) cout << p1 << " : " << p2 << endl;
      // 1 : 1
      // 1 : 2
      // 2 : 2
      // 3 : 0
      // 3 : 2
  }

说明 pair 这样的对象在默认比较规则下就是从小到大, 若第一键相同, 则第二键从小到大.

方法二: lambda 函数

注意第二种写法需要构造时候传入比较函数作为参数.

void t2() {
    auto cmp = [](const auto &lhs, const auto &rhs) {
        return lhs.first < rhs.first; // lambda 比较规则
    };
    using pii = pair<int, int>;
    // 算是 decltype 的一个妙用了
    priority_queue<pii, vector<pii>, decltype(cmp)> pq(cmp);
    pq.emplace(1, 3);
    pq.emplace(4, 2);
    pq.emplace(6, 2);
    for (; !pq.empty(); pq.pop()) {
        auto [k, v] = pq.top();
        cout << k << " : " << v << endl;
    }
    // 6 : 2
    // 4 : 2
    // 1 : 3
}

做题练练手

692. 前K个高频单词 - 力扣(Leetcode);

class Solution {
public:
    vector<string> topKFrequent(vector<string> &words, int k) {
        unordered_map<string, int> cnt;
        for (auto w : words)
            ++cnt[w];
        using pis = pair<int, string>;
        auto cmp = [&](const auto &lhs, const auto &rhs) {
            return lhs.first == rhs.first ? lhs.second < rhs.second
                                          : lhs.first > rhs.first;
        };
        priority_queue<pis, vector<pis>, decltype(cmp)> pq(cmp);
        for (auto &[key, val] : cnt) {
            pq.emplace(val, key);
            if (pq.size() > k)
                pq.pop();
        }
        vector<string> ans;
        for (; !pq.empty(); pq.pop()) {
            ans.emplace_back(pq.top().second);
        }
        reverse(ans.begin(), ans.end());
        return ans;
    }
};

786. 第 K 个最小的素数分数 - 力扣(Leetcode);

class Solution {
public:
    vector<int> kthSmallestPrimeFraction(vector<int>& arr, int k) {
        int n = arr.size();
        using ptype = pair<int, int>;
        auto cmp = [&](const auto& lhs, const auto& rhs) {
            return arr[lhs.first] * arr[rhs.second] >
                   arr[lhs.second] * arr[rhs.first];
        };
        priority_queue<ptype, vector<ptype>, decltype(cmp)> pq(cmp);

        for (int i{1}; i < n; ++i)
            pq.emplace(0, i);
        for (int _{1}; _ < k; ++_) {
            auto [i, j] = pq.top();
            pq.pop();
            if (i + 1 < j)
                pq.emplace(i + 1, j);
        }
        return {arr[pq.top().first], arr[pq.top().second]};
    }
};

总结

可以说是各有千秋, 不过我还是喜欢 lambda!