Skip to content

Commit

Permalink
fix intit for moving the graph
Browse files Browse the repository at this point in the history
  • Loading branch information
wirew0rm committed Jun 23, 2023
1 parent a6eb0f0 commit 391e8e4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
17 changes: 11 additions & 6 deletions include/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,9 @@ class edge {

class graph {
private:
std::vector<std::function<connection_result_t()>> _connection_definitions;
std::vector<std::unique_ptr<node_model>> _nodes;
std::vector<edge> _edges;
std::vector<std::function<connection_result_t(graph&)>> _connection_definitions;
std::vector<std::unique_ptr<node_model>> _nodes;
std::vector<edge> _edges;

template<typename Node>
std::unique_ptr<node_model> &
Expand Down Expand Up @@ -689,8 +689,8 @@ class graph {
if (!is_node_known(source) || !is_node_known(destination)) {
throw fmt::format("Source {} and/or destination {} do not belong to this graph\n", source.name(), destination.name());
}
self._connection_definitions.push_back([self = &self, source = &source, source_port = &port, destination = &destination, destination_port = &destination_port]() {
return self->connect_impl<src_port_index, dst_port_index>(*source, *source_port, *destination, *destination_port);
self._connection_definitions.push_back([source = &source, source_port = &port, destination = &destination, destination_port = &destination_port](graph &graph) {
return graph.connect_impl<src_port_index, dst_port_index>(*source, *source_port, *destination, *destination_port);
});
return connection_result_t::SUCCESS;
}
Expand Down Expand Up @@ -745,6 +745,11 @@ class graph {
connect(Source &source, Port Source::*member_ptr);

public:
graph(graph&) = delete;
graph(graph&&) = default;
graph() = default;
graph &operator=(graph&) = delete;
graph &operator=(graph&&) = default;
/**
* @return a list of all blocks contained in this graph
* N.B. some 'blocks' may be (sub-)graphs themselves
Expand Down Expand Up @@ -822,7 +827,7 @@ class graph {
return dynamic_output_port(source, source_index).connect(dynamic_input_port(sink, sink_index));
}

const std::vector<std::function<connection_result_t()>> &
const std::vector<std::function<connection_result_t(graph&)>> &
connection_definitions() {
return _connection_definitions;
}
Expand Down
63 changes: 44 additions & 19 deletions include/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace fair::graph::scheduler {

enum execution_policy { single_threaded, multi_threaded };
enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSE, SHUTTING_DOWN };
enum SchedulerState { IDLE, INITIALISED, RUNNING, REQUESTED_STOP, REQUESTED_PAUSE, STOPPED, PAUSED, SHUTTING_DOWN };

template<typename scheduler_type, execution_policy = single_threaded>
class scheduler_base : public node<scheduler_type> {
Expand All @@ -20,21 +20,27 @@ class scheduler_base : public node<scheduler_type> {
SchedulerState _state = IDLE;
fair::graph::graph _graph;
std::shared_ptr<thread_pool_type> _pool;
std::vector<std::vector<node_t>> _job_lists{};
std::vector<std::vector<node_t>> _job_lists{}; // move to impl

public:
explicit scheduler_base(fair::graph::graph &&graph,
std::shared_ptr<thread_pool_type> thread_pool = std::make_shared<fair::thread_pool::BasicThreadPool>("simple-scheduler-pool", thread_pool::CPU_BOUND))
: _graph(std::move(graph)), _pool(std::move(thread_pool)){};

void
init(fair::graph::graph &graph) {
auto result = init_proof(std::all_of(graph.connection_definitions().begin(), graph.connection_definitions().end(),
[](auto &connection_definition) { return connection_definition() == connection_result_t::SUCCESS; }));
graph.clear_connection_definitions();
return result;
init() {
if (_state != IDLE) {
return;
}
auto result = std::all_of(_graph.connection_definitions().begin(), _graph.connection_definitions().end(),
[this](auto &connection_definition) { return connection_definition(_graph) == connection_result_t::SUCCESS; });
_graph.clear_connection_definitions();
if (result) {
_state = INITIALISED;
}
}

// todo: move to impl
template<typename node_t>
work_return_t
traverse_nodes(std::span<node_t> nodes) {
Expand Down Expand Up @@ -98,12 +104,16 @@ class scheduler_base : public node<scheduler_type> {
*/
template<execution_policy executionPolicy = single_threaded>
class simple : public scheduler_base<simple<executionPolicy>>{
using S = scheduler_base<simple<executionPolicy>>;
using node_t = S::node_t; //node_model*;
using thread_pool_type = S::thread_pool_type; //thread_pool::BasicThreadPool;
using Base = scheduler_base<simple<executionPolicy>>;
using thread_pool_type = Base::thread_pool_type; //thread_pool::BasicThreadPool;
public:
explicit simple(fair::graph::graph &&graph, std::shared_ptr<thread_pool_type> thread_pool = std::make_shared<thread_pool_type>("simple-scheduler-pool", thread_pool::CPU_BOUND))
: S(std::move(graph), thread_pool) {
: Base(std::forward<fair::graph::graph>(graph), thread_pool) {
}

void
init() {
Base::init();
// generate job list
if constexpr (executionPolicy == multi_threaded) {
const auto n_batches = std::min(static_cast<std::size_t>(this->_pool->maxThreads()), this->_graph.blocks().size());
Expand All @@ -119,11 +129,16 @@ class simple : public scheduler_base<simple<executionPolicy>>{
}
}


work_return_t
work() {
// if (!_init) {
// return work_return_t::ERROR;
// }
if (this->_state == IDLE) {
this->init();
}
if (this->_state != INITIALISED) {
fmt::print("simple scheduler work(): graph not initialised");
return work_return_t::ERROR;
}
if constexpr (executionPolicy == single_threaded) {
bool run = true;
while (run) {
Expand Down Expand Up @@ -157,13 +172,19 @@ class simple : public scheduler_base<simple<executionPolicy>>{
*/
template<execution_policy executionPolicy = single_threaded>
class breadth_first : public scheduler_base<breadth_first<executionPolicy>> {
using S = scheduler_base<breadth_first<executionPolicy>>;
using Base = scheduler_base<breadth_first<executionPolicy>>;
using node_t = node_model*;
using thread_pool_type = thread_pool::BasicThreadPool;
std::vector<node_t> _nodelist;
public:
explicit breadth_first(fair::graph::graph &&graph, std::shared_ptr<thread_pool_type> thread_pool = std::make_shared<thread_pool_type>("breadth-first-pool", thread_pool::CPU_BOUND))
: S(std::move(graph), thread_pool) {
: Base(std::move(graph), thread_pool) {
}

void
init() {
Base::init();
// calculate adjacency list
std::map<node_t, std::vector<node_t>> _adjacency_list{};
std::vector<node_t> _source_nodes{};
// compute the adjacency list
Expand Down Expand Up @@ -213,9 +234,13 @@ class breadth_first : public scheduler_base<breadth_first<executionPolicy>> {

work_return_t
work() {
// if (!_init) {
// return work_return_t::ERROR;
// }
if (this->_state == IDLE) {
this->init();
}
if (this->_state != INITIALISED) {
fmt::print("simple scheduler work(): graph not initialised");
return work_return_t::ERROR;
}
if constexpr (executionPolicy == single_threaded) {
bool run = true;
while (run) {
Expand Down
8 changes: 7 additions & 1 deletion test/qa_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ get_graph_scaled_sum(tracer &trace) {
}

template<typename node_type>
void check_node_names(std::vector<node_type> joblist, std::set<std::string> set) {
void check_node_names(const std::vector<node_type> &joblist, std::set<std::string> set) {
boost::ut::expect(boost::ut::that % joblist.size() == set.size());
for (auto &node: joblist) {
boost::ut::expect(boost::ut::that % set.contains(std::string(node->name()))) << fmt::format("{} not in {}\n", node->name(), set);
Expand Down Expand Up @@ -283,6 +283,8 @@ const boost::ut::suite SchedulerTests = [] {
using scheduler = fair::graph::scheduler::breadth_first<fg::scheduler::execution_policy::multi_threaded>;
tracer trace{};
auto sched = scheduler{get_graph_linear(trace), thread_pool};
sched.init();
expect(sched.getJobLists().size() == 2u);
check_node_names(sched.getJobLists()[0], {"s1", "mult2"});
check_node_names(sched.getJobLists()[1], {"mult1", "out"});
sched.work();
Expand All @@ -303,6 +305,8 @@ const boost::ut::suite SchedulerTests = [] {
using scheduler = fair::graph::scheduler::breadth_first<fg::scheduler::execution_policy::multi_threaded>;
tracer trace{};
auto sched = scheduler{get_graph_parallel(trace), thread_pool};
sched.init();
expect(sched.getJobLists().size() == 2u);
check_node_names(sched.getJobLists()[0], {"s1", "mult1b", "mult2b", "outb"});
check_node_names(sched.getJobLists()[1], {"mult1a", "mult2a", "outa"});
sched.work();
Expand All @@ -324,6 +328,8 @@ const boost::ut::suite SchedulerTests = [] {
using scheduler = fair::graph::scheduler::breadth_first<fg::scheduler::execution_policy::multi_threaded>;
tracer trace{};
auto sched = scheduler{get_graph_scaled_sum(trace), thread_pool};
sched.init();
expect(sched.getJobLists().size() == 2u);
check_node_names(sched.getJobLists()[0], {"s1", "mult", "out"});
check_node_names(sched.getJobLists()[1], {"s2", "add"});
sched.work();
Expand Down

0 comments on commit 391e8e4

Please sign in to comment.