diff options
-rw-r--r-- | Makefile | 1 | ||||
-rw-r--r-- | owning_mutex.h | 125 | ||||
-rw-r--r-- | test/owning_mutex.cc | 41 |
3 files changed, 167 insertions, 0 deletions
@@ -2,6 +2,7 @@ TEST += bitfield TEST += option TEST += timer TEST += log +TEST += owning_mutex # -- INTERNALS ----------------------------------------------------------------- diff --git a/owning_mutex.h b/owning_mutex.h new file mode 100644 index 0000000..82756da --- /dev/null +++ b/owning_mutex.h @@ -0,0 +1,125 @@ +#ifndef UTILS_MUTEX_H +#define UTILS_MUTEX_H + +#include <mutex> + +#include <cassert> + +template <typename T, typename M> +struct guard; + +// -- OWNING MUTEX ------------------------------------------------------------- + +/// owning_mutex +/// +/// An mutex wrapper type that owns a value of type T and provides mutual +/// exclusive access to that value through guard objects. Guard objects are +/// obtained from the wrapper API. When a guard goes out of scope, the mutex +/// will be unlocked automatically. +/// +/// The mutex type can be controlled by the template type argument M. +/// M: BasicLockable +/// +/// EXAMPLE: +/// struct data { int a; }; +/// owning_mutex<data> val{1}; +/// +/// { +/// auto guard = val.lock(); +/// guard->a = 1337; +/// // mutex will be unlocked after this scope +/// } +template <typename T, typename M = std::mutex> +struct owning_mutex { + template <typename... Args> + constexpr explicit owning_mutex(Args... args) + : m_val{std::forward<Args>(args)...} {} + + owning_mutex(const owning_mutex&) = delete; + owning_mutex(owning_mutex&&) = delete; + + guard<T, M> lock() { + return {m_mtx, m_val}; + } + + private: + M m_mtx; + T m_val; +}; + +// -- GUARD -------------------------------------------------------------------- + +#if __cplusplus >= 201703L + +template <typename T, typename M> +struct [[nodiscard]] guard { + guard(M& mtx, T& val) : m_lk{mtx}, m_val{val} {} + + // With the guaranteed copy elision (cpp17) we can truly delete the + // copy/move constructor of the guard type. + // + // https://stackoverflow.com/a/38043447 + guard(const guard&) = delete; + guard(guard&&) = delete; + + T& operator*() { + return m_val; + } + + T* operator->() { + return &m_val; + } + + private: + std::lock_guard<M> m_lk; + T& m_val; +}; + +#else // before cpp17 + +template <typename T, typename M> +struct guard { + guard(M& mtx, T& val) : m_mtx{&mtx}, m_val{val} { + m_mtx->lock(); + } + + ~guard() { + if (m_mtx) { + m_mtx->unlock(); + } + } + + T& operator*() { + assert(m_mtx != nullptr); + return m_val; + } + + T* operator->() { + if (!m_mtx) { + return nullptr; + } + return &m_val; + } + + guard(const guard&) = delete; + // Implement move constructor for cases where the compiler does no copy + // elision. + // For API compatibility with the cpp17 version, the move constructor + // should not be explicitly invoked by the user. + // + // SAFETY: Exclusive access to T is guaranteed as at any given time only a + // single *guard* instance is NOT moved. + // + // UB: The *guard* must not be moved across thread boundaries and dropped + // there. + guard(guard&& rhs) noexcept : m_mtx{rhs.m_mtx}, m_val{rhs.m_val} { + rhs.m_mtx = nullptr; + } + + private: + M* m_mtx; + T& m_val; +}; +#endif + +#endif diff --git a/test/owning_mutex.cc b/test/owning_mutex.cc new file mode 100644 index 0000000..060415f --- /dev/null +++ b/test/owning_mutex.cc @@ -0,0 +1,41 @@ +#include <owning_mutex.h> + +#include <limits> +#include <thread> +#include <vector> + +#include <cassert> +#include <cstdio> + +constexpr unsigned kNumThreads = 8; +constexpr unsigned kIter = 1 << 18; + +static_assert((static_cast<unsigned long>(kNumThreads) * + static_cast<unsigned long>(kIter)) <= + std::numeric_limits<unsigned>::max(), + "Expectate result overflowed!"); + +int main() { + owning_mutex<unsigned> data(0u); + + std::vector<std::thread> threads; + threads.reserve(kNumThreads); + + for (unsigned t = 0; t < kNumThreads; ++t) { + threads.emplace_back([&data, t]() { + for (unsigned i = 0; i < kIter; ++i) { + *data.lock() += 1; + } + std::printf("th%u finished\n", t); + }); + } + + for (auto& th : threads) { + th.join(); + } + + assert(*data.lock() == (kNumThreads * kIter)); + std::printf("Result %u\n", *data.lock()); + + return 0; +} |