写在前面
老师讲的很细, 可以参考:
总结一下面试常考知识点.
#include <bits/stdc++.h>
#include "../../utils/chrono.h"
namespace base_version {
template <typename T>
class shared_ptr {
T* _raw_ptr{};
std::atomic_size_t* _ref_count{};
void release() {
if (_ref_count and _ref_count->fetch_sub(1) == 1) {
delete _raw_ptr;
delete _ref_count;
_raw_ptr = nullptr;
_ref_count = nullptr;
}
}
public:
shared_ptr() = default;
explicit shared_ptr(T* ptr)
: _raw_ptr(ptr),
_ref_count(ptr ? new std::atomic_size_t(1) : nullptr) {}
~shared_ptr() { release(); }
shared_ptr(const shared_ptr<T>& other)
: _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
if (_ref_count) {
_ref_count->fetch_add(1);
}
}
shared_ptr& operator=(const shared_ptr<T>& other) {
if (this != &other) {
release();
_ref_count = other._ref_count;
_raw_ptr = other._raw_ptr;
if (_ref_count) {
_ref_count->fetch_add(1);
}
}
return *this;
}
shared_ptr(shared_ptr<T>&& other) noexcept
: _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
other._ref_count = nullptr;
other._raw_ptr = nullptr;
}
shared_ptr& operator=(shared_ptr<T>&& other) noexcept {
if (this != &other) {
release();
_ref_count = other._ref_count;
_raw_ptr = other._raw_ptr;
other._ref_count = nullptr;
other._raw_ptr = nullptr;
}
return *this;
}
T& operator*() const { return *_raw_ptr; }
T* operator->() const { return _raw_ptr; }
T* get() const { return _raw_ptr; }
size_t use_count() const { return _ref_count ? _ref_count->load() : 0; }
void reset(T* p = nullptr) {
release();
_raw_ptr = p;
_ref_count = p ? new std::atomic_size_t(1) : nullptr;
}
};
} // namespace base_version
namespace mem_order_version {
template <typename T>
class shared_ptr {
T* _raw_ptr{};
std::atomic_size_t* _ref_count{};
void release() {
if (_ref_count and
_ref_count->fetch_sub(1, std::memory_order_acq_rel) == 1) {
delete _raw_ptr;
delete _ref_count;
_raw_ptr = nullptr;
_ref_count = nullptr;
}
}
public:
shared_ptr() = default;
explicit shared_ptr(T* ptr)
: _raw_ptr(ptr),
_ref_count(ptr ? new std::atomic_size_t(1) : nullptr) {}
~shared_ptr() { release(); }
shared_ptr(const shared_ptr<T>& other)
: _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
if (_ref_count) {
_ref_count->fetch_add(1, std::memory_order_relaxed);
}
}
shared_ptr& operator=(const shared_ptr<T>& other) {
if (this != &other) {
release();
_ref_count = other._ref_count;
_raw_ptr = other._raw_ptr;
if (_ref_count) {
_ref_count->fetch_add(1, std::memory_order_relaxed);
}
}
return *this;
}
shared_ptr(shared_ptr<T>&& other) noexcept
: _raw_ptr(other._raw_ptr), _ref_count(other._ref_count) {
other._ref_count = nullptr;
other._raw_ptr = nullptr;
}
shared_ptr& operator=(shared_ptr<T>&& other) noexcept {
if (this != &other) {
release();
_ref_count = other._ref_count;
_raw_ptr = other._raw_ptr;
other._ref_count = nullptr;
other._raw_ptr = nullptr;
}
return *this;
}
T& operator*() const { return *_raw_ptr; }
T* operator->() const { return _raw_ptr; }
T* get() const { return _raw_ptr; }
size_t use_count() const {
return _ref_count ? _ref_count->load(std::memory_order_acquire) : 0;
}
void reset(T* p = nullptr) {
release();
_raw_ptr = p;
_ref_count = p ? new std::atomic_size_t(1) : nullptr;
}
};
} // namespace mem_order_version
void t1() {
using namespace base_version;
// using namespace mem_order_version;
auto pi = shared_ptr<int>(new int[29]);
auto pii = pi;
std::cout << (*pii) << std::endl;
}
template <template <typename> class SharedPtr>
void test_shared_ptr_impl() {
SharedPtr<int> ptr(new int(42));
TimeCost tc;
const int num_threads = 1000;
std::vector<std::thread> vp;
for (int i{}; i < num_threads; ++i) {
vp.emplace_back([&ptr]() {
for (int j{}; j < 1000; ++j) {
auto local_ptr(ptr);
std::this_thread::sleep_for(1ms);
}
});
}
for (auto& t : vp) {
t.join();
}
std::cout << "use_count: " << ptr.use_count() << std::endl;
if (ptr.use_count() == 1) {
std::cout << "passed: safe.\n";
} else {
std::cout << "failed\n";
}
std::cout << "total time:" << tc.get_time_cost() << "ms\n";
}
void t2(int mode = 0) {
if (mode == 2) {
printf("using self defined mem order\n");
test_shared_ptr_impl<mem_order_version::shared_ptr>();
} else {
printf("use default mem order\n");
test_shared_ptr_impl<base_version::shared_ptr>();
}
}
int main(int argc, char** argv) {
// t1();
t2(argc);
return 0;
}