diff --git a/include/scheduler.hpp b/include/scheduler.hpp index 4817e5c1..84ea8ff3 100644 --- a/include/scheduler.hpp +++ b/include/scheduler.hpp @@ -5,124 +5,129 @@ #include #include #include +#include namespace fair::graph::scheduler { -struct init_proof { - explicit init_proof(bool _success) : success(_success) {} - - init_proof(init_proof &&init) noexcept : init_proof(init.success) {} - - bool success = true; +enum execution_policy { single_threaded, multi_threaded }; +enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSE, SHUTTING_DOWN }; + +template +class scheduler_base : public node { +protected: + using node_t = node_model *; + using thread_pool_type = thread_pool::BasicThreadPool; + SchedulerState _state = IDLE; + fair::graph::graph _graph; + std::shared_ptr _pool; + std::vector> _job_lists{}; - init_proof & - operator=(init_proof &&init) noexcept { - this->success = init; - return *this; +public: + explicit scheduler_base(fair::graph::graph &&graph, + std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) + : _graph(std::move(graph)), _pool(std::move(thread_pool)){}; + + void + init(fair::graph::graph &graph) { + auto result = init_proof(std::all_of(graph.connection_definitions().begin(), graph.connection_definitions().end(), + [](auto &connection_definition) { return connection_definition() == connection_result_t::SUCCESS; })); + graph.clear_connection_definitions(); + return result; } - operator bool() const { return success; } -}; - -init_proof -init(fair::graph::graph &graph) { - auto result = init_proof(std::all_of(graph.connection_definitions().begin(), graph.connection_definitions().end(), - [](auto &connection_definition) { return connection_definition() == connection_result_t::SUCCESS; })); - graph.clear_connection_definitions(); - return result; -} - -enum execution_policy { single_threaded, multi_threaded }; - -namespace detail { -template -work_return_t traverse_nodes(std::span nodes) { - bool something_happened = false; - for (auto &node: nodes) { - auto result = node->work(); - if (result == work_return_t::ERROR) { - return work_return_t::ERROR; - } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { - // nothing - } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { - something_happened = true; + template + work_return_t + traverse_nodes(std::span nodes) { + bool something_happened = false; + for (auto ¤tNode : nodes) { + auto result = currentNode->work(); + if (result == work_return_t::ERROR) { + return work_return_t::ERROR; + } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { + // nothing + } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { + something_happened = true; + } } + return something_happened ? work_return_t::OK : work_return_t::DONE; } - return something_happened ? work_return_t::OK : work_return_t::DONE; -} -void run_on_pool(std::span job, std::size_t n_batches, std::atomic_uint64_t &progress, std::latch &running_threads, std::atomic_bool &stop_requested) { - uint32_t done = 0; - uint32_t progress_count = 0; - while (done < n_batches && !stop_requested) { - bool something_happened = traverse_nodes(job) == work_return_t::OK; - uint64_t progress_local, progress_new; - if (something_happened) { // something happened in this thread => increase progress and reset done count - do { - progress_local = progress.load(); - progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); - done = static_cast(progress_local & ((1ULL << 32) - 1)); - progress_new = (progress_count + 1ULL) << 32; - } while (!progress.compare_exchange_strong(progress_local, progress_new)); - progress.notify_all(); - } else { // nothing happened on this thread - uint32_t progress_count_old = progress_count; - do { - progress_local = progress.load(); - progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); - done = static_cast(progress_local & ((1ULL << 32) - 1)); - if (progress_count == progress_count_old) { // nothing happened => increase done count - progress_new = ((progress_count + 0ULL) << 32) + done + 1; - } else { // something happened in another thread => keep progress and done count and rerun this task without waiting - progress_new = ((progress_count + 0ULL) << 32) + done; + void + run_on_pool(std::span job, std::size_t n_batches, std::atomic_uint64_t &progress, std::latch &running_threads, std::atomic_bool &stop_requested) { + uint32_t done = 0; + uint32_t progress_count = 0; + while (done < n_batches && !stop_requested) { + bool something_happened = traverse_nodes(job) == work_return_t::OK; + uint64_t progress_local, progress_new; + if (something_happened) { // something happened in this thread => increase progress and reset done count + do { + progress_local = progress.load(); + progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); + done = static_cast(progress_local & ((1ULL << 32) - 1)); + progress_new = (progress_count + 1ULL) << 32; + } while (!progress.compare_exchange_strong(progress_local, progress_new)); + progress.notify_all(); + } else { // nothing happened on this thread + uint32_t progress_count_old = progress_count; + do { + progress_local = progress.load(); + progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); + done = static_cast(progress_local & ((1ULL << 32) - 1)); + if (progress_count == progress_count_old) { // nothing happened => increase done count + progress_new = ((progress_count + 0ULL) << 32) + done + 1; + } else { // something happened in another thread => keep progress and done count and rerun this task without waiting + progress_new = ((progress_count + 0ULL) << 32) + done; + } + } while (!progress.compare_exchange_strong(progress_local, progress_new)); + progress.notify_all(); + if (progress_count == progress_count_old && done < n_batches) { + progress.wait(progress_new); } - } while (!progress.compare_exchange_strong(progress_local, progress_new)); - progress.notify_all(); - if (progress_count == progress_count_old && done < n_batches) { - progress.wait(progress_new); } - } - } // while (done < n_batches) { - running_threads.count_down(); -} -} + } // while (done < n_batches) + running_threads.count_down(); + } + + [[nodiscard]] const std::vector> &getJobLists() const { + return _job_lists; + } +}; /** * Trivial loop based scheduler, which iterates over all nodes in definition order in the graph until no node did any processing */ template -class simple : public node>{ - using node_t = node_model*; - using thread_pool_type = thread_pool::BasicThreadPool; - init_proof _init; - fair::graph::graph _graph; - std::shared_ptr _pool; - std::vector> _job_lists{}; +class simple : public scheduler_base>{ + using S = scheduler_base>; + using node_t = S::node_t; //node_model*; + using thread_pool_type = S::thread_pool_type; //thread_pool::BasicThreadPool; public: - explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) - : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(std::move(thread_pool)) { + explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) + : S(std::move(graph), thread_pool) { // generate job list - const auto n_batches = std::min(static_cast(_pool->maxThreads()), _graph.blocks().size()); - _job_lists.reserve(n_batches); - for (std::size_t i = 0; i < n_batches; i++) { - // create job-set for thread - auto &job = _job_lists.emplace_back(); - job.reserve(_graph.blocks().size() / n_batches + 1); - for (std::size_t j = i; j < _graph.blocks().size(); j += n_batches) { - job.push_back(_graph.blocks()[j].get()); + if constexpr (executionPolicy == multi_threaded) { + const auto n_batches = std::min(static_cast(this->_pool->maxThreads()), this->_graph.blocks().size()); + this->_job_lists.reserve(n_batches); + for (std::size_t i = 0; i < n_batches; i++) { + // create job-set for thread + auto &job = this->_job_lists.emplace_back(); + job.reserve(this->_graph.blocks().size() / n_batches + 1); + for (std::size_t j = i; j < this->_graph.blocks().size(); j += n_batches) { + job.push_back(this->_graph.blocks()[j].get()); + } } } } work_return_t work() { - if (!_init) { - return work_return_t::ERROR; - } + // if (!_init) { + // return work_return_t::ERROR; + // } if constexpr (executionPolicy == single_threaded) { bool run = true; while (run) { - if (auto result = detail::traverse_nodes(std::span{_graph.blocks()}); result == work_return_t::ERROR) { + if (auto result = this->traverse_nodes(std::span{this->_graph.blocks()}); result == work_return_t::ERROR) { return result; } else { run = result == work_return_t::OK; @@ -131,11 +136,10 @@ class simple : public node>{ } else if (executionPolicy == multi_threaded) { std::atomic_bool stop_requested(false); std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work - std::latch running_threads{static_cast(_job_lists.size())}; // latch to wait for completion of the flowgraph - for (auto &job: _job_lists) { - //_pool->execute(detail::run_on_pool, std::span{job}, _job_lists.size(), progress, running_threads, stoken); - _pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { - detail::run_on_pool(std::span{job}, _job_lists.size(), progress, running_threads, stop_requested); + std::latch running_threads{static_cast(this->_job_lists.size())}; // latch to wait for completion of the flowgraph + for (auto &job: this->_job_lists) { + this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { + this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested); }); } running_threads.wait(); @@ -152,22 +156,19 @@ class simple : public node>{ * detecting cycles and nodes which can be reached from several source nodes. */ template -class breadth_first : public node> { +class breadth_first : public scheduler_base> { + using S = scheduler_base>; using node_t = node_model*; using thread_pool_type = thread_pool::BasicThreadPool; - init_proof _init; - fair::graph::graph _graph; std::vector _nodelist; - std::vector> _job_lists; - std::shared_ptr _pool; public: explicit breadth_first(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("breadth-first-pool", thread_pool::CPU_BOUND)) - : _init{fair::graph::scheduler::init(graph)}, _graph(std::move(graph)), _pool(std::move(thread_pool)) { + : S(std::move(graph), thread_pool) { std::map> _adjacency_list{}; std::vector _source_nodes{}; // compute the adjacency list std::set node_reached; - for (auto &e : _graph.edges()) { + for (auto &e : this->_graph.edges()) { _adjacency_list[e._src_node].push_back(e._dst_node); _source_nodes.push_back(e._src_node); node_reached.insert(e._dst_node); @@ -198,11 +199,11 @@ class breadth_first : public node> { } } // generate job list - const auto n_batches = std::min(static_cast(_pool->maxThreads()), _nodelist.size()); - _job_lists.reserve(n_batches); + const auto n_batches = std::min(static_cast(this->_pool->maxThreads()), _nodelist.size()); + this->_job_lists.reserve(n_batches); for (std::size_t i = 0; i < n_batches; i++) { // create job-set for thread - auto &job = _job_lists.emplace_back(); + auto &job = this->_job_lists.emplace_back(); job.reserve(_nodelist.size() / n_batches + 1); for (std::size_t j = i; j < _nodelist.size(); j += n_batches) { job.push_back(_nodelist[j]); @@ -212,13 +213,13 @@ class breadth_first : public node> { work_return_t work() { - if (!_init) { - return work_return_t::ERROR; - } + // if (!_init) { + // return work_return_t::ERROR; + // } if constexpr (executionPolicy == single_threaded) { bool run = true; while (run) { - if (auto result = detail::traverse_nodes(std::span{_nodelist}); result == work_return_t::ERROR) { + if (auto result = this->traverse_nodes(std::span{_nodelist}); result == work_return_t::ERROR) { return work_return_t::ERROR; } else { run = (result == work_return_t::OK); @@ -227,11 +228,10 @@ class breadth_first : public node> { } else if (executionPolicy == multi_threaded) { std::atomic_bool stop_requested; std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work - std::latch running_threads{static_cast(_job_lists.size())}; // latch to wait for completion of the flowgraph - for (auto &job: _job_lists) { - //_pool->execute(detail::run_on_pool, std::span{job}, _job_lists.size(), progress, running_threads, stoken); - _pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { - detail::run_on_pool(std::span{job}, _job_lists.size(), progress, running_threads, stop_requested); + std::latch running_threads{static_cast(this->_job_lists.size())}; // latch to wait for completion of the flowgraph + for (auto &job: this->_job_lists) { + this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { + this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested); }); } running_threads.wait(); @@ -242,10 +242,6 @@ class breadth_first : public node> { return work_return_t::DONE; } - const std::vector> &getJobLists() const { - return _job_lists; - } - }; } // namespace fair::graph::scheduler