From e882dfb8830fae49eccd3d61a80d5d3d33e06463 Mon Sep 17 00:00:00 2001 From: Alexander Krimm Date: Wed, 21 Jun 2023 18:17:22 +0200 Subject: [PATCH] thread_pool: task type as member instead of NTTP simplifies thread pool management --- bench/bm_scheduler.cpp | 4 ++-- include/scheduler.hpp | 18 ++++++++++-------- include/thread_pool.hpp | 23 ++++++++++------------- test/qa_hier_node.cpp | 2 +- test/qa_scheduler.cpp | 2 +- test/qa_settings.cpp | 2 +- test/qa_thread_pool.cpp | 10 +++++----- 7 files changed, 30 insertions(+), 31 deletions(-) diff --git a/bench/bm_scheduler.cpp b/bench/bm_scheduler.cpp index 2775d585b..6773e6e02 100644 --- a/bench/bm_scheduler.cpp +++ b/bench/bm_scheduler.cpp @@ -107,10 +107,10 @@ void exec_bm(auto& scheduler, const std::string& test_case) { [[maybe_unused]] inline const boost::ut::suite scheduler_tests = [] { using namespace boost::ut; using namespace benchmark; - using thread_pool = fair::thread_pool::BasicThreadPool; + using thread_pool = fair::thread_pool::BasicThreadPool; using fg::scheduler::execution_policy::multi_threaded; - auto pool = std::make_shared("custom-pool", 2, 2); + auto pool = std::make_shared("custom-pool", fair::thread_pool::CPU_BOUND, 2, 2); fg::scheduler::simple sched1(test_graph_linear(2 * N_NODES), pool); "linear graph - simple scheduler"_benchmark.repeat(N_SAMPLES) = [&sched1]() { diff --git a/include/scheduler.hpp b/include/scheduler.hpp index 8744d0676..4817e5c1f 100644 --- a/include/scheduler.hpp +++ b/include/scheduler.hpp @@ -90,16 +90,17 @@ void run_on_pool(std::span job, std::size_t n_batches, std::atomic_ /** * Trivial loop based scheduler, which iterates over all nodes in definition order in the graph until no node did any processing */ -template> -class simple : public node>{ +template +class simple : public node>{ using node_t = node_model*; + using thread_pool_type = thread_pool::BasicThreadPool; init_proof _init; fair::graph::graph _graph; std::shared_ptr _pool; std::vector> _job_lists{}; public: - explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared>("simple-scheduler-pool")) - : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(thread_pool) { + explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) + : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(std::move(thread_pool)) { // generate job list const auto n_batches = std::min(static_cast(_pool->maxThreads()), _graph.blocks().size()); _job_lists.reserve(n_batches); @@ -150,17 +151,18 @@ class simple : public node>{ * Breadth first traversal scheduler which traverses the graph starting from the source nodes in a breath first fashion * detecting cycles and nodes which can be reached from several source nodes. */ -template> -class breadth_first : public node> { +template +class breadth_first : public node> { using node_t = node_model*; + using thread_pool_type = thread_pool::BasicThreadPool; init_proof _init; fair::graph::graph _graph; std::vector _nodelist; std::vector> _job_lists; std::shared_ptr _pool; public: - explicit breadth_first(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared>("breadth-first-pool")) - : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(thread_pool) { + explicit breadth_first(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("breadth-first-pool", thread_pool::CPU_BOUND)) + : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(std::move(thread_pool)) { std::map> _adjacency_list{}; std::vector _source_nodes{}; // compute the adjacency list diff --git a/include/thread_pool.hpp b/include/thread_pool.hpp index d4aedb662..d1c7184bb 100644 --- a/include/thread_pool.hpp +++ b/include/thread_pool.hpp @@ -321,7 +321,6 @@ concept ThreadPool = requires(T t, std::function &&func) { * } * @endcode */ -template class BasicThreadPool { using Task = thread_pool::detail::Task; using TaskQueue = thread_pool::detail::TaskQueue; @@ -351,6 +350,7 @@ class BasicThreadPool { int _schedulingPriority = 0; const std::string _poolName; + const TaskType _taskType; const uint32_t _minThreads; const uint32_t _maxThreads; @@ -358,8 +358,8 @@ class BasicThreadPool { std::chrono::microseconds sleepDuration = std::chrono::milliseconds(1); std::chrono::milliseconds keepAliveDuration = std::chrono::seconds(10); - BasicThreadPool(const std::string_view &name = generateName(), uint32_t min = std::thread::hardware_concurrency(), uint32_t max = std::thread::hardware_concurrency()) - : _poolName(name), _minThreads(min), _maxThreads(max) { + BasicThreadPool(const std::string_view &name = generateName(), const TaskType taskType = TaskType::CPU_BOUND, uint32_t min = std::thread::hardware_concurrency(), uint32_t max = std::thread::hardware_concurrency()) + : _poolName(name), _taskType(taskType), _minThreads(min), _maxThreads(max) { assert(min > 0 && "minimum number of threads must be > 0"); assert(min <= max && "minimum number of threads must be <= maximum number of threads"); for (uint32_t i = 0; i < _minThreads; ++i) { @@ -376,9 +376,9 @@ class BasicThreadPool { } BasicThreadPool(const BasicThreadPool &) = delete; - BasicThreadPool(BasicThreadPool &&) = default; + BasicThreadPool(BasicThreadPool &&) = delete; BasicThreadPool &operator=(const BasicThreadPool &) = delete; - BasicThreadPool &operator=(BasicThreadPool &&) = default; + BasicThreadPool &operator=(BasicThreadPool &&) = delete; [[nodiscard]] std::string poolName() const noexcept { return _poolName; } [[nodiscard]] uint32_t minThreads() const noexcept { return _minThreads; }; @@ -435,7 +435,7 @@ class BasicThreadPool { _taskQueue.push(createTask(std::forward(func), std::forward(args)...)); _condition.notify_one(); - if constexpr (taskType == TaskType::IO_BOUND) { + if (_taskType == TaskType::IO_BOUND) { spinWait.spinOnce(); spinWait.spinOnce(); while (_taskQueue.size() > 0) { @@ -493,7 +493,7 @@ class BasicThreadPool { thread::setThreadName(fmt::format("{}#{}", _poolName, threadID), thread); thread::setThreadSchedulingParameter(_schedulingPolicy, _schedulingPriority, thread); if (!_affinityMask.empty()) { - if (taskType == TaskType::IO_BOUND) { + if (_taskType == TaskType::IO_BOUND) { thread::setThreadAffinity(_affinityMask); return; } @@ -636,12 +636,9 @@ class BasicThreadPool { } while (running); } }; -template -inline std::atomic BasicThreadPool::_globalPoolId = 0U; -template -inline std::atomic BasicThreadPool::_taskID = 0U; -static_assert(ThreadPool>); -static_assert(ThreadPool>); +inline std::atomic BasicThreadPool::_globalPoolId = 0U; +inline std::atomic BasicThreadPool::_taskID = 0U; +static_assert(ThreadPool); } diff --git a/test/qa_hier_node.cpp b/test/qa_hier_node.cpp index d8cdf73c8..da5f9c266 100644 --- a/test/qa_hier_node.cpp +++ b/test/qa_hier_node.cpp @@ -183,7 +183,7 @@ make_graph(std::size_t events_count) { int main() { - auto thread_pool = std::make_shared>("custom pool", 2,2); + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2,2); fg::scheduler::simple scheduler(make_graph(10), thread_pool); diff --git a/test/qa_scheduler.cpp b/test/qa_scheduler.cpp index 398e4c1ab..d7ff8c493 100644 --- a/test/qa_scheduler.cpp +++ b/test/qa_scheduler.cpp @@ -207,7 +207,7 @@ void check_node_names(std::vector joblist, std::set set) const boost::ut::suite SchedulerTests = [] { using namespace boost::ut; using namespace fair::graph; - auto thread_pool = std::make_shared>("custom pool", 2,2); + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2, 2); "SimpleScheduler_linear"_test = [&thread_pool] { using scheduler = fair::graph::scheduler::simple<>; diff --git a/test/qa_settings.cpp b/test/qa_settings.cpp index 09ef3e05f..58e0c8599 100644 --- a/test/qa_settings.cpp +++ b/test/qa_settings.cpp @@ -238,7 +238,7 @@ const boost::ut::suite SettingsTests = [] { expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block1).to<"in">(block2))); expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block2).to<"in">(sink))); - auto thread_pool = std::make_shared>("custom pool", 2,2); + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2, 2); fair::graph::scheduler::simple sched{ std::move(flow_graph), thread_pool }; expect(src.settings().auto_update_parameters().contains("sample_rate")); std::ignore = src.settings().set({ { "sample_rate", 49000.0f } }); diff --git a/test/qa_thread_pool.cpp b/test/qa_thread_pool.cpp index 9f5258a94..00e056ce1 100644 --- a/test/qa_thread_pool.cpp +++ b/test/qa_thread_pool.cpp @@ -11,12 +11,12 @@ const boost::ut::suite ThreadPoolTests = [] { using namespace boost::ut; "Basic ThreadPool tests"_test = [] { - expect(nothrow([] { fair::thread_pool::BasicThreadPool(); })); - expect(nothrow([] { fair::thread_pool::BasicThreadPool(); })); + expect(nothrow([] { fair::thread_pool::BasicThreadPool("test", fair::thread_pool::IO_BOUND); })); + expect(nothrow([] { fair::thread_pool::BasicThreadPool("test2", fair::thread_pool::CPU_BOUND); })); std::atomic enqueueCount{0}; std::atomic executeCount{0}; - fair::thread_pool::BasicThreadPool pool("TestPool", 1, 2); + fair::thread_pool::BasicThreadPool pool("TestPool", fair::thread_pool::IO_BOUND, 1, 2); expect(nothrow([&] { pool.sleepDuration = std::chrono::milliseconds(1); })); expect(nothrow([&] { pool.keepAliveDuration = std::chrono::seconds(10); })); pool.waitUntilInitialised(); @@ -60,7 +60,7 @@ const boost::ut::suite ThreadPoolTests = [] { }; "contention tests"_test = [] { std::atomic counter{0}; - fair::thread_pool::BasicThreadPool pool("contention", 1, 4); + fair::thread_pool::BasicThreadPool pool("contention", fair::thread_pool::IO_BOUND, 1, 4); pool.waitUntilInitialised(); expect(that % pool.isInitialised()); expect(pool.numThreads() == 1_u); @@ -99,7 +99,7 @@ const boost::ut::suite ThreadPoolTests = [] { std::atomic counter{0}; // Pool with min and max thread count - fair::thread_pool::BasicThreadPool pool("count_test", minThreads, maxThreads); + fair::thread_pool::BasicThreadPool pool("count_test", fair::thread_pool::IO_BOUND, minThreads, maxThreads); pool.keepAliveDuration = std::chrono::milliseconds(10); // default is 10 seconds, reducing for testing pool.waitUntilInitialised();