Skip to content

Commit

Permalink
restructure base scheduler and add state machine
Browse files Browse the repository at this point in the history
  • Loading branch information
wirew0rm committed Jun 26, 2023
1 parent 391e8e4 commit 6c62ede
Showing 1 changed file with 127 additions and 69 deletions.
196 changes: 127 additions & 69 deletions include/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
#include <set>
#include <queue>
#include <thread_pool.hpp>
#include <latch>
#include <barrier>
#include <utility>

namespace fair::graph::scheduler {

enum execution_policy { single_threaded, multi_threaded };
enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN };
enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN, ERROR};

template<typename scheduler_type, execution_policy = single_threaded>
class scheduler_base : public node<scheduler_type> {
Expand All @@ -20,78 +20,106 @@ class scheduler_base : public node<scheduler_type> {
SchedulerState _state = IDLE;
fair::graph::graph _graph;
std::shared_ptr<thread_pool_type> _pool;
std::atomic_uint64_t _progress;
std::atomic_size_t _running_threads;
std::atomic_bool _stop_requested;
std::vector<std::vector<node_t>> _job_lists{}; // move to impl

public:
explicit scheduler_base(fair::graph::graph &&graph,
std::shared_ptr<thread_pool_type> thread_pool = std::make_shared<fair::thread_pool::BasicThreadPool>("simple-scheduler-pool", thread_pool::CPU_BOUND))
: _graph(std::move(graph)), _pool(std::move(thread_pool)){};

~scheduler_base() {
stop();
_state = SHUTTING_DOWN;
}

void stop(){
if (_state == STOPPED || _state == ERROR){
return;
}
if (_state == RUNNING) {
request_stop();
}
for (auto running = _running_threads.load(); running > 0ul; running = _running_threads.load()) {
//_running_threads.wait(running);
}
_state = STOPPED;
}

void request_stop() {
_stop_requested = true;
_state = REQUESTED_STOP;
}

void request_pause() {
_stop_requested = true;
_state = REQUESTED_PAUSE;
}

void
init() {
if (_state != IDLE) {
return;
}
auto result = std::all_of(_graph.connection_definitions().begin(), _graph.connection_definitions().end(),
[this](auto &connection_definition) { return connection_definition(_graph) == connection_result_t::SUCCESS; });
_graph.clear_connection_definitions();
if (result) {
_graph.clear_connection_definitions();
_state = INITIALISED;
} else {
_state = ERROR;
}
}

// todo: move to impl
template<typename node_t>
work_return_t
traverse_nodes(std::span<node_t> nodes) {
bool something_happened = false;
for (auto &currentNode : nodes) {
auto result = currentNode->work();
if (result == work_return_t::ERROR) {
return work_return_t::ERROR;
} else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) {
// nothing
} else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) {
something_happened = true;
}
void
run_on_pool(const std::vector<std::vector<node_t>> &jobs, const std::function<work_return_t(const std::span<const node_t>&)> work_function) {
_progress = 0;
_running_threads = jobs.size();
for (auto &jobset : jobs) {
_pool->execute([this, &jobset, work_function, &jobs]() {
pool_worker([&work_function, &jobset]() {
return work_function(jobset);
}, jobs.size());
});
}
return something_happened ? work_return_t::OK : work_return_t::DONE;
}

void
run_on_pool(std::span<node_model *> job, std::size_t n_batches, std::atomic_uint64_t &progress, std::latch &running_threads, std::atomic_bool &stop_requested) {
pool_worker(const std::function<work_return_t()>& work, std::size_t n_batches) {
uint32_t done = 0;
uint32_t progress_count = 0;
while (done < n_batches && !stop_requested) {
bool something_happened = traverse_nodes(job) == work_return_t::OK;
while (done < n_batches && !_stop_requested) {
bool something_happened = work() == work_return_t::OK;
uint64_t progress_local, progress_new;
if (something_happened) { // something happened in this thread => increase progress and reset done count
do {
progress_local = progress.load();
progress_local = _progress.load();
progress_count = static_cast<std::uint32_t>((progress_local >> 32) & ((1ULL << 32) - 1));
done = static_cast<std::uint32_t>(progress_local & ((1ULL << 32) - 1));
progress_new = (progress_count + 1ULL) << 32;
} while (!progress.compare_exchange_strong(progress_local, progress_new));
progress.notify_all();
} while (!_progress.compare_exchange_strong(progress_local, progress_new));
_progress.notify_all();
} else { // nothing happened on this thread
uint32_t progress_count_old = progress_count;
do {
progress_local = progress.load();
progress_local = _progress.load();
progress_count = static_cast<std::uint32_t>((progress_local >> 32) & ((1ULL << 32) - 1));
done = static_cast<std::uint32_t>(progress_local & ((1ULL << 32) - 1));
if (progress_count == progress_count_old) { // nothing happened => increase done count
progress_new = ((progress_count + 0ULL) << 32) + done + 1;
} else { // something happened in another thread => keep progress and done count and rerun this task without waiting
progress_new = ((progress_count + 0ULL) << 32) + done;
}
} while (!progress.compare_exchange_strong(progress_local, progress_new));
progress.notify_all();
} while (!_progress.compare_exchange_strong(progress_local, progress_new));
_progress.notify_all();
if (progress_count == progress_count_old && done < n_batches) {
progress.wait(progress_new);
_progress.wait(progress_new);
}
}
} // while (done < n_batches)
running_threads.count_down();
_running_threads.fetch_sub(1);
}

[[nodiscard]] const std::vector<std::vector<node_t>> &getJobLists() const {
Expand All @@ -105,11 +133,11 @@ class scheduler_base : public node<scheduler_type> {
template<execution_policy executionPolicy = single_threaded>
class simple : public scheduler_base<simple<executionPolicy>>{
using Base = scheduler_base<simple<executionPolicy>>;
using thread_pool_type = Base::thread_pool_type; //thread_pool::BasicThreadPool;
using node_t = node_model*;
using thread_pool_type = typename Base::thread_pool_type; //thread_pool::BasicThreadPool;
public:
explicit simple(fair::graph::graph &&graph, std::shared_ptr<thread_pool_type> thread_pool = std::make_shared<thread_pool_type>("simple-scheduler-pool", thread_pool::CPU_BOUND))
: Base(std::forward<fair::graph::graph>(graph), thread_pool) {
}
: Base(std::forward<fair::graph::graph>(graph), thread_pool) { }

void
init() {
Expand All @@ -129,40 +157,55 @@ class simple : public scheduler_base<simple<executionPolicy>>{
}
}

template<typename node_type>
work_return_t
work_once(const std::span<node_type> &nodes) {
bool something_happened = false;
for (auto &currentNode : nodes) {
auto result = currentNode->work();
if (result == work_return_t::ERROR) {
return work_return_t::ERROR;
} else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) {
// nothing
} else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) {
something_happened = true;
}
}
return something_happened ? work_return_t::OK : work_return_t::DONE;
}

work_return_t
work() {
start();
this->stop();
return work_return_t::DONE;
}

void
start() {
if (this->_state == IDLE) {
this->init();
}
if (this->_state != INITIALISED) {
fmt::print("simple scheduler work(): graph not initialised");
return work_return_t::ERROR;
return;
}
if constexpr (executionPolicy == single_threaded) {
bool run = true;
while (run) {
if (auto result = this->traverse_nodes(std::span{this->_graph.blocks()}); result == work_return_t::ERROR) {
return result;
} else {
run = result == work_return_t::OK;
this->_state = RUNNING;
work_return_t result;
auto nodelist = std::span{this->_graph.blocks()};
while ((result = work_once(nodelist)) == work_return_t::OK) {
if (result == work_return_t::ERROR) {
this->_state = ERROR;
return;
}
}
this->_state = STOPPED;
} else if (executionPolicy == multi_threaded) {
std::atomic_bool stop_requested(false);
std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work
std::latch running_threads{static_cast<std::ptrdiff_t>(this->_job_lists.size())}; // latch to wait for completion of the flowgraph
for (auto &job: this->_job_lists) {
this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() {
this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested);
});
}
running_threads.wait();
return work_return_t::DONE;
this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);} );
} else {
throw std::invalid_argument("Unknown execution Policy");
}
return work_return_t::DONE;
}
};

Expand Down Expand Up @@ -232,41 +275,56 @@ class breadth_first : public scheduler_base<breadth_first<executionPolicy>> {
}
}

template<typename node_type>
work_return_t
work_once(const std::span<node_type> &nodes) {
bool something_happened = false;
for (auto &currentNode : nodes) {
auto result = currentNode->work();
if (result == work_return_t::ERROR) {
return work_return_t::ERROR;
} else if (result == work_return_t::INSUFFICIENT_INPUT_ITEMS || result == work_return_t::DONE) {
// nothing
} else if (result == work_return_t::OK || result == work_return_t::INSUFFICIENT_OUTPUT_ITEMS) {
something_happened = true;
}
}
return something_happened ? work_return_t::OK : work_return_t::DONE;
}

work_return_t
work() {
start();
this->stop();
return work_return_t::DONE;
}

void
start() {
if (this->_state == IDLE) {
this->init();
}
if (this->_state != INITIALISED) {
fmt::print("simple scheduler work(): graph not initialised");
return work_return_t::ERROR;
return;
}
if constexpr (executionPolicy == single_threaded) {
bool run = true;
while (run) {
if (auto result = this->traverse_nodes(std::span{_nodelist}); result == work_return_t::ERROR) {
return work_return_t::ERROR;
} else {
run = (result == work_return_t::OK);
this->_state = RUNNING;
work_return_t result;
auto nodelist = std::span{this->_nodelist};
while ((result = work_once(nodelist)) == work_return_t::OK) {
if (result == work_return_t::ERROR) {
this->_state = ERROR;
return;
}
}
this->_state = STOPPED;
} else if (executionPolicy == multi_threaded) {
std::atomic_bool stop_requested;
std::atomic_uint64_t progress{0}; // upper uint32t: progress counter, lower uint32t: number of workers that finished all their work
std::latch running_threads{static_cast<std::ptrdiff_t>(this->_job_lists.size())}; // latch to wait for completion of the flowgraph
for (auto &job: this->_job_lists) {
this->_pool->execute([this, &job, &progress, &running_threads, &stop_requested]() {
this->run_on_pool(std::span{job}, this->_job_lists.size(), progress, running_threads, stop_requested);
});
}
running_threads.wait();
return work_return_t::DONE;
this->run_on_pool(this->_job_lists, [this](auto &job) {return this->work_once(job);});
} else {
throw std::invalid_argument("Unknown execution Policy");
}
return work_return_t::DONE;
}

};
} // namespace fair::graph::scheduler

Expand Down

0 comments on commit 6c62ede

Please sign in to comment.