// Copyright 2020 yuzu emulator team // Licensed under GPLv2 or any later version // Refer to the license.txt file included. #pragma once #include <atomic> #include <condition_variable> #include <functional> #include <mutex> #include <stop_token> #include <string> #include <thread> #include <type_traits> #include <vector> #include <queue> #include "common/thread.h" #include "common/unique_function.h" namespace Common { template <class StateType = void> class StatefulThreadWorker { static constexpr bool with_state = !std::is_same_v<StateType, void>; struct DummyCallable { int operator()() const noexcept { return 0; } }; using Task = std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>; using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>; public: explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) : workers_queued{num_workers}, thread_name{std::move(name)} { const auto lambda = [this, func](std::stop_token stop_token) { Common::SetCurrentThreadName(thread_name.c_str()); { [[maybe_unused]] std::conditional_t<with_state, StateType, int> state{func()}; while (!stop_token.stop_requested()) { Task task; { std::unique_lock lock{queue_mutex}; if (requests.empty()) { wait_condition.notify_all(); } condition.wait(lock, stop_token, [this] { return !requests.empty(); }); if (stop_token.stop_requested()) { break; } task = std::move(requests.front()); requests.pop(); } if constexpr (with_state) { task(&state); } else { task(); } ++work_done; } } ++workers_stopped; wait_condition.notify_all(); }; threads.reserve(num_workers); for (size_t i = 0; i < num_workers; ++i) { threads.emplace_back(lambda); } } StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete; StatefulThreadWorker(const StatefulThreadWorker&) = delete; StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete; StatefulThreadWorker(StatefulThreadWorker&&) = delete; void QueueWork(Task work) { { std::unique_lock lock{queue_mutex}; requests.emplace(std::move(work)); ++work_scheduled; } condition.notify_one(); } void WaitForRequests(std::stop_token stop_token = {}) { std::stop_callback callback(stop_token, [this] { for (auto& thread : threads) { thread.request_stop(); } }); std::unique_lock lock{queue_mutex}; wait_condition.wait(lock, [this] { return workers_stopped >= workers_queued || work_done >= work_scheduled; }); } private: std::queue<Task> requests; std::mutex queue_mutex; std::condition_variable_any condition; std::condition_variable wait_condition; std::atomic<size_t> work_scheduled{}; std::atomic<size_t> work_done{}; std::atomic<size_t> workers_stopped{}; std::atomic<size_t> workers_queued{}; std::string thread_name; std::vector<std::jthread> threads; }; using ThreadWorker = StatefulThreadWorker<>; } // namespace Common