From f3775dc2df0e927aa99c852fd2d8b613a33f91b0 Mon Sep 17 00:00:00 2001 From: Johannes Stoelp Date: Tue, 7 Nov 2023 21:51:26 +0100 Subject: mutex: add owning mutex utility --- Makefile | 1 + owning_mutex.h | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++ test/owning_mutex.cc | 41 +++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 owning_mutex.h create mode 100644 test/owning_mutex.cc diff --git a/Makefile b/Makefile index 8d1fb21..e489d12 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ TEST += bitfield TEST += option TEST += timer TEST += log +TEST += owning_mutex # -- INTERNALS ----------------------------------------------------------------- diff --git a/owning_mutex.h b/owning_mutex.h new file mode 100644 index 0000000..82756da --- /dev/null +++ b/owning_mutex.h @@ -0,0 +1,125 @@ +#ifndef UTILS_MUTEX_H +#define UTILS_MUTEX_H + +#include + +#include + +template +struct guard; + +// -- OWNING MUTEX ------------------------------------------------------------- + +/// owning_mutex +/// +/// An mutex wrapper type that owns a value of type T and provides mutual +/// exclusive access to that value through guard objects. Guard objects are +/// obtained from the wrapper API. When a guard goes out of scope, the mutex +/// will be unlocked automatically. +/// +/// The mutex type can be controlled by the template type argument M. +/// M: BasicLockable +/// +/// EXAMPLE: +/// struct data { int a; }; +/// owning_mutex val{1}; +/// +/// { +/// auto guard = val.lock(); +/// guard->a = 1337; +/// // mutex will be unlocked after this scope +/// } +template +struct owning_mutex { + template + constexpr explicit owning_mutex(Args... args) + : m_val{std::forward(args)...} {} + + owning_mutex(const owning_mutex&) = delete; + owning_mutex(owning_mutex&&) = delete; + + guard lock() { + return {m_mtx, m_val}; + } + + private: + M m_mtx; + T m_val; +}; + +// -- GUARD -------------------------------------------------------------------- + +#if __cplusplus >= 201703L + +template +struct [[nodiscard]] guard { + guard(M& mtx, T& val) : m_lk{mtx}, m_val{val} {} + + // With the guaranteed copy elision (cpp17) we can truly delete the + // copy/move constructor of the guard type. + // + // https://stackoverflow.com/a/38043447 + guard(const guard&) = delete; + guard(guard&&) = delete; + + T& operator*() { + return m_val; + } + + T* operator->() { + return &m_val; + } + + private: + std::lock_guard m_lk; + T& m_val; +}; + +#else // before cpp17 + +template +struct guard { + guard(M& mtx, T& val) : m_mtx{&mtx}, m_val{val} { + m_mtx->lock(); + } + + ~guard() { + if (m_mtx) { + m_mtx->unlock(); + } + } + + T& operator*() { + assert(m_mtx != nullptr); + return m_val; + } + + T* operator->() { + if (!m_mtx) { + return nullptr; + } + return &m_val; + } + + guard(const guard&) = delete; + // Implement move constructor for cases where the compiler does no copy + // elision. + // For API compatibility with the cpp17 version, the move constructor + // should not be explicitly invoked by the user. + // + // SAFETY: Exclusive access to T is guaranteed as at any given time only a + // single *guard* instance is NOT moved. + // + // UB: The *guard* must not be moved across thread boundaries and dropped + // there. + guard(guard&& rhs) noexcept : m_mtx{rhs.m_mtx}, m_val{rhs.m_val} { + rhs.m_mtx = nullptr; + } + + private: + M* m_mtx; + T& m_val; +}; +#endif + +#endif diff --git a/test/owning_mutex.cc b/test/owning_mutex.cc new file mode 100644 index 0000000..060415f --- /dev/null +++ b/test/owning_mutex.cc @@ -0,0 +1,41 @@ +#include + +#include +#include +#include + +#include +#include + +constexpr unsigned kNumThreads = 8; +constexpr unsigned kIter = 1 << 18; + +static_assert((static_cast(kNumThreads) * + static_cast(kIter)) <= + std::numeric_limits::max(), + "Expectate result overflowed!"); + +int main() { + owning_mutex data(0u); + + std::vector threads; + threads.reserve(kNumThreads); + + for (unsigned t = 0; t < kNumThreads; ++t) { + threads.emplace_back([&data, t]() { + for (unsigned i = 0; i < kIter; ++i) { + *data.lock() += 1; + } + std::printf("th%u finished\n", t); + }); + } + + for (auto& th : threads) { + th.join(); + } + + assert(*data.lock() == (kNumThreads * kIter)); + std::printf("Result %u\n", *data.lock()); + + return 0; +} -- cgit v1.2.3