C++引用计数智能指针的实现

 
Category: C_C++

写在前面

老师讲的很细, 可以参考:

  1. Mark shared_ptr;

总结一下面试常考知识点.

#include <bits/stdc++.h>
#include "../../utils/chrono.h"

namespace base_version {
template <typename T>
class shared_ptr {
    T* _raw_ptr{};
    std::atomic_size_t* _ref_count{};

    void release() {
        if (_ref_count and _ref_count->fetch_sub(1) == 1) {
            delete _raw_ptr;
            delete _ref_count;
            _raw_ptr = nullptr;
            _ref_count = nullptr;
        }
    }

public:
    shared_ptr() = default;

    explicit shared_ptr(T* ptr)
        : _raw_ptr(ptr),
          _ref_count(ptr ? new std::atomic_size_t(1) : nullptr) {}

    ~shared_ptr() { release(); }

    shared_ptr(const shared_ptr<T>& other)
        : _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
        if (_ref_count) {
            _ref_count->fetch_add(1);
        }
    }

    shared_ptr& operator=(const shared_ptr<T>& other) {
        if (this != &other) {
            release();
            _ref_count = other._ref_count;
            _raw_ptr = other._raw_ptr;
            if (_ref_count) {
                _ref_count->fetch_add(1);
            }
        }
        return *this;
    }

    shared_ptr(shared_ptr<T>&& other) noexcept
        : _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
        other._ref_count = nullptr;
        other._raw_ptr = nullptr;
    }

    shared_ptr& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) {
            release();
            _ref_count = other._ref_count;
            _raw_ptr = other._raw_ptr;
            other._ref_count = nullptr;
            other._raw_ptr = nullptr;
        }
        return *this;
    }

    T& operator*() const { return *_raw_ptr; }

    T* operator->() const { return _raw_ptr; }

    T* get() const { return _raw_ptr; }

    size_t use_count() const { return _ref_count ? _ref_count->load() : 0; }

    void reset(T* p = nullptr) {
        release();
        _raw_ptr = p;
        _ref_count = p ? new std::atomic_size_t(1) : nullptr;
    }
};

} // namespace base_version

namespace mem_order_version {
template <typename T>
class shared_ptr {
    T* _raw_ptr{};
    std::atomic_size_t* _ref_count{};

    void release() {
        if (_ref_count and
            _ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {
            delete _raw_ptr;
            delete _ref_count;
            _raw_ptr = nullptr;
            _ref_count = nullptr;
        }
    }

public:
    shared_ptr() = default;

    explicit shared_ptr(T* ptr)
        : _raw_ptr(ptr),
          _ref_count(ptr ? new std::atomic_size_t(1) : nullptr) {}

    ~shared_ptr() { release(); }

    shared_ptr(const shared_ptr<T>& other)
        : _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
        if (_ref_count) {
            _ref_count->fetch_add(1, std::memory_order_relaxed);
        }
    }

    shared_ptr& operator=(const shared_ptr<T>& other) {
        if (this != &other) {
            release();
            _ref_count = other._ref_count;
            _raw_ptr = other._raw_ptr;
            if (_ref_count) {
                _ref_count->fetch_add(1, std::memory_order_relaxed);
            }
        }
        return *this;
    }

    shared_ptr(shared_ptr<T>&& other) noexcept
        : _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
        other._ref_count = nullptr;
        other._raw_ptr = nullptr;
    }

    shared_ptr& operator=(shared_ptr<T>&& other) noexcept {
        if (this != &other) {
            release();
            _ref_count = other._ref_count;
            _raw_ptr = other._raw_ptr;
            other._ref_count = nullptr;
            other._raw_ptr = nullptr;
        }
        return *this;
    }

    T& operator*() const { return *_raw_ptr; }

    T* operator->() const { return _raw_ptr; }

    T* get() const { return _raw_ptr; }

    size_t use_count() const {
        return _ref_count ? _ref_count->load(std::memory_order_acquire) : 0;
    }

    void reset(T* p = nullptr) {
        release();
        _raw_ptr = p;
        _ref_count = p ? new std::atomic_size_t(1) : nullptr;
    }
};

} // namespace mem_order_version

void t1() {
    using namespace base_version;
    // using namespace mem_order_version;
    auto pi = shared_ptr<int>(new int[29]);
    auto pii = pi;
    std::cout << (*pii) << std::endl;
}

template <template <typename> class SharedPtr>
void test_shared_ptr_impl() {
    SharedPtr<int> ptr(new int(42));
    TimeCost tc;
    const int num_threads = 1000;
    std::vector<std::thread> vp;
    for (int i{}; i < num_threads; ++i) {
        vp.emplace_back([&ptr]() {
            for (int j{}; j < 1000; ++j) {
                auto local_ptr(ptr);
                std::this_thread::sleep_for(1ms);
            }
        });
    }
    for (auto& t : vp) {
        t.join();
    }
    std::cout << "use_count: " << ptr.use_count() << std::endl;
    if (ptr.use_count() == 1) {
        std::cout << "passed: safe.\n";
    } else {
        std::cout << "failed\n";
    }
    std::cout << "total time:" << tc.get_time_cost() << "ms\n";
}

void t2(int mode = 0) {
    if (mode == 2) {
        printf("using self defined mem order\n");
        test_shared_ptr_impl<mem_order_version::shared_ptr>();
    } else {
        printf("use default mem order\n");
        test_shared_ptr_impl<base_version::shared_ptr>();
    }
}

int main(int argc, char** argv) {
    // t1();
    t2(argc);
    return 0;
}