diff --git a/include/scheduler.hpp b/include/scheduler.hpp index db810f73..8f1492a1 100644 --- a/include/scheduler.hpp +++ b/include/scheduler.hpp @@ -4,13 +4,13 @@ #include #include #include -#include +#include #include namespace fair::graph::scheduler { enum execution_policy { single_threaded, multi_threaded }; -enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN }; +enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN, ERROR}; template class scheduler_base : public node { @@ -20,6 +20,9 @@ class scheduler_base : public node { SchedulerState _state = IDLE; fair::graph::graph _graph; std::shared_ptr _pool; + std::atomic_uint64_t _progress; + std::atomic_size_t _running_threads; + std::atomic_bool _stop_requested; std::vector> _job_lists{}; // move to impl public: @@ -27,6 +30,34 @@ class scheduler_base : public node { std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) : _graph(std::move(graph)), _pool(std::move(thread_pool)){}; + ~scheduler_base() { + stop(); + _state = SHUTTING_DOWN; + } + + void stop(){ + if (_state == STOPPED || _state == ERROR){ + return; + } + if (_state == RUNNING) { + request_stop(); + } + for (auto running = _running_threads.load(); running > 0ul; running = _running_threads.load()) { + //_running_threads.wait(running); + } + _state = STOPPED; + } + + void request_stop() { + _stop_requested = true; + _state = REQUESTED_STOP; + } + + void request_pause() { + _stop_requested = true; + _state = REQUESTED_PAUSE; + } + void init() { if (_state != IDLE) { @@ -34,49 +65,46 @@ class scheduler_base : public node { } auto result = std::all_of(_graph.connection_definitions().begin(), _graph.connection_definitions().end(), [this](auto &connection_definition) { return connection_definition(_graph) == connection_result_t::SUCCESS; }); - _graph.clear_connection_definitions(); if (result) { + _graph.clear_connection_definitions(); _state = INITIALISED; + } else { + _state = ERROR; } } - // todo: move to impl - template - work_return_t - traverse_nodes(std::span nodes) { - bool something_happened = false; - for (auto ¤tNode : nodes) { - auto result = currentNode->work(); - if (result == work_return_t::ERROR) { - return work_return_t::ERROR; - } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { - // nothing - } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { - something_happened = true; - } + void + run_on_pool(const std::vector> &jobs, const std::function&)> work_function) { + _progress = 0; + _running_threads = jobs.size(); + for (auto &jobset : jobs) { + _pool->execute([this, &jobset, work_function, &jobs]() { + pool_worker([&work_function, &jobset]() { + return work_function(jobset); + }, jobs.size()); + }); } - return something_happened ? work_return_t::OK : work_return_t::DONE; } void - run_on_pool(std::span job, std::size_t n_batches, std::atomic_uint64_t &progress, std::latch &running_threads, std::atomic_bool &stop_requested) { + pool_worker(const std::function& work, std::size_t n_batches) { uint32_t done = 0; uint32_t progress_count = 0; - while (done < n_batches && !stop_requested) { - bool something_happened = traverse_nodes(job) == work_return_t::OK; + while (done < n_batches && !_stop_requested) { + bool something_happened = work() == work_return_t::OK; uint64_t progress_local, progress_new; if (something_happened) { // something happened in this thread => increase progress and reset done count do { - progress_local = progress.load(); + progress_local = _progress.load(); progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); done = static_cast(progress_local & ((1ULL << 32) - 1)); progress_new = (progress_count + 1ULL) << 32; - } while (!progress.compare_exchange_strong(progress_local, progress_new)); - progress.notify_all(); + } while (!_progress.compare_exchange_strong(progress_local, progress_new)); + _progress.notify_all(); } else { // nothing happened on this thread uint32_t progress_count_old = progress_count; do { - progress_local = progress.load(); + progress_local = _progress.load(); progress_count = static_cast((progress_local >> 32) & ((1ULL << 32) - 1)); done = static_cast(progress_local & ((1ULL << 32) - 1)); if (progress_count == progress_count_old) { // nothing happened => increase done count @@ -84,14 +112,14 @@ class scheduler_base : public node { } else { // something happened in another thread => keep progress and done count and rerun this task without waiting progress_new = ((progress_count + 0ULL) << 32) + done; } - } while (!progress.compare_exchange_strong(progress_local, progress_new)); - progress.notify_all(); + } while (!_progress.compare_exchange_strong(progress_local, progress_new)); + _progress.notify_all(); if (progress_count == progress_count_old && done < n_batches) { - progress.wait(progress_new); + _progress.wait(progress_new); } } } // while (done < n_batches) - running_threads.count_down(); + _running_threads.fetch_sub(1); } [[nodiscard]] const std::vector> &getJobLists() const { @@ -105,11 +133,11 @@ class scheduler_base : public node { template class simple : public scheduler_base>{ using Base = scheduler_base>; - using thread_pool_type = Base::thread_pool_type; //thread_pool::BasicThreadPool; + using node_t = node_model*; + using thread_pool_type = typename Base::thread_pool_type; //thread_pool::BasicThreadPool; public: explicit simple(fair::graph::graph &&graph, std::shared_ptr thread_pool = std::make_shared("simple-scheduler-pool", thread_pool::CPU_BOUND)) - : Base(std::forward(graph), thread_pool) { - } + : Base(std::forward(graph), thread_pool) { } void init() { @@ -129,40 +157,55 @@ class simple : public scheduler_base>{ } } + template + work_return_t + work_once(const std::span &nodes) { + bool something_happened = false; + for (auto ¤tNode : nodes) { + auto result = currentNode->work(); + if (result == work_return_t::ERROR) { + return work_return_t::ERROR; + } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { + // nothing + } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { + something_happened = true; + } + } + return something_happened ? work_return_t::OK : work_return_t::DONE; + } work_return_t work() { + start(); + this->stop(); + return work_return_t::DONE; + } + + void + start() { if (this->_state == IDLE) { this->init(); } if (this->_state != INITIALISED) { fmt::print("simple scheduler work(): graph not initialised"); - return work_return_t::ERROR; + return; } if constexpr (executionPolicy == single_threaded) { - bool run = true; - while (run) { - if (auto result = this->traverse_nodes(std::span{this->_graph.blocks()}); result == work_return_t::ERROR) { - return result; - } else { - run = result == work_return_t::OK; + this->_state = RUNNING; + work_return_t result; + auto nodelist = std::span{this->_graph.blocks()}; + while ((result = work_once(nodelist)) == work_return_t::OK) { + if (result == work_return_t::ERROR) { + this->_state = ERROR; + return; } } + this->_state = STOPPED; } else if (executionPolicy == multi_threaded) { - std::atomic_bool stop_requested(false); - std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work - std::latch running_threads{static_cast(this->_job_lists.size())}; // latch to wait for completion of the flowgraph - for (auto &job: this->_job_lists) { - this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { - this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested); - }); - } - running_threads.wait(); - return work_return_t::DONE; + this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);} ); } else { throw std::invalid_argument("Unknown execution Policy"); } - return work_return_t::DONE; } }; @@ -232,41 +275,56 @@ class breadth_first : public scheduler_base> { } } + template + work_return_t + work_once(const std::span &nodes) { + bool something_happened = false; + for (auto ¤tNode : nodes) { + auto result = currentNode->work(); + if (result == work_return_t::ERROR) { + return work_return_t::ERROR; + } else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) { + // nothing + } else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) { + something_happened = true; + } + } + return something_happened ? work_return_t::OK : work_return_t::DONE; + } + work_return_t work() { + start(); + this->stop(); + return work_return_t::DONE; + } + + void + start() { if (this->_state == IDLE) { this->init(); } if (this->_state != INITIALISED) { fmt::print("simple scheduler work(): graph not initialised"); - return work_return_t::ERROR; + return; } if constexpr (executionPolicy == single_threaded) { - bool run = true; - while (run) { - if (auto result = this->traverse_nodes(std::span{_nodelist}); result == work_return_t::ERROR) { - return work_return_t::ERROR; - } else { - run = (result == work_return_t::OK); + this->_state = RUNNING; + work_return_t result; + auto nodelist = std::span{this->_nodelist}; + while ((result = work_once(nodelist)) == work_return_t::OK) { + if (result == work_return_t::ERROR) { + this->_state = ERROR; + return; } } + this->_state = STOPPED; } else if (executionPolicy == multi_threaded) { - std::atomic_bool stop_requested; - std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work - std::latch running_threads{static_cast(this->_job_lists.size())}; // latch to wait for completion of the flowgraph - for (auto &job: this->_job_lists) { - this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() { - this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested); - }); - } - running_threads.wait(); - return work_return_t::DONE; + this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);}); } else { throw std::invalid_argument("Unknown execution Policy"); } - return work_return_t::DONE; } - }; } // namespace fair::graph::scheduler