差分数组c++实现与力扣题目总结

 
Category: DSA

写在前面

总结一下经典的差分数组方法, (华为机试刚考了), 思路很简单, 但是没遇到的话想写出来还是有点难度的.

参考了 labuladong 的博客, 里面的代码是 Java 实现的, 这里用 C++重写一下.

小而美的算法技巧:差分数组.

思路

差分数组是一种支持频繁对数组的某一区间进行增减, 以及查询的数据结构.

基本思想就是: (高中学的数列)

设数列为 ${a_i},\ i\in[0,\ n]$, 那么做一新数列${b_j},\ j\in[0, n]$, 且满足

\[\begin{cases} b_0 = a_0\\ b_j = a_i - a_{i-1}, \ j\in[1,\ n] \end{cases}\]

这样的数列(数组) $b_j$ 其实就是差分数组.

那么为什么这样的数组支持对区间的频繁增减呢?

例子

$A={4,7,5,2,9}$, 对应的差分数组$B={4,3,-2,-3,7}$.

初始化

void init(vector<int> &A, vector<int> & B) {
    B[0] = A[0];
    for (int i{1}; i < A.size(); ++i) 
        B[i] = A[i] - A[i - 1];
}

那么有了差分数组, 计算区间和就像前缀和一样简洁了:

查询: 返回原始数组

直接将上述操作反过来即可:

vector<int> get_raw(vector<int> &B) {
    int n = B.size();
    vector<int> C(B);
    for (int i{1}; i < B.size(); ++i) 
        C[i] = B[i] + C[i - 1];
   	return C;
}

修改: 增减

如果要对区间(闭区间) $[2,3]$ 每一个数都增加 1, 那么只需要修改区间的第一个下标位置的元素(加一)和最后一个下标位置的下一个元素(减一)即可, 即 \(B' = \{4,3,-1,-3,6\}\)

一般的, 对差分数组$B$的区间$[i,j]$进行增(减)$k$的操作, 只需将$B[i] \pm= k$且$B[j+1]\mp=k$即可.

void modify(vector<int> &B, int i, int j, int k) {
    B[i] += k;
    if (j + 1 < B.size()) // 考虑数组越界
        B[j + 1] -= k;
}

封装成类

头文件与打印重载:

#include <vector>
#include <iostream>
using namespace std;


ostream& operator<<(ostream& os, const vector<int>& v) {
    for (auto i : v) os << i << " ";
    return os << endl;
}

类: (声明与定义写在一起了)

template <typename T>
class DiffArray {
    vector<T> arr;

public:
    DiffArray() : arr() {}
    DiffArray(vector<T> _arr) { init(_arr); }

    void init(const vector<T>& raw_arr) {
        int n = raw_arr.size();
        arr.resize(n);
        arr[0] = raw_arr[0];
        for (int i{1}; i < n; ++i) arr[i] = raw_arr[i] - raw_arr[i - 1];
    }

    vector<T> get_raw() {
        vector<T> raw_arr(arr);
        for (int i{1}; i < arr.size(); ++i)
            raw_arr[i] = arr[i] + raw_arr[i - 1];
        return raw_arr;
    }

    void modify(int i, int j, int k) {
        arr[i] += k;
        if (j + 1 < arr.size()) arr[j + 1] -= k;
    }

    const vector<T>& query() const { return arr; }
};

差分数组类的打印:

ostream& operator<<(ostream& os, const DiffArray<int>& v) {
    return os << v.query();
}

测试:

int main() {
    //
    vector<int> v{4, 7, 5, 2, 9};
    cout << v;
    DiffArray<int> d1(v);
    cout << d1;
    d1.modify(2, 3, 1); // [2, 3] 增加 1
    cout << d1;
    cout << d1.get_raw();
}

结果:

4 7 5 2 9    // 原始数组
4 3 -2 -3 7  // 差分数组
4 3 -1 -3 6  // 修改后的差分数组
4 7 6 3 9    // 修改后的原数组

力扣题目

1109. 航班预订统计;

class DiffArray {
    vector<int> arr;

public:
    DiffArray() : arr() {}
    DiffArray(vector<int> _arr) : arr(_arr) {}

    vector<int> get_raw() {
        vector<int> raw_arr(arr);
        for (int i{1}; i < arr.size(); ++i)
            raw_arr[i] = arr[i] + raw_arr[i - 1];
        return raw_arr;
    }

    void modify(int i, int j, int k) {
        arr[i] += k;
        if (j + 1 < arr.size()) arr[j + 1] -= k;
    }
};
class Solution {
public:
    vector<int> corpFlightBookings(vector<vector<int>>& bookings, int n) {
        vector<int> ans(n);
        DiffArray da(ans);
        for (auto v : bookings) da.modify(v[0] - 1, v[1] - 1, v[2]);
        return da.get_raw();
    }
};

当然, 可以更快一些:

class Solution {
public:
    vector<int> corpFlightBookings(vector<vector<int>>& bookings, int n) {
        vector<int> ans(n);
        for (auto v : bookings) {
            auto R{v[1]}, val{v[2]};
            ans[v[0] - 1] += val;
            if (R < n) ans[R] -= val;
        }
        for (int i{1}; i < n; ++i) ans[i] = ans[i] + ans[i - 1];
        return ans;
    }
};

最快的方法STL 以及开辟(n+1)长度的数组减少判断:

class Solution {
public:
    vector<int> corpFlightBookings(vector<vector<int>>& bookings, int n) {
        vector<int> ans(n + 1);
        for (vector<int>& i : bookings) {
            ans[i[0] - 1] += i[2];
            ans[i[1]] -= i[2];
        }
        partial_sum(ans.begin(), ans.end(), ans.begin());
        ans.pop_back();
        return ans;
    }
};

6919. 使数组中的所有元素都等于零 - 力扣(Leetcode);

周赛第四题, 没时间了, 否则能做出来…

class Solution {
public:
    bool checkArray(vector<int>& nums, int k) {
        int n = nums.size();
        int diff[n];
        diff[0] = nums[0];
        for (int i{1}; i < n; ++i)
            diff[i] = nums[i] - nums[i - 1];
        for (int i{}; i < n; ++i) {
            if (diff[i] > 0) {
                int tmp{diff[i]};
                diff[i] -= tmp;
                if (i + k < n)
                    diff[i + k] += tmp;
                else if (i + k > n)
                    return 0;
            } else if (diff[i] < 0)
                return 0;
        }
        return 1;
    }
};