diff --git a/bench/bm_case1.cpp b/bench/bm_case1.cpp index 22138d48..bba76f39 100644 --- a/bench/bm_case1.cpp +++ b/bench/bm_case1.cpp @@ -471,7 +471,7 @@ invoke_work(auto &sched) { using namespace benchmark; test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; } @@ -598,7 +598,7 @@ inline const boost::ut::suite _runtime_tests = [] { ::benchmark::benchmark<1LU>{ test_name }.repeat(N_SAMPLES) = [&sched]() { test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; }; @@ -629,7 +629,7 @@ inline const boost::ut::suite _simd_tests = [] { "runtime src->mult(2.0)->mult(0.5)->add(-1)->sink (SIMD)"_benchmark.repeat(N_SAMPLES) = [&sched]() { test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; }; @@ -665,7 +665,8 @@ inline const boost::ut::suite _simd_tests = [] { "runtime src->(mult(2.0)->mult(0.5)->add(-1))^10->sink (SIMD)"_benchmark.repeat(N_SAMPLES) = [&sched]() { test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); + sched.reset(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; }; @@ -693,7 +694,7 @@ inline const boost::ut::suite _sample_by_sample_vs_bulk_access_tests = [] { test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; fg::scheduler::simple sched{ std::move(flow_graph) }; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; }; @@ -719,7 +720,7 @@ inline const boost::ut::suite _sample_by_sample_vs_bulk_access_tests = [] { ::benchmark::benchmark<1LU>{ test_name }.repeat(N_SAMPLES) = [&sched]() { test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; }; diff --git a/bench/bm_filter.cpp b/bench/bm_filter.cpp index 4f92bff0..3f0bbe44 100644 --- a/bench/bm_filter.cpp +++ b/bench/bm_filter.cpp @@ -36,7 +36,7 @@ invoke_work(auto &sched) { using namespace benchmark; test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - sched.work(); + sched.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << "did not produce enough output samples"; expect(eq(test::n_samples_consumed, N_SAMPLES)) << "did not consume enough input samples"; } diff --git a/bench/bm_scheduler.cpp b/bench/bm_scheduler.cpp index bdafeda8..2df37bab 100644 --- a/bench/bm_scheduler.cpp +++ b/bench/bm_scheduler.cpp @@ -10,6 +10,7 @@ namespace fg = fair::graph; inline constexpr std::size_t N_ITER = 10; inline constexpr std::size_t N_SAMPLES = gr::util::round_up(10'000'000, 1024); +inline constexpr std::size_t N_NODES = 5; template class math_op : public fg::node, fg::IN, fg::OUT> { @@ -98,34 +99,59 @@ void exec_bm(auto& scheduler, const std::string& test_case) { using namespace benchmark; test::n_samples_produced = 0LU; test::n_samples_consumed = 0LU; - scheduler.work(); + scheduler.run_and_wait(); expect(eq(test::n_samples_produced, N_SAMPLES)) << fmt::format("did not produce enough output samples for {}", test_case); expect(ge(test::n_samples_consumed, N_SAMPLES)) << fmt::format("did not consume enough input samples for {}", test_case); + scheduler.reset(); } -[[maybe_unused]] inline const boost::ut::suite _scheduler = [] { +[[maybe_unused]] inline const boost::ut::suite scheduler_tests = [] { using namespace boost::ut; using namespace benchmark; - - fg::scheduler::simple sched1(test_graph_linear(10)); - "linear graph - simple scheduler"_benchmark.repeat(N_SAMPLES) = [&sched1]() { - exec_bm(sched1, "linear-graph simple-sched"); - }; - - fg::scheduler::breadth_first sched2(test_graph_linear(10)); - "linear graph - BFS scheduler"_benchmark.repeat(N_SAMPLES) = [&sched2]() { - exec_bm(sched2, "linear-graph BFS-sched"); - }; - - fg::scheduler::simple sched3(test_graph_bifurcated(5)); - "bifurcated graph - simple scheduler"_benchmark.repeat(N_SAMPLES) = [&sched3]() { - exec_bm(sched3, "bifurcated-graph simple-sched"); - }; - - fg::scheduler::breadth_first sched4(test_graph_bifurcated(5)); - "bifurcated graph - BFS scheduler"_benchmark.repeat(N_SAMPLES) = [&sched4]() { - exec_bm(sched4, "bifurcated-graph BFS-sched"); - }; + using thread_pool = fair::thread_pool::BasicThreadPool; + using fg::scheduler::execution_policy::multi_threaded; + + auto pool = std::make_shared("custom-pool", fair::thread_pool::CPU_BOUND, 2, 2); + + fg::scheduler::simple sched1(test_graph_linear(2 * N_NODES), pool); + "linear graph - simple scheduler"_benchmark.repeat(N_SAMPLES) = [&sched1]() { + exec_bm(sched1, "linear-graph simple-sched"); + }; + + fg::scheduler::breadth_first sched2(test_graph_linear(2 * N_NODES), pool); + "linear graph - BFS scheduler"_benchmark.repeat(N_SAMPLES) = [&sched2]() { + exec_bm(sched2, "linear-graph BFS-sched"); + }; + + fg::scheduler::simple sched3(test_graph_bifurcated(N_NODES), pool); + "bifurcated graph - simple scheduler"_benchmark.repeat(N_SAMPLES) = [&sched3]() { + exec_bm(sched3, "bifurcated-graph simple-sched"); + }; + + fg::scheduler::breadth_first sched4(test_graph_bifurcated(N_NODES), pool); + "bifurcated graph - BFS scheduler"_benchmark.repeat(N_SAMPLES) = [&sched4]() { + exec_bm(sched4, "bifurcated-graph BFS-sched"); + }; + + fg::scheduler::simple sched1_mt(test_graph_linear(2 * N_NODES), pool); + "linear graph - simple scheduler (multi-threaded)"_benchmark.repeat(N_SAMPLES) = [&sched1_mt]() { + exec_bm(sched1_mt, "linear-graph simple-sched (multi-threaded)"); + }; + + fg::scheduler::breadth_first sched2_mt(test_graph_linear(2 * N_NODES), pool); + "linear graph - BFS scheduler (multi-threaded)"_benchmark.repeat(N_SAMPLES) = [&sched2_mt]() { + exec_bm(sched2_mt, "linear-graph BFS-sched (multi-threaded)"); + }; + + fg::scheduler::simple sched3_mt(test_graph_bifurcated(N_NODES), pool); + "bifurcated graph - simple scheduler (multi-threaded)"_benchmark.repeat(N_SAMPLES) = [&sched3_mt]() { + exec_bm(sched3_mt, "bifurcated-graph simple-sched (multi-threaded)"); + }; + + fg::scheduler::breadth_first sched4_mt(test_graph_bifurcated(N_NODES), pool); + "bifurcated graph - BFS scheduler (multi-threaded)"_benchmark.repeat(N_SAMPLES) = [&sched4_mt]() { + exec_bm(sched4_mt, "bifurcated-graph BFS-sched (multi-threaded)"); + }; }; int diff --git a/include/graph.hpp b/include/graph.hpp index 35cd6c9c..443e5f0c 100644 --- a/include/graph.hpp +++ b/include/graph.hpp @@ -629,9 +629,9 @@ class edge { class graph { private: - std::vector> _connection_definitions; - std::vector> _nodes; - std::vector _edges; + std::vector> _connection_definitions; + std::vector> _nodes; + std::vector _edges; template std::unique_ptr & @@ -705,8 +705,8 @@ class graph { if (!is_node_known(source) || !is_node_known(destination)) { throw fmt::format("Source {} and/or destination {} do not belong to this graph\n", source.name(), destination.name()); } - self._connection_definitions.push_back([self = &self, source = &source, source_port = &port, destination = &destination, destination_port = &destination_port]() { - return self->connect_impl(*source, *source_port, *destination, *destination_port); + self._connection_definitions.push_back([source = &source, source_port = &port, destination = &destination, destination_port = &destination_port](graph &graph) { + return graph.connect_impl(*source, *source_port, *destination, *destination_port); }); return connection_result_t::SUCCESS; } @@ -761,6 +761,11 @@ class graph { connect(Source &source, Port Source::*member_ptr); public: + graph(graph&) = delete; + graph(graph&&) = default; + graph() = default; + graph &operator=(graph&) = delete; + graph &operator=(graph&&) = default; /** * @return a list of all blocks contained in this graph * N.B. some 'blocks' may be (sub-)graphs themselves @@ -846,7 +851,7 @@ class graph { return result; } - const std::vector> & + const std::vector> & connection_definitions() { return _connection_definitions; } diff --git a/include/scheduler.hpp b/include/scheduler.hpp index 0f02925d..b58e51e3 100644 --- a/include/scheduler.hpp +++ b/include/scheduler.hpp @@ -3,69 +3,246 @@ #include #include #include +#include +#include +#include namespace fair::graph::scheduler { +using fair::thread_pool::BasicThreadPool; -struct init_proof { - init_proof(bool _success) : success(_success) {} +enum execution_policy { single_threaded, multi_threaded }; +enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN, ERROR}; - init_proof(init_proof &&init) : init_proof(init.success) {} +class scheduler_base { +protected: + SchedulerState _state = IDLE; + fair::graph::graph _graph; + std::shared_ptr _pool; + std::atomic_uint64_t _progress; + std::atomic_size_t _running_threads; + std::atomic_bool _stop_requested; - bool success = true; +public: + explicit scheduler_base(fair::graph::graph &&graph, + std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) + : _graph(std::move(graph)), _pool(std::move(thread_pool)){}; - init_proof & - operator=(init_proof &&init) noexcept { - this->success = init; - return *this; + ~scheduler_base() { + stop(); + _state = SHUTTING_DOWN; } - operator bool() const { return success; } -}; + void stop(){ + if (_state == STOPPED || _state == ERROR){ + return; + } + if (_state == RUNNING) { + request_stop(); + } + wait_done(); + _state = STOPPED; + } -init_proof -init(fair::graph::graph &graph) { - auto result = init_proof(std::all_of(graph.connection_definitions().begin(), graph.connection_definitions().end(), - [](auto &connection_definition) { return connection_definition() == connection_result_t::SUCCESS; })); - graph.clear_connection_definitions(); - return result; -} + void pause(){ + if (_state == PAUSED || _state == ERROR){ + return; + } + if (_state == RUNNING) { + request_pause(); + } + wait_done(); + _state = PAUSED; + } + + void wait_done() { + for (auto running = _running_threads.load(); running > 0ul; running = _running_threads.load()) { + _running_threads.wait(running); + } + _state = STOPPED; + } + + void request_stop() { + _stop_requested = true; + _state = REQUESTED_STOP; + } + + void request_pause() { + _stop_requested = true; + _state = REQUESTED_PAUSE; + } + + void + init() { + if (_state != IDLE) { + return; + } + auto result = std::all_of(_graph.connection_definitions().begin(), _graph.connection_definitions().end(), + [this](auto &connection_definition) { return connection_definition(_graph) == connection_result_t::SUCCESS; }); + if (result) { + _graph.clear_connection_definitions(); + _state = INITIALISED; + } else { + _state = ERROR; + } + } + + void + reset() { + // since it is not possible to setup the graph connections a second time, this method leaves the graph in the initialized state with clear buffers. + switch (_state) { + case IDLE: + init(); + break; + case RUNNING: + case REQUESTED_STOP: + case REQUESTED_PAUSE: + pause(); + // intentional fallthrough + case STOPPED: + case PAUSED: + // clear buffers + // std::for_each(_graph.edges().begin(), _graph.edges().end(), [](auto &edge) { + // + // }); + _state = INITIALISED; + break; + case SHUTTING_DOWN: + case INITIALISED: + case ERROR: + break; + } + } + + void + run_on_pool(const std::vector> &jobs, const std::function&)> work_function) { + _progress = 0; + _running_threads = jobs.size(); + for (auto &jobset : jobs) { + _pool->execute([this, &jobset, work_function, &jobs]() { + pool_worker([&work_function, &jobset]() { + return work_function(jobset); + }, jobs.size()); + }); + } + } + + void + pool_worker(const std::function& work, std::size_t n_batches) { + uint32_t done = 0; + uint32_t progress_count = 0; + while (done < n_batches && !_stop_requested) { + bool something_happened = work() == work_return_t::OK; + uint64_t progress_local, progress_new; + if (something_happened) { // something happened in this thread => increase progress and reset done count + do { + progress_local = _progress.load(); + progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); + done = static_cast(progress_local & ((1ULL << 32) - 1)); + progress_new = (progress_count + 1ULL) << 32; + } while (!_progress.compare_exchange_strong(progress_local, progress_new)); + _progress.notify_all(); + } else { // nothing happened on this thread + uint32_t progress_count_old = progress_count; + do { + progress_local = _progress.load(); + progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); + done = static_cast(progress_local & ((1ULL << 32) - 1)); + if (progress_count == progress_count_old) { // nothing happened => increase done count + progress_new = ((progress_count + 0ULL) << 32) + done + 1; + } else { // something happened in another thread => keep progress and done count and rerun this task without waiting + progress_new = ((progress_count + 0ULL) << 32) + done; + } + } while (!_progress.compare_exchange_strong(progress_local, progress_new)); + _progress.notify_all(); + if (progress_count == progress_count_old && done < n_batches) { + _progress.wait(progress_new); + } + } + } // while (done < n_batches) + _running_threads.fetch_sub(1); + _running_threads.notify_all(); + } + +}; /** * Trivial loop based scheduler, which iterates over all nodes in definition order in the graph until no node did any processing */ -class simple : public node { - init_proof _init; - fair::graph::graph _graph; - +template +class simple : public scheduler_base { + std::vector> _job_lists{}; public: - explicit simple(fair::graph::graph &&graph) : _init{ fair::graph::scheduler::init(graph) }, _graph(std::move(graph)) {} + explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) + : scheduler_base(std::forward(graph), thread_pool) { } + void + init() { + scheduler_base::init(); + // generate job list + if constexpr (executionPolicy == multi_threaded) { + const auto n_batches = std::min(static_cast(this->_pool->maxThreads()), this->_graph.blocks().size()); + this->_job_lists.reserve(n_batches); + for (std::size_t i = 0; i < n_batches; i++) { + // create job-set for thread + auto &job = this->_job_lists.emplace_back(); + job.reserve(this->_graph.blocks().size() / n_batches + 1); + for (std::size_t j = i; j < this->_graph.blocks().size(); j += n_batches) { + job.push_back(this->_graph.blocks()[j].get()); + } + } + } + } + + template work_return_t - work() { - if (!_init) { - return work_return_t::ERROR; - } - bool run = true; - while (run) { - bool something_happened = false; - for (const auto ¤t_node : _graph.blocks()) { - auto result = current_node->work(); + work_once(const std::span &nodes) { + bool something_happened = false; + for (auto ¤tNode : nodes) { + auto result = currentNode->work(); + if (result == work_return_t::ERROR) { + return work_return_t::ERROR; + } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { + // nothing + } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { + something_happened = true; + } + } + return something_happened ? work_return_t::OK : work_return_t::DONE; + } + + // todo: could be moved to base class, but would make `start()` virtual or require CRTP + // todo: iterate api for continuous flowgraphs vs ones that become "DONE" at some point + void + run_and_wait() { + start(); + this->wait_done(); + } + + void + start() { + if (this->_state == IDLE) { + this->init(); + } + if (this->_state != INITIALISED) { + fmt::print("simple scheduler work(): graph not initialised"); + return; + } + if constexpr (executionPolicy == single_threaded) { + this->_state = RUNNING; + work_return_t result; + auto nodelist = std::span{this->_graph.blocks()}; + while ((result = work_once(nodelist)) == work_return_t::OK) { if (result == work_return_t::ERROR) { - return work_return_t::ERROR; - } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS) { - // nothing - } else if (result == work_return_t::DONE) { - // nothing - } else if (result == work_return_t::OK) { - something_happened = true; - } else if (result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { - something_happened = true; + this->_state = ERROR; + return; } } - run = something_happened; + this->_state = STOPPED; + } else if (executionPolicy == multi_threaded) { + this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);} ); + } else { + throw std::invalid_argument("Unknown execution Policy"); } - - return work_return_t::DONE; } }; @@ -73,19 +250,25 @@ class simple : public node { * Breadth first traversal scheduler which traverses the graph starting from the source nodes in a breath first fashion * detecting cycles and nodes which can be reached from several source nodes. */ -class breadth_first : public node { - using node_t = fair::graph::node_model *; - init_proof _init; - fair::graph::graph _graph; - std::vector _nodelist; - +template +class breadth_first : public scheduler_base { + std::vector _nodelist; + std::vector> _job_lists{}; public: - explicit breadth_first(fair::graph::graph &&graph) : _init{ fair::graph::scheduler::init(graph) }, _graph(std::move(graph)) { + explicit breadth_first(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("breadth-first-pool", thread_pool::CPU_BOUND)) + : scheduler_base(std::move(graph), thread_pool) { + } + + void + init() { + using node_t = node_model *; + scheduler_base::init(); + // calculate adjacency list std::map> _adjacency_list{}; std::vector _source_nodes{}; // compute the adjacency list std::set node_reached; - for (auto &e : _graph.edges()) { + for (auto &e : this->_graph.edges()) { _adjacency_list[e._src_node].push_back(e._dst_node); _source_nodes.push_back(e._src_node); node_reached.insert(e._dst_node); @@ -96,7 +279,9 @@ class breadth_first : public node { std::set reached; // add all source nodes to queue for (node_t source_node : _source_nodes) { - queue.push(source_node); + if (!reached.contains(source_node)) { + queue.push(source_node); + } reached.insert(source_node); } // process all nodes, adding all unvisited child nodes to the queue @@ -104,7 +289,7 @@ class breadth_first : public node { node_t current_node = queue.front(); queue.pop(); _nodelist.push_back(current_node); - if (_adjacency_list.contains(current_node)) { // current_node has outgoing edges + if (_adjacency_list.contains(current_node)) { // node has outgoing edges for (auto &dst : _adjacency_list.at(current_node)) { if (!reached.contains(dst)) { // detect cycles. this could be removed if we guarantee cycle free graphs earlier queue.push(dst); @@ -113,24 +298,74 @@ class breadth_first : public node { } } } + // generate job list + if constexpr (executionPolicy == multi_threaded) { + const auto n_batches = std::min(static_cast(this->_pool->maxThreads()), _nodelist.size()); + this->_job_lists.reserve(n_batches); + for (std::size_t i = 0; i < n_batches; i++) { + // create job-set for thread + auto &job = this->_job_lists.emplace_back(); + job.reserve(_nodelist.size() / n_batches + 1); + for (std::size_t j = i; j < _nodelist.size(); j += n_batches) { + job.push_back(_nodelist[j]); + } + } + } } + template work_return_t - work() { - if (!_init) { - return work_return_t::ERROR; - } - while (true) { - bool anything_happened = false; - for (auto current_node : _nodelist) { - auto res = current_node->work(); - anything_happened |= (res == work_return_t::OK || res == work_return_t::INSUFFICIENT_OUTPUT_ITEMS); + work_once(const std::span &nodes) { + bool something_happened = false; + for (auto ¤tNode : nodes) { + auto result = currentNode->work(); + if (result == work_return_t::ERROR) { + return work_return_t::ERROR; + } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { + // nothing + } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { + something_happened = true; } - if (!anything_happened) { - return work_return_t::DONE; + } + return something_happened ? work_return_t::OK : work_return_t::DONE; + } + + void + run_and_wait() { + start(); + this->wait_done(); + } + + void + start() { + if (this->_state == IDLE) { + this->init(); + } + if (this->_state != INITIALISED) { + fmt::print("simple scheduler work(): graph not initialised"); + return; + } + if constexpr (executionPolicy == single_threaded) { + this->_state = RUNNING; + work_return_t result; + auto nodelist = std::span{this->_nodelist}; + while ((result = work_once(nodelist)) == work_return_t::OK) { + if (result == work_return_t::ERROR) { + this->_state = ERROR; + return; + } } + this->_state = STOPPED; + } else if (executionPolicy == multi_threaded) { + this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);}); + } else { + throw std::invalid_argument("Unknown execution Policy"); } } + + [[nodiscard]] const std::vector> &getJobLists() const { + return _job_lists; + } }; } // namespace fair::graph::scheduler diff --git a/include/thread_pool.hpp b/include/thread_pool.hpp index 21fa045c..d1c7184b 100644 --- a/include/thread_pool.hpp +++ b/include/thread_pool.hpp @@ -321,7 +321,6 @@ concept ThreadPool = requires(T t, std::function &&func) { * } * @endcode */ -template class BasicThreadPool { using Task = thread_pool::detail::Task; using TaskQueue = thread_pool::detail::TaskQueue; @@ -351,6 +350,7 @@ class BasicThreadPool { int _schedulingPriority = 0; const std::string _poolName; + const TaskType _taskType; const uint32_t _minThreads; const uint32_t _maxThreads; @@ -358,8 +358,8 @@ class BasicThreadPool { std::chrono::microseconds sleepDuration = std::chrono::milliseconds(1); std::chrono::milliseconds keepAliveDuration = std::chrono::seconds(10); - BasicThreadPool(const std::string_view &name = generateName(), uint32_t min = std::thread::hardware_concurrency(), uint32_t max = std::thread::hardware_concurrency()) - : _poolName(name), _minThreads(min), _maxThreads(max) { + BasicThreadPool(const std::string_view &name = generateName(), const TaskType taskType = TaskType::CPU_BOUND, uint32_t min = std::thread::hardware_concurrency(), uint32_t max = std::thread::hardware_concurrency()) + : _poolName(name), _taskType(taskType), _minThreads(min), _maxThreads(max) { assert(min > 0 && "minimum number of threads must be > 0"); assert(min <= max && "minimum number of threads must be <= maximum number of threads"); for (uint32_t i = 0; i < _minThreads; ++i) { @@ -435,7 +435,7 @@ class BasicThreadPool { _taskQueue.push(createTask(std::forward(func), std::forward(args)...)); _condition.notify_one(); - if constexpr (taskType == TaskType::IO_BOUND) { + if (_taskType == TaskType::IO_BOUND) { spinWait.spinOnce(); spinWait.spinOnce(); while (_taskQueue.size() > 0) { @@ -493,7 +493,7 @@ class BasicThreadPool { thread::setThreadName(fmt::format("{}#{}", _poolName, threadID), thread); thread::setThreadSchedulingParameter(_schedulingPolicy, _schedulingPriority, thread); if (!_affinityMask.empty()) { - if (taskType == TaskType::IO_BOUND) { + if (_taskType == TaskType::IO_BOUND) { thread::setThreadAffinity(_affinityMask); return; } @@ -636,12 +636,9 @@ class BasicThreadPool { } while (running); } }; -template -inline std::atomic BasicThreadPool::_globalPoolId = 0U; -template -inline std::atomic BasicThreadPool::_taskID = 0U; -static_assert(ThreadPool>); -static_assert(ThreadPool>); +inline std::atomic BasicThreadPool::_globalPoolId = 0U; +inline std::atomic BasicThreadPool::_taskID = 0U; +static_assert(ThreadPool); } diff --git a/test/app_grc.cpp b/test/app_grc.cpp index 2dbd5f35..5b0d8510 100644 --- a/test/app_grc.cpp +++ b/test/app_grc.cpp @@ -54,7 +54,7 @@ main(int argc, char *argv[]) { assert(graph_saved_source + "\n" == graph_expected_source); fair::graph::scheduler::simple scheduler(std::move(graph)); - scheduler.work(); + scheduler.run_and_wait(); } // Test if we get the same graph when saving it and loading the saved diff --git a/test/qa_dynamic_port.cpp b/test/qa_dynamic_port.cpp index 9e7c067d..9b957a8d 100644 --- a/test/qa_dynamic_port.cpp +++ b/test/qa_dynamic_port.cpp @@ -169,7 +169,7 @@ const boost::ut::suite PortApiTests = [] { expect(eq(connection_result_t::SUCCESS, flow.connect<"sum">(added).to<"sink">(out))); fair::graph::scheduler::simple sched{ std::move(flow) }; - sched.work(); + sched.run_and_wait(); }; #ifdef ENABLE_DYNAMIC_PORTS diff --git a/test/qa_hier_node.cpp b/test/qa_hier_node.cpp index c782cc28..b79354b5 100644 --- a/test/qa_hier_node.cpp +++ b/test/qa_hier_node.cpp @@ -50,7 +50,7 @@ class hier_node : public fg::node_model { using in_port_t = fg::IN; - fg::scheduler::simple _scheduler; + fg::scheduler::simple<> _scheduler; fg::graph make_graph() { @@ -87,7 +87,8 @@ class hier_node : public fg::node_model { fg::work_return_t work() override { - return _scheduler.work(); + _scheduler.run_and_wait(); + return fair::graph::work_return_t::DONE; } void * @@ -194,7 +195,9 @@ make_graph(std::size_t events_count) { int main() { - fg::scheduler::simple scheduler(make_graph(10)); + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2,2); // use custom pool to limit number of threads for emscripten - scheduler.work(); + fg::scheduler::simple scheduler(make_graph(10), thread_pool); + + scheduler.run_and_wait(); } diff --git a/test/qa_scheduler.cpp b/test/qa_scheduler.cpp index 67d1cef5..14607f65 100644 --- a/test/qa_scheduler.cpp +++ b/test/qa_scheduler.cpp @@ -4,145 +4,157 @@ #if defined(__clang__) && __clang_major__ >= 16 // clang 16 does not like ut's default reporter_junit due to some issues with stream buffers and output redirection -template<> +template <> auto boost::ut::cfg = boost::ut::runner>{}; #endif -namespace fg = fair::graph; +namespace fg = fair::graph; -using trace_vector = std::vector; +using trace_vector_type = std::vector; +class tracer{ + std::mutex _trace_mutex; + trace_vector_type _trace_vector; +public: + void trace(std::string_view id) { + std::scoped_lock lock{ _trace_mutex }; + if (_trace_vector.empty() || _trace_vector.back() != id) { + _trace_vector.emplace_back(id); + } + } -static void -trace(trace_vector &traceVector, std::string_view id) { - if (traceVector.empty() || traceVector.back() != id) { - traceVector.emplace_back(id); + trace_vector_type get_vec() { + std::scoped_lock lock{_trace_mutex}; + return {_trace_vector}; } -} +}; // define some example graph nodes template class count_source : public fg::node, fg::OUT::max(), "out">> { - trace_vector &tracer; - std::size_t count = 0; - + tracer &_tracer; + std::size_t _count = 0; public: - count_source(trace_vector &trace, std::string_view name) : tracer{ trace } { this->_name = name; } + count_source(tracer &trace, std::string_view name) : _tracer{trace} { this->_name = name;} constexpr std::make_signed_t - available_samples(const count_source & /*d*/) noexcept { - const auto ret = static_cast>(N - count); + available_samples(const count_source &/*d*/) noexcept { + const auto ret = static_cast>(N - _count); return ret > 0 ? ret : -1; // '-1' -> DONE, produced enough samples } constexpr T process_one() { - trace(tracer, this->name()); - return static_cast(count++); + _tracer.trace(this->name()); + return static_cast(_count++); } }; -template -class expect_sink : public fg::node, fg::IN::max(), "in">> { - trace_vector &tracer; - std::int64_t count = 0; +template +class expect_sink : public fg::node, fg::IN::max(), "in">> { + tracer &_tracer; + std::int64_t _count = 0; std::function _checker; - public: - expect_sink(trace_vector &trace, std::string_view name, std::function &&checker) : tracer{ trace }, _checker(std::move(checker)) { this->_name = name; } + expect_sink(tracer &trace, std::string_view name, std::function &&checker) : _tracer{trace}, _checker(std::move(checker)) { this->_name = name;} + + ~expect_sink() { + boost::ut::expect(boost::ut::that % _count == N); + } [[nodiscard]] fg::work_return_t process_bulk(std::span input) noexcept { - trace(tracer, this->name()); - for (auto data : input) { - _checker(count, data); - count++; + _tracer.trace(this->name()); + for (auto data: input) { + _checker(_count, data); + _count++; } return fg::work_return_t::OK; } - constexpr void - process_one(T /*a*/) const noexcept { - trace(tracer, this->name()); + process_one(T /*a*/) noexcept { + _tracer.trace(this->name()); } }; template() * std::declval())> class scale : public fg::node, fg::IN::max(), "original">, fg::OUT::max(), "scaled">> { - trace_vector &tracer; - + tracer &_tracer; public: - scale(trace_vector &trace, std::string_view name) : tracer{ trace } { this->_name = name; } - + scale(tracer &trace, std::string_view name) : _tracer{trace} {this->_name = name;} template V> [[nodiscard]] constexpr auto - process_one(V a) const noexcept { - trace(tracer, this->name()); + process_one(V a) noexcept { + _tracer.trace(this->name()); return a * Scale; } }; template() + std::declval())> -class adder : public fg::node, fg::IN::max(), "addend0">, fg::IN::max(), "addend1">, - fg::OUT::max(), "sum">> { - trace_vector &tracer; - +class adder : public fg::node, fg::IN::max(), "addend0">, fg::IN::max(), "addend1">, fg::OUT::max(), "sum">> { + tracer &_tracer; public: - adder(trace_vector &trace, std::string_view name) : tracer{ trace } { this->_name = name; } - + adder(tracer &trace, std::string_view name) : _tracer(trace) {this->_name = name;} template V> [[nodiscard]] constexpr auto - process_one(V a, V b) const noexcept { - trace(tracer, this->name()); + process_one(V a, V b) noexcept { + _tracer.trace(this->name()); return a + b; } }; fair::graph::graph -get_graph_linear(trace_vector &traceVector) { +get_graph_linear(tracer &trace) { using fg::port_direction_t::INPUT; using fg::port_direction_t::OUTPUT; - // Nodes need to be alive for as long as the flow is +// Nodes need to be alive for as long as the flow is fg::graph flow; - // Generators - auto &source1 = flow.make_node>(traceVector, "s1"); - auto &scale_block1 = flow.make_node>(traceVector, "mult1"); - auto &scale_block2 = flow.make_node>(traceVector, "mult2"); - auto &sink = flow.make_node>(traceVector, "out", [](std::uint64_t count, std::uint64_t data) { boost::ut::expect(boost::ut::that % data == 8 * count); }); - - std::ignore = flow.connect<"scaled">(scale_block2).to<"in">(sink); - std::ignore = flow.connect<"scaled">(scale_block1).to<"original">(scale_block2); - std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1); +// Generators + auto& source1 = flow.make_node>(trace, "s1"); + auto& scale_block1 = flow.make_node>(trace, "mult1"); + auto& scale_block2 = flow.make_node>(trace, "mult2"); + auto& sink = flow.make_node>(trace, "out", [](std::uint64_t count, std::uint64_t data) { + boost::ut::expect(boost::ut::that % data == 8 * count); + } ); + + std::ignore = flow.connect<"scaled">(scale_block2).to<"in">(sink); + std::ignore = flow.connect<"scaled">(scale_block1).to<"original">(scale_block2); + std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1); return flow; } fair::graph::graph -get_graph_parallel(trace_vector &traceVector) { +get_graph_parallel(tracer &trace) { using fg::port_direction_t::INPUT; using fg::port_direction_t::OUTPUT; - // Nodes need to be alive for as long as the flow is +// Nodes need to be alive for as long as the flow is fg::graph flow; - // Generators - auto &source1 = flow.make_node>(traceVector, "s1"); - auto &scale_block1a = flow.make_node>(traceVector, "mult1a"); - auto &scale_block2a = flow.make_node>(traceVector, "mult2a"); - auto &sink_a = flow.make_node>(traceVector, "outa", [](std::uint64_t count, std::uint64_t data) { boost::ut::expect(boost::ut::that % data == 6 * count); }); - auto &scale_block1b = flow.make_node>(traceVector, "mult1b"); - auto &scale_block2b = flow.make_node>(traceVector, "mult2b"); - auto &sink_b = flow.make_node>(traceVector, "outb", [](std::uint64_t count, std::uint64_t data) { boost::ut::expect(boost::ut::that % data == 15 * count); }); - - std::ignore = flow.connect<"scaled">(scale_block1a).to<"original">(scale_block2a); - std::ignore = flow.connect<"scaled">(scale_block1b).to<"original">(scale_block2b); - std::ignore = flow.connect<"scaled">(scale_block2b).to<"in">(sink_b); - std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1a); - std::ignore = flow.connect<"scaled">(scale_block2a).to<"in">(sink_a); - std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1b); +// Generators + auto& source1 = flow.make_node>(trace, "s1"); + auto& scale_block1a = flow.make_node>(trace, "mult1a"); + auto& scale_block2a = flow.make_node>(trace, "mult2a"); + auto& sink_a = flow.make_node>(trace, "outa", [](std::uint64_t count, std::uint64_t data) { + boost::ut::expect(boost::ut::that % data == 6 * count); + } ); + auto& scale_block1b = flow.make_node>(trace, "mult1b"); + auto& scale_block2b = flow.make_node>(trace, "mult2b"); + auto& sink_b = flow.make_node>(trace, "outb", [](std::uint64_t count, std::uint64_t data) { + boost::ut::expect(boost::ut::that % data == 15 * count); + } ); + + std::ignore = flow.connect<"scaled">(scale_block1a).to<"original">(scale_block2a); + std::ignore = flow.connect<"scaled">(scale_block1b).to<"original">(scale_block2b); + std::ignore = flow.connect<"scaled">(scale_block2b).to<"in">(sink_b); + std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1a); + std::ignore = flow.connect<"scaled">(scale_block2a).to<"in">(sink_a); + std::ignore = flow.connect<"out">(source1).to<"original">(scale_block1b); return flow; } + /** * sets up an example graph * ┌───────────┐ @@ -160,101 +172,169 @@ get_graph_parallel(trace_vector &traceVector) { * └───────────┘ */ fair::graph::graph -get_graph_scaled_sum(trace_vector &traceVector) { +get_graph_scaled_sum(tracer &trace) { using fg::port_direction_t::INPUT; using fg::port_direction_t::OUTPUT; - // Nodes need to be alive for as long as the flow is +// Nodes need to be alive for as long as the flow is fg::graph flow; - // Generators - auto &source1 = flow.make_node>(traceVector, "s1"); - auto &source2 = flow.make_node>(traceVector, "s2"); - auto &scale_block = flow.make_node>(traceVector, "mult"); - auto &add_block = flow.make_node>(traceVector, "add"); - auto &sink = flow.make_node>(traceVector, "out", [](std::uint64_t count, std::uint64_t data) { boost::ut::expect(boost::ut::that % data == (2 * count) + count); }); +// Generators + auto& source1 = flow.make_node>(trace, "s1"); + auto& source2 = flow.make_node>(trace, "s2"); + auto& scale_block = flow.make_node>(trace, "mult"); + auto& add_block = flow.make_node>(trace, "add"); + auto& sink = flow.make_node>(trace, "out", [](std::uint64_t count, std::uint64_t data) { + boost::ut::expect(boost::ut::that % data == (2 * count) + count); + } ); - std::ignore = flow.connect<"out">(source1).to<"original">(scale_block); - std::ignore = flow.connect<"scaled">(scale_block).to<"addend0">(add_block); - std::ignore = flow.connect<"out">(source2).to<"addend1">(add_block); - std::ignore = flow.connect<"sum">(add_block).to<"in">(sink); + std::ignore = flow.connect<"out">(source1).to<"original">(scale_block); + std::ignore = flow.connect<"scaled">(scale_block).to<"addend0">(add_block); + std::ignore = flow.connect<"out">(source2).to<"addend1">(add_block); + std::ignore = flow.connect<"sum">(add_block).to<"in">(sink); return flow; } +template +void check_node_names(const std::vector &joblist, std::set set) { + boost::ut::expect(boost::ut::that % joblist.size() == set.size()); + for (auto &node: joblist) { + boost::ut::expect(boost::ut::that % set.contains(std::string(node->name()))) << fmt::format("{} not in {}\n", node->name(), set); + } +} + const boost::ut::suite SchedulerTests = [] { using namespace boost::ut; using namespace fair::graph; - - "SimpleScheduler_linear"_test = [] { - using scheduler = fair::graph::scheduler::simple; - trace_vector t{}; - auto sched = scheduler{ get_graph_linear(t) }; - sched.work(); + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2, 2); + + "SimpleScheduler_linear"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::simple<>; + tracer trace{}; + auto sched = scheduler{get_graph_linear(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 8u); - expect(boost::ut::that % t == trace_vector{ "s1", "mult1", "mult2", "out", "s1", "mult1", "mult2", "out" }); + expect(boost::ut::that % t == trace_vector_type{ "s1", "mult1", "mult2", "out", "s1", "mult1", "mult2", "out" }); }; - "BreadthFirstScheduler_linear"_test = [] { - using scheduler = fair::graph::scheduler::breadth_first; - trace_vector t{}; - auto sched = scheduler{ get_graph_linear(t) }; - sched.work(); + "BreadthFirstScheduler_linear"_test = [&] { + using scheduler = fair::graph::scheduler::breadth_first<>; + tracer trace{}; + auto sched = scheduler{get_graph_linear(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 8u); - expect(boost::ut::that % t == trace_vector{ "s1", "mult1", "mult2", "out", "s1", "mult1", "mult2", "out" }); + expect(boost::ut::that % t == trace_vector_type{ "s1", "mult1", "mult2", "out", "s1", "mult1", "mult2", "out"}); }; - "SimpleScheduler_parallel"_test = [] { - using scheduler = fair::graph::scheduler::simple; - trace_vector t{}; - auto sched = scheduler{ get_graph_parallel(t) }; - sched.work(); + "SimpleScheduler_parallel"_test = [&] { + using scheduler = fair::graph::scheduler::simple<>; + tracer trace{}; + auto sched = scheduler{get_graph_parallel(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 14u); - expect(boost::ut::that % t == trace_vector{ "s1", "mult1a", "mult2a", "outa", "mult1b", "mult2b", "outb", "s1", "mult1a", "mult2a", "outa", "mult1b", "mult2b", "outb" }); + expect(boost::ut::that % t == trace_vector_type{ "s1", "mult1a", "mult2a", "outa", "mult1b", "mult2b", "outb", "s1", "mult1a", "mult2a", "outa", "mult1b", "mult2b", "outb"}); }; - "BreadthFirstScheduler_parallel"_test = [] { - using scheduler = fair::graph::scheduler::breadth_first; - trace_vector t{}; - auto sched = scheduler{ get_graph_parallel(t) }; - sched.work(); + "BreadthFirstScheduler_parallel"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::breadth_first<>; + tracer trace{}; + auto sched = scheduler{get_graph_parallel(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 14u); - expect(boost::ut::that % t - == trace_vector{ - "s1", - "mult1a", - "mult1b", - "mult2a", - "mult2b", - "outa", - "outb", - "s1", - "mult1a", - "mult1b", - "mult2a", - "mult2b", - "outa", - "outb", - }); + expect(boost::ut::that % t == trace_vector_type{"s1", "mult1a", "mult1b", "mult2a", "mult2b", "outa", "outb", "s1", "mult1a", "mult1b", "mult2a", "mult2b", "outa", "outb", }); }; - "SimpleScheduler_scaled_sum"_test = [] { - using scheduler = fair::graph::scheduler::simple; + "SimpleScheduler_scaled_sum"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::simple<>; // construct an example graph and get an adjacency list for it - trace_vector t{}; - auto sched = scheduler{ get_graph_scaled_sum(t) }; - sched.work(); + tracer trace{}; + auto sched = scheduler{get_graph_scaled_sum(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 10u); - expect(boost::ut::that % t == trace_vector{ "s1", "s2", "mult", "add", "out", "s1", "s2", "mult", "add", "out" }); + expect(boost::ut::that % t == trace_vector_type{ "s1", "s2", "mult", "add", "out", "s1", "s2", "mult", "add", "out"}); }; - "BreadthFirstScheduler_scaled_sum"_test = [] { - using scheduler = fair::graph::scheduler::breadth_first; - trace_vector t{}; - auto sched = scheduler{ get_graph_scaled_sum(t) }; - sched.work(); + "BreadthFirstScheduler_scaled_sum"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::breadth_first<>; + tracer trace{}; + auto sched = scheduler{get_graph_scaled_sum(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); expect(boost::ut::that % t.size() == 10u); - expect(boost::ut::that % t == trace_vector{ "s1", "s2", "mult", "add", "out", "s1", "s2", "mult", "add", "out" }); + expect(boost::ut::that % t == trace_vector_type{ "s1", "s2", "mult", "add", "out", "s1", "s2", "mult", "add", "out"}); + }; + + "SimpleScheduler_linear_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::simple; + tracer trace{}; + auto sched = scheduler{get_graph_linear(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(that % t.size() >= 8u); + }; + + "BreadthFirstScheduler_linear_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::breadth_first; + tracer trace{}; + auto sched = scheduler{get_graph_linear(trace), thread_pool}; + sched.init(); + expect(sched.getJobLists().size() == 2u); + check_node_names(sched.getJobLists()[0], {"s1", "mult2"}); + check_node_names(sched.getJobLists()[1], {"mult1", "out"}); + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(boost::ut::that % t.size() >= 8u); + }; + + "SimpleScheduler_parallel_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::simple; + tracer trace{}; + auto sched = scheduler{get_graph_parallel(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(boost::ut::that % t.size() >= 14u); + }; + + "BreadthFirstScheduler_parallel_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::breadth_first; + tracer trace{}; + auto sched = scheduler{get_graph_parallel(trace), thread_pool}; + sched.init(); + expect(sched.getJobLists().size() == 2u); + check_node_names(sched.getJobLists()[0], {"s1", "mult1b", "mult2b", "outb"}); + check_node_names(sched.getJobLists()[1], {"mult1a", "mult2a", "outa"}); + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(boost::ut::that % t.size() >= 14u); + }; + + "SimpleScheduler_scaled_sum_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::simple; + // construct an example graph and get an adjacency list for it + tracer trace{}; + auto sched = scheduler{get_graph_scaled_sum(trace), thread_pool}; + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(boost::ut::that % t.size() >= 10u); + }; + + "BreadthFirstScheduler_scaled_sum_multi_threaded"_test = [&thread_pool] { + using scheduler = fair::graph::scheduler::breadth_first; + tracer trace{}; + auto sched = scheduler{get_graph_scaled_sum(trace), thread_pool}; + sched.init(); + expect(sched.getJobLists().size() == 2u); + check_node_names(sched.getJobLists()[0], {"s1", "mult", "out"}); + check_node_names(sched.getJobLists()[1], {"s2", "add"}); + sched.run_and_wait(); + auto t = trace.get_vec(); + expect(boost::ut::that % t.size() >= 10u); }; }; diff --git a/test/qa_settings.cpp b/test/qa_settings.cpp index 3427d3a0..6ac6028b 100644 --- a/test/qa_settings.cpp +++ b/test/qa_settings.cpp @@ -238,10 +238,11 @@ const boost::ut::suite SettingsTests = [] { expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block1).to<"in">(block2))); expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(block2).to<"in">(sink))); - fair::graph::scheduler::simple sched{ std::move(flow_graph) }; + auto thread_pool = std::make_shared("custom pool", fair::thread_pool::CPU_BOUND, 2, 2); // use custom pool to limit number of threads for emscripten + fair::graph::scheduler::simple sched{ std::move(flow_graph), thread_pool }; expect(src.settings().auto_update_parameters().contains("sample_rate")); std::ignore = src.settings().set({ { "sample_rate", 49000.0f } }); - sched.work(); + sched.run_and_wait(); expect(eq(src.n_samples_produced, n_samples)) << "did not produce enough output samples"; expect(eq(sink.n_samples_consumed, n_samples)) << "did not consume enough input samples"; diff --git a/test/qa_tags.cpp b/test/qa_tags.cpp index 51f78e20..9decdb23 100644 --- a/test/qa_tags.cpp +++ b/test/qa_tags.cpp @@ -90,7 +90,7 @@ const boost::ut::suite TagPropagation = [] { expect(eq(connection_result_t::SUCCESS, flow_graph.connect<"out">(monitor2).to<"in">(sink2))); scheduler::simple sched{ std::move(flow_graph) }; - sched.work(); + sched.run_and_wait(); expect(eq(src.n_samples_produced, n_samples)) << "src did not produce enough output samples"; expect(eq(monitor1.n_samples_produced, n_samples)) << "monitor1 did not consume enough input samples"; diff --git a/test/qa_thread_pool.cpp b/test/qa_thread_pool.cpp index 9f5258a9..00e056ce 100644 --- a/test/qa_thread_pool.cpp +++ b/test/qa_thread_pool.cpp @@ -11,12 +11,12 @@ const boost::ut::suite ThreadPoolTests = [] { using namespace boost::ut; "Basic ThreadPool tests"_test = [] { - expect(nothrow([] { fair::thread_pool::BasicThreadPool(); })); - expect(nothrow([] { fair::thread_pool::BasicThreadPool(); })); + expect(nothrow([] { fair::thread_pool::BasicThreadPool("test", fair::thread_pool::IO_BOUND); })); + expect(nothrow([] { fair::thread_pool::BasicThreadPool("test2", fair::thread_pool::CPU_BOUND); })); std::atomic enqueueCount{0}; std::atomic executeCount{0}; - fair::thread_pool::BasicThreadPool pool("TestPool", 1, 2); + fair::thread_pool::BasicThreadPool pool("TestPool", fair::thread_pool::IO_BOUND, 1, 2); expect(nothrow([&] { pool.sleepDuration = std::chrono::milliseconds(1); })); expect(nothrow([&] { pool.keepAliveDuration = std::chrono::seconds(10); })); pool.waitUntilInitialised(); @@ -60,7 +60,7 @@ const boost::ut::suite ThreadPoolTests = [] { }; "contention tests"_test = [] { std::atomic counter{0}; - fair::thread_pool::BasicThreadPool pool("contention", 1, 4); + fair::thread_pool::BasicThreadPool pool("contention", fair::thread_pool::IO_BOUND, 1, 4); pool.waitUntilInitialised(); expect(that % pool.isInitialised()); expect(pool.numThreads() == 1_u); @@ -99,7 +99,7 @@ const boost::ut::suite ThreadPoolTests = [] { std::atomic counter{0}; // Pool with min and max thread count - fair::thread_pool::BasicThreadPool pool("count_test", minThreads, maxThreads); + fair::thread_pool::BasicThreadPool pool("count_test", fair::thread_pool::IO_BOUND, minThreads, maxThreads); pool.keepAliveDuration = std::chrono::milliseconds(10); // default is 10 seconds, reducing for testing pool.waitUntilInitialised();