aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile1
-rw-r--r--owning_mutex.h125
-rw-r--r--test/owning_mutex.cc41
3 files changed, 167 insertions, 0 deletions
diff --git a/Makefile b/Makefile
index 8d1fb21..e489d12 100644
--- a/Makefile
+++ b/Makefile
@@ -2,6 +2,7 @@ TEST += bitfield
TEST += option
TEST += timer
TEST += log
+TEST += owning_mutex
# -- INTERNALS -----------------------------------------------------------------
diff --git a/owning_mutex.h b/owning_mutex.h
new file mode 100644
index 0000000..82756da
--- /dev/null
+++ b/owning_mutex.h
@@ -0,0 +1,125 @@
+#ifndef UTILS_MUTEX_H
+#define UTILS_MUTEX_H
+
+#include <mutex>
+
+#include <cassert>
+
+template <typename T, typename M>
+struct guard;
+
+// -- OWNING MUTEX -------------------------------------------------------------
+
+/// owning_mutex
+///
+/// An mutex wrapper type that owns a value of type T and provides mutual
+/// exclusive access to that value through guard objects. Guard objects are
+/// obtained from the wrapper API. When a guard goes out of scope, the mutex
+/// will be unlocked automatically.
+///
+/// The mutex type can be controlled by the template type argument M.
+/// M: BasicLockable
+///
+/// EXAMPLE:
+/// struct data { int a; };
+/// owning_mutex<data> val{1};
+///
+/// {
+/// auto guard = val.lock();
+/// guard->a = 1337;
+/// // mutex will be unlocked after this scope
+/// }
+template <typename T, typename M = std::mutex>
+struct owning_mutex {
+ template <typename... Args>
+ constexpr explicit owning_mutex(Args... args)
+ : m_val{std::forward<Args>(args)...} {}
+
+ owning_mutex(const owning_mutex&) = delete;
+ owning_mutex(owning_mutex&&) = delete;
+
+ guard<T, M> lock() {
+ return {m_mtx, m_val};
+ }
+
+ private:
+ M m_mtx;
+ T m_val;
+};
+
+// -- GUARD --------------------------------------------------------------------
+
+#if __cplusplus >= 201703L
+
+template <typename T, typename M>
+struct [[nodiscard]] guard {
+ guard(M& mtx, T& val) : m_lk{mtx}, m_val{val} {}
+
+ // With the guaranteed copy elision (cpp17) we can truly delete the
+ // copy/move constructor of the guard type.
+ //
+ // https://stackoverflow.com/a/38043447
+ guard(const guard&) = delete;
+ guard(guard&&) = delete;
+
+ T& operator*() {
+ return m_val;
+ }
+
+ T* operator->() {
+ return &m_val;
+ }
+
+ private:
+ std::lock_guard<M> m_lk;
+ T& m_val;
+};
+
+#else // before cpp17
+
+template <typename T, typename M>
+struct guard {
+ guard(M& mtx, T& val) : m_mtx{&mtx}, m_val{val} {
+ m_mtx->lock();
+ }
+
+ ~guard() {
+ if (m_mtx) {
+ m_mtx->unlock();
+ }
+ }
+
+ T& operator*() {
+ assert(m_mtx != nullptr);
+ return m_val;
+ }
+
+ T* operator->() {
+ if (!m_mtx) {
+ return nullptr;
+ }
+ return &m_val;
+ }
+
+ guard(const guard&) = delete;
+ // Implement move constructor for cases where the compiler does no copy
+ // elision.
+ // For API compatibility with the cpp17 version, the move constructor
+ // should not be explicitly invoked by the user.
+ //
+ // SAFETY: Exclusive access to T is guaranteed as at any given time only a
+ // single *guard* instance is NOT moved.
+ //
+ // UB: The *guard* must not be moved across thread boundaries and dropped
+ // there.
+ guard(guard&& rhs) noexcept : m_mtx{rhs.m_mtx}, m_val{rhs.m_val} {
+ rhs.m_mtx = nullptr;
+ }
+
+ private:
+ M* m_mtx;
+ T& m_val;
+};
+#endif
+
+#endif
diff --git a/test/owning_mutex.cc b/test/owning_mutex.cc
new file mode 100644
index 0000000..060415f
--- /dev/null
+++ b/test/owning_mutex.cc
@@ -0,0 +1,41 @@
+#include <owning_mutex.h>
+
+#include <limits>
+#include <thread>
+#include <vector>
+
+#include <cassert>
+#include <cstdio>
+
+constexpr unsigned kNumThreads = 8;
+constexpr unsigned kIter = 1 << 18;
+
+static_assert((static_cast<unsigned long>(kNumThreads) *
+ static_cast<unsigned long>(kIter)) <=
+ std::numeric_limits<unsigned>::max(),
+ "Expectate result overflowed!");
+
+int main() {
+ owning_mutex<unsigned> data(0u);
+
+ std::vector<std::thread> threads;
+ threads.reserve(kNumThreads);
+
+ for (unsigned t = 0; t < kNumThreads; ++t) {
+ threads.emplace_back([&data, t]() {
+ for (unsigned i = 0; i < kIter; ++i) {
+ *data.lock() += 1;
+ }
+ std::printf("th%u finished\n", t);
+ });
+ }
+
+ for (auto& th : threads) {
+ th.join();
+ }
+
+ assert(*data.lock() == (kNumThreads * kIter));
+ std::printf("Result %u\n", *data.lock());
+
+ return 0;
+}