From 950da68225b17146ec4f0d269d376eb7b90144f9 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Mon, 5 Aug 2024 03:18:48 -0400 Subject: [PATCH 01/29] pass existing tests --- .../{cost_estimate.h => cost_estimator.h} | 0 lib/compiler/include/compiler/graph_utils.h | 37 -- .../include/compiler/machine_mapping.h | 17 +- .../compiler/optimal_cost_state.struct.toml | 9 +- .../include/compiler/unity_algorithm.h | 2 +- lib/compiler/src/graph_utils.cc | 153 --------- lib/compiler/src/machine_mapping.cc | 315 +++++++++--------- lib/compiler/src/unity_algorithm.cc | 1 - lib/compiler/test/src/test_cost_estimator.h | 2 +- lib/compiler/test/src/test_optimal_cost.cc | 128 ++++--- .../sub_parallel_computation_graph.h | 3 + .../sub_parallel_computation_graph.cc | 24 ++ 12 files changed, 251 insertions(+), 440 deletions(-) rename lib/compiler/include/compiler/{cost_estimate.h => cost_estimator.h} (100%) delete mode 100644 lib/compiler/include/compiler/graph_utils.h delete mode 100644 lib/compiler/src/graph_utils.cc diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimator.h similarity index 100% rename from lib/compiler/include/compiler/cost_estimate.h rename to lib/compiler/include/compiler/cost_estimator.h diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h deleted file mode 100644 index 1370357837..0000000000 --- a/lib/compiler/include/compiler/graph_utils.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_GRAPH_UTILS_H -#define _FLEXFLOW_COMPILER_GRAPH_UTILS_H - -#include "compiler/unity_algorithm.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); - -// NOTE(@wmdi): I think we should have the following interfaces in the graph -// library eventually. - -template -void minimize(T &t, T const &v) { - if (v < t) { - t = v; - } -} - -template -void minimize(T &t, T const &v, Compare comp) { - if (comp(v, t)) { - t = v; - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 5d17cbb373..eab86ed7a9 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -3,7 +3,7 @@ #include "compiler/machine_mapping.dtg.h" #include "compiler/optimal_cost_state.dtg.h" -#include "cost_estimate.h" +#include "cost_estimator.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" @@ -29,10 +29,6 @@ struct OptimalCostResult { }; FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); -struct OptimalCostRuntimeCmp { - bool operator()(OptimalCostResult const &, OptimalCostResult const &); -}; - class OptimalCostCache { public: OptimalCostCache() = default; @@ -55,15 +51,4 @@ OptimalCostResult optimal_cost( } // namespace FlexFlow -// namespace std { -// -// template <> -// struct hash> { -// size_t operator()( -// std::unordered_map const &g) -// const; -// }; - -// }; // namespace std - #endif diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml index 50496f661b..9c8fabbf47 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/optimal_cost_state.struct.toml @@ -17,6 +17,7 @@ includes = [ "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", "utils/fmt/unordered_map.h", "utils/hash/unordered_map.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", ] [[fields]] @@ -28,9 +29,5 @@ name = "resource" type = "::FlexFlow::MachineSpecification" [[fields]] -name = "given_machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" - -[[fields]] -name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" \ No newline at end of file +name = "fixed_machine_views" +type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index abddef37ed..65ca694f78 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #include "compiler/machine_mapping.h" -#include "cost_estimate.h" +#include "cost_estimator.h" #include "machine_mapping.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc deleted file mode 100644 index 08db219a21..0000000000 --- a/lib/compiler/src/graph_utils.cc +++ /dev/null @@ -1,153 +0,0 @@ -#include "compiler/graph_utils.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/containers/without_order.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -namespace FlexFlow { - -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return get_serial_parallel_decomposition(pcg.raw_graph); -} - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { - NOT_IMPLEMENTED(); -} - -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return view_output_labelled_as_output_labelled_open(pcg.raw_graph); -} - -// std::vector -// get_sorted_node_input_edges(ParallelComputationGraph const &pcg, -// Node const &n) { -// std::unordered_map> -// incoming_edges = -// get_incoming_edges_by_idx(pcg, n); - -// std::vector result; -// for (auto const &p_id_edge_set : incoming_edges) { -// result.push_back(get_only(p_id_edge_set.second)); -// } - -// return result; -// } - -// std::unordered_map -// infer_tensor_shapes(ParallelComputationGraph const &pcg) { -// std::unordered_map result; -// for (Node const &n : get_topological_ordering(pcg)) { -// PCGOperatorAttrs op = pcg.raw_graph.at(n); - -// std::vector input_tensor_shapes = -// vector_transform([&](MultiDiEdge const &e) { return result.at(e); }, -// get_sorted_node_input_edges(pcg, n)); - -// std::vector output_tensor_shapes = -// get_output_shapes(op, input_tensor_shapes); - -// auto outgoing_edges = get_outgoing_edges_by_idx(pcg, n); - -// int i = 0; - -// for (auto const &[node_port, edges] : outgoing_edges) { -// for (MultiDiEdge const &e : edges) { -// result.insert({e, output_tensor_shapes[i++]}); -// } -// } -// } - -// assert(result.size() == get_edges(pcg.raw_graph).size()); - -// return result; -// } - -/* template */ -/* LabelledOpenMultiDiGraph */ -/* get_subgraph(LabelledOpenMultiDiGraph const &g, */ -/* std::unordered_set const &nodes, */ -/* InputSettings input_settings, */ -/* OutputSettings output_settings) { */ - -/* auto iview = LabelledOpenMultiDiGraphView(g) */ -/* .unsafe(); */ - -/* if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledOpenMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view(subgraph_view); */ -/* } else if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::EXCLUDE_OUTPUTS) { */ -/* LabelledUpwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else if (input_settings == InputSettings::EXCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledDownwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else { */ -/* LabelledMultiDiSubgraphView subgraph_view(*iview, - */ -/* nodes); - */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); - */ -/* } */ -/* } */ - -// struct GetNodes { -// template -// std::unordered_set operator()(T const &t) { -// return get_nodes(t); -// } -// }; - -// std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { -// return std::visit(GetNodes{}, sp.raw_variant); -// } - -// std::unordered_set get_nodes(SerialSplit const &serial) { -// return set_union( -// transform(serial.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(ParallelSplit const ¶llel) { -// return set_union( -// transform(parallel.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(Node const &node) { -// return {node}; -// } - -} // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 12eacb2a30..dd2be6bafc 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -1,6 +1,5 @@ #include "compiler/machine_mapping.h" -#include "compiler/cost_estimate.h" -#include "compiler/graph_utils.h" +#include "compiler/cost_estimator.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" @@ -12,6 +11,10 @@ #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" +#include "utils/containers/values.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/restrict_keys.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" @@ -20,6 +23,8 @@ #include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" #include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -31,6 +36,12 @@ bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } +void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { + if (m2.runtime < m1.runtime) { + m1 = m2; + } +} + OptimalCostResult OptimalCostResult::sequential_combine(OptimalCostResult const &s1, OptimalCostResult const &s2) { @@ -50,11 +61,6 @@ OptimalCostResult OptimalCostResult::infinity() { MachineMapping{std::unordered_map{}}}; } -bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, - OptimalCostResult const &rhs) { - return lhs.runtime < rhs.runtime; -} - std::optional OptimalCostCache::load(OptimalCostState const &state) const { if (contains_key(cache, state)) { @@ -82,7 +88,6 @@ std::vector> return result; } -// We may replace this by having unflattened AST std::pair decompose(SerialSplit const &serial) { if (serial.children.size() == 2) { @@ -120,11 +125,10 @@ GraphSplit float estimate_cost(SubParallelComputationGraph const &g, CostEstimator const &estimator, - MachineMapping const &device_mapping, - std::unordered_map const - &frontier_machine_views) { + std::unordered_map const + &machine_views) { // TODO: Consider parallelism - float cost = 0; + float cost = 1.; // for (Node const &node : get_nodes(g.raw_graph)) { // std::vector incoming_edges = // get_incoming_edges(g.raw_graph, node); @@ -140,21 +144,20 @@ float estimate_cost(SubParallelComputationGraph const &g, return cost; } -void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { - minimize(m1, m2, OptimalCostRuntimeCmp{}); -} - struct MachineMappingSearcher { MachineMappingSearcher( - CostEstimator cost_estimator, + ParallelComputationGraph const &pcg, + CostEstimator const &cost_estimator, std::function( ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views, OptimalCostCache &cached_subgraph_costs) - : cost_estimator(cost_estimator), + : pcg(pcg), + cost_estimator(cost_estimator), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} + ParallelComputationGraph pcg; CostEstimator cost_estimator; std::function(ParallelLayerAttrs const &, MachineSpecification const &)> @@ -164,27 +167,20 @@ struct MachineMappingSearcher { struct OptimalCostFunctor { OptimalCostFunctor( MachineMappingSearcher *searcher, - SubParallelComputationGraph const &g, MachineSpecification resource, - std::unordered_map given_machine_views, - std::unordered_map - frontier_machine_views) - : searcher(searcher), g(g), resource(resource), - given_machine_views(given_machine_views), - frontier_machine_views(frontier_machine_views) {} + std::unordered_map fixed_machine_views) + : searcher(searcher), resource(resource), + fixed_machine_views(fixed_machine_views) {} MachineMappingSearcher *searcher; - SubParallelComputationGraph const &g; MachineSpecification resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; + std::unordered_map fixed_machine_views; template OptimalCostResult operator()(T const &t) { - OptimalCostState state{SerialParallelDecomposition{t}, + OptimalCostState state(SerialParallelDecomposition{t}, resource, - given_machine_views, - frontier_machine_views}; + fixed_machine_views); std::optional cached_result = searcher->cached_subgraph_costs.load(state); @@ -192,7 +188,7 @@ struct MachineMappingSearcher { return cached_result.value(); } OptimalCostResult result = searcher->optimal_cost( - t, g, resource, given_machine_views, frontier_machine_views); + t, resource, fixed_machine_views); searcher->cached_subgraph_costs.save(state, result); return result; @@ -200,146 +196,162 @@ struct MachineMappingSearcher { }; OptimalCostResult - optimal_cost(SubParallelComputationGraph const &g, - MachineSpecification resource, - SerialParallelDecomposition const &sp_decomposition) { - return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), - sp_decomposition.raw_variant); + optimal_cost(MachineSpecification resource) { + return std::visit(OptimalCostFunctor(this, resource, {}), + get_serial_parallel_decomposition(pcg.raw_graph).value().raw_variant); } OptimalCostResult optimal_cost( SerialSplit const &serial, - SubParallelComputationGraph const &g, MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - NOT_IMPLEMENTED(); - // OptimalCostResult optimal_result = OptimalCostResult::infinity(); - - // auto decomposed = decompose(serial); - // SerialParallelDecomposition pre_decompn = decomposed.first; - // SerialParallelDecomposition post_decompn = decomposed.second; - - // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - // SubParallelComputationGraph pre_graph = - // get_subgraph(g, graph_split.first); - // SubParallelComputationGraph post_graph = - // get_subgraph(g, graph_split.second); - - // std::unordered_set post_graph_sources = - // get_closed_sources(post_graph); - - // assert(post_graph_sources.size() == 1); // assume perfect SP - - // Node split_point = get_only(post_graph_sources); - // OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - - // for (MachineView const &mv : - // allowed_machine_views(g.raw_graph.at(split_point), resource)) { - // std::unordered_map new_given_machine_views = - // given_machine_views; - // new_given_machine_views.emplace(split_point, mv); - // std::unordered_map - // new_frontier_machine_views = frontier_machine_views; - // new_frontier_machine_views.emplace(split_edge, mv); - // minimize_runtime( - // optimal_result, - // OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // pre_graph, - // resource, - // given_machine_views, - // new_frontier_machine_views), - // pre_decompn.raw_variant), - // std::visit(OptimalCostFunctor(this, - // post_graph, - // resource, - // new_given_machine_views, - // frontier_machine_views), - // post_decompn.raw_variant))); - // } - - // return optimal_result; + std::unordered_map const &fixed_machine_views) { + OptimalCostResult optimal_result = OptimalCostResult::infinity(); + + auto [decompn1, decompn2] = decompose(serial); + + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + + OpenDataflowSubgraphResult subgraph_res1 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); + + std::unordered_set split_outputs; + for (auto const &[value, _] : subgraph_res2.full_graph_values_to_subgraph_inputs) { + assert(value.has()); + split_outputs.insert(value.get()); + } + + for (std::unordered_map const &split_machine_views : enumerate_machine_views(split_outputs, resource)) { + std::unordered_map fixed_machine_views1 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res2.graph)); + + for (auto const &[split_value, split_input] : subgraph_res2.full_graph_values_to_subgraph_inputs) { + MachineView mv = split_machine_views.at(split_value.get()); + fixed_machine_views1.emplace(split_value, mv); + fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); + } + + minimize_runtime( + optimal_result, + OptimalCostResult::sequential_combine( + std::visit(OptimalCostFunctor(this, + resource, + fixed_machine_views1), + decompn1.raw_variant), + std::visit(OptimalCostFunctor(this, + resource, + fixed_machine_views2), + decompn2.raw_variant))); + } + + return optimal_result; } OptimalCostResult optimal_cost( ParallelSplit const ¶llel, - SubParallelComputationGraph const &g, MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - - NOT_IMPLEMENTED(); - // auto decomposed = decompose(parallel); - // SerialParallelDecomposition decompn1 = decomposed.first; - // SerialParallelDecomposition decompn2 = decomposed.second; - - // GraphSplit graph_split = get_graph_split(decompn1, decompn2); - // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), - // g2 = get_subgraph(g, graph_split.second); - - // OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant)); - - // for (auto const &resource_split : get_resource_split(resource)) { - // minimize_runtime( - // optimal_result, - // OptimalCostResult::parallel_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource_split.first, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource_split.second, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant))); - // } - - // return optimal_result; + std::unordered_map const &fixed_machine_views) { + auto [decompn1, decompn2] = decompose(parallel); + + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + + OpenDataflowSubgraphResult subgraph_res1 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); + + std::unordered_map fixed_machine_views1 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res2.graph)); + + OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( + std::visit(OptimalCostFunctor(this, + resource, + fixed_machine_views1), + decompn1.raw_variant), + std::visit(OptimalCostFunctor(this, + resource, + fixed_machine_views1), + decompn2.raw_variant)); + + for (auto const &resource_split : get_resource_split(resource)) { + minimize_runtime( + optimal_result, + OptimalCostResult::parallel_combine( + std::visit(OptimalCostFunctor(this, + resource_split.first, + fixed_machine_views1), + decompn1.raw_variant), + std::visit(OptimalCostFunctor(this, + resource_split.second, + fixed_machine_views1), + decompn2.raw_variant))); + } + + return optimal_result; } OptimalCostResult optimal_cost( Node const &node, - SubParallelComputationGraph const &g, MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - if (contains_key(given_machine_views, node)) { - assert(contains(allowed_machine_views(g.raw_graph.at(node), resource), - given_machine_views.at(node))); - MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}; + std::unordered_map const &fixed_machine_views) { + SubParallelComputationGraph subgraph = sub_pcg_from_partial_pcg(pcg, {node}); + + OpenDataflowValue any_output = OpenDataflowValue(get_outputs(pcg.raw_graph, node)[0]); + if (contains_key(fixed_machine_views, any_output)) { + assert(contains(allowed_machine_views(pcg.raw_graph.at(node), resource), + fixed_machine_views.at(any_output))); + MachineView mv = fixed_machine_views.at(any_output); + MachineMapping mv_map{{{node, mv}}}; + return {estimate_cost(subgraph, cost_estimator, fixed_machine_views), mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (auto mv : allowed_machine_views(g.raw_graph.at(node), resource)) { - MachineMapping mv_map{{{node, mv}}}; + for (std::unordered_map node_machine_views : enumerate_machine_views({node}, resource)) { + MachineMapping mv_map{{{node, node_machine_views.at(node)}}}; + std::unordered_map machine_views = fixed_machine_views; + for (DataflowOutput o : get_outputs(pcg.raw_graph, node)) { + machine_views.emplace(o, node_machine_views.at(node)); + } minimize_runtime( optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + {estimate_cost(subgraph, cost_estimator, machine_views), mv_map}); } return optimal_result; } } + + std::vector> enumerate_machine_views(std::unordered_set const &nodes, MachineSpecification const &resource) { + if (nodes.empty()) { + return {{}}; + } + Node node = get_first(nodes); + std::vector> partial_enumeration = enumerate_machine_views(set_minus(nodes, {node}), resource); + std::unordered_set allowed_machine_views_for_node = this->allowed_machine_views(pcg.raw_graph.at(node), resource); + std::vector> enumeration; + for (MachineView const &mv : allowed_machine_views_for_node) { + for (std::unordered_map const &partial : partial_enumeration) { + enumeration.push_back(merge_maps(partial, std::unordered_map{{node, mv}})); + } + } + return enumeration; + } + + std::vector> enumerate_machine_views(std::unordered_set const &values, MachineSpecification const &resource) { + std::unordered_set nodes; + for (DataflowOutput const &v : values) { + nodes.insert(v.node); + } + + std::vector> node_enumeration = enumerate_machine_views(nodes, resource); + std::vector> enumeration; + + for (std::unordered_map _node_enumeration : node_enumeration) { + std::unordered_map _emumeration; + for (DataflowOutput const &v : values) { + _emumeration.emplace(v, _node_enumeration.at(v.node)); + } + enumeration.push_back(_emumeration); + } + + return enumeration; + } }; OptimalCostResult optimal_cost( @@ -350,12 +362,9 @@ OptimalCostResult optimal_cost( CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = - get_serial_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - MachineMappingSearcher searcher( + MachineMappingSearcher searcher(g, cost_estimator, allowed_machine_views, cached_subgraph_costs); - return searcher.optimal_cost(subpcg, resources, sp_decomposition); + return searcher.optimal_cost(resources); } } // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ba6ef28daa..116a238dc2 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,5 +1,4 @@ #include "compiler/unity_algorithm.h" -#include "compiler/graph_utils.h" #include "compiler/machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" diff --git a/lib/compiler/test/src/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h index 9417b863e4..322f80352f 100644 --- a/lib/compiler/test/src/test_cost_estimator.h +++ b/lib/compiler/test/src/test_cost_estimator.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H #define _FLEXFLOW_TEST_COST_ESTIMATOR_H -#include "compiler/cost_estimate.h" +#include "compiler/cost_estimator.h" namespace FlexFlow { diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 133558f83a..7f4a26766d 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -1,72 +1,56 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "test_cost_estimator.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// // Rapidcheck infrastructures for graphs does not work for now -// /* -// Tests whether optimal_cost can give a valid result given random PCG, -// trivial allowed machine views, trivial cost estimator and random machine -// specification. -// */ -// // TEST_CASE("optimal_cost") { -// // auto test_allowed_machine_views = [](Operator const &, -// // MachineSpecification const &) { -// // return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// // }; -// // RC_SUBCASE([](ParallelComputationGraph const &g, -// // MachineSpecification const &machine_spec) { -// // OptimalCostCache cached_subgraph_costs; -// // OptimalCostResult result = optimal_cost(g, -// // test_allowed_machine_views, -// // TestCostEstimator{}, -// // machine_spec, -// // cached_subgraph_costs); -// // RC_ASSERT(result.runtime > 0); -// // RC_ASSERT(keys(result.machine_mapping.machine_views) == -// get_nodes(g)); -// // }); -// // } - -// TEST_CASE("optimal_cost_0") { -// auto pcg = -// OutputLabelledMultiDiGraph::template -// create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); -// Node n1 = pcg.add_node(Operator{ -// LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, -// std::nullopt}, "linear"}); - -// MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; -// pcg.add_edge(e); -// ParallelDim dim = {2, 1, false}; -// ParallelTensorDims dims = {FFOrdered{dim}}; -// pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, -// CreateGrad::YES)); - -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{ -// make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; -// }; - -// CostEstimator estimator = CostEstimator::create(); - -// MachineSpecification machine_spec{1, 1, 1, 1, 1}; - -// OptimalCostCache cached_results; - -// OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), -// test_allowed_machine_views, -// estimator, -// machine_spec, -// cached_results); - -// CHECK(bool(result.runtime > 0)); -// } -// } +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "test_cost_estimator.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("optimal_cost_0") { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, true, "input0"); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + ParallelComputationGraph pcg = builder.pcg; + + auto test_allowed_machine_views = [](ParallelLayerAttrs const &, + MachineSpecification const &) { + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; + + CostEstimator estimator = CostEstimator::create(); + MachineSpecification machine_spec{1, 1, 1, 1, 1}; + OptimalCostCache cached_results; + OptimalCostResult result = optimal_cost( + pcg, + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); + + CHECK(bool(result.runtime > 0)); + } +} \ No newline at end of file diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 42d85dc549..c91055c530 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -21,6 +21,9 @@ SubParallelComputationGraph ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &); +SubParallelComputationGraph + sub_pcg_from_partial_pcg(ParallelComputationGraph const &, std::unordered_set const &); + parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name); diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 2f050ce45e..ca273b9ad0 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -4,6 +4,9 @@ #include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" namespace FlexFlow { @@ -51,6 +54,27 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( // }; } +SubParallelComputationGraph + sub_pcg_from_partial_pcg(ParallelComputationGraph const &pcg, std::unordered_set const &nodes) { + auto as_open = view_as_labelled_open_dataflow_graph(pcg.raw_graph); + OpenDataflowSubgraphResult subgraph_result = get_subgraph(as_open, nodes); + return SubParallelComputationGraph{ + with_labelling( + subgraph_result.graph, + generate_map(nodes, [&](Node const &node) { return as_open.at(node); }), + generate_map( + get_open_dataflow_values(subgraph_result.graph), + [&](OpenDataflowValue const &value) { + if (value.has()) { + return as_open.at(subgraph_result.full_graph_values_to_subgraph_inputs.at_r(value.get())); + } else { + return as_open.at(value); + } + }) + ) + }; +} + parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name) { From d36d1ea0fac9e413db2c4dfec01ec57b971f77ec Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 7 Aug 2024 15:33:08 -0400 Subject: [PATCH 02/29] unity algorithm builds --- .../graph_optimize_result.struct.toml | 15 +++ .../include/compiler/graph_optimize_state.h | 29 ++++++ .../include/compiler/unity_algorithm.h | 23 +---- lib/compiler/src/graph_optimize_state.cc | 48 ++++++++++ lib/compiler/src/unity_algorithm.cc | 94 ++++++++----------- 5 files changed, 137 insertions(+), 72 deletions(-) create mode 100644 lib/compiler/include/compiler/graph_optimize_result.struct.toml create mode 100644 lib/compiler/include/compiler/graph_optimize_state.h create mode 100644 lib/compiler/src/graph_optimize_state.cc diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml new file mode 100644 index 0000000000..c2a9a4ab39 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "GraphOptimizeResult" +features = [ ] + +includes = [ + "compiler/machine_mapping.h" +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h new file mode 100644 index 0000000000..4192179ddd --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H +#define _FLEXFLOW_COMPILER_MCMC_STATE_H + +#include "compiler/machine_mapping.h" +#include "compiler/graph_optimize_result.dtg.h" + +namespace FlexFlow { + +struct GraphOptimizeState { + GraphOptimizeResult graph_optimize_result; + float runtime; + + bool operator==(GraphOptimizeState const &other) const; + bool operator!=(GraphOptimizeState const &other) const; + bool operator<(GraphOptimizeState const &other) const; +}; + +} + +namespace std { + +template <> +struct hash<::FlexFlow::GraphOptimizeState> { + size_t operator()(::FlexFlow::GraphOptimizeState const &) const; +}; + +} // namespace std + +#endif \ No newline at end of file diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 65ca694f78..ed77e28b39 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -3,27 +3,12 @@ #include "compiler/machine_mapping.h" #include "cost_estimator.h" -#include "machine_mapping.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" -namespace FlexFlow { - -struct Strategy { - ParallelComputationGraph pcg; - MachineMapping machine_mapping; - req runtime; - friend bool operator!=(Strategy const &lhs, Strategy const &rhs) { - return (lhs.machine_mapping != rhs.machine_mapping) || - (lhs.runtime != rhs.runtime); - } -}; +#include "compiler/graph_optimize_result.dtg.h" -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); - -struct StrategyRuntimeCmp { - bool operator()(Strategy const &, Strategy const &); -}; +namespace FlexFlow { struct OptimizerConfig { float alpha; @@ -32,8 +17,8 @@ struct OptimizerConfig { int max_num_ops; }; -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc new file mode 100644 index 0000000000..6826aa5e6c --- /dev/null +++ b/lib/compiler/src/graph_optimize_state.cc @@ -0,0 +1,48 @@ +#include "compiler/graph_optimize_state.h" + +namespace FlexFlow { + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + auto layers1 = topological_ordering(graph_optimize_result.pcg); + auto layers2 = topological_ordering(other.graph_optimize_result.pcg); + if (layers1.size() != layers2.size()) { + return false; + } + for (size_t i = 0; i < layers1.size(); ++i) { + auto inputs1 = get_layer_inputs(graph_optimize_result.pcg, layers1[i]); + auto inputs2 = get_layer_inputs(other.graph_optimize_result.pcg, layers2[i]); + if (inputs1.size() != inputs2.size()) { + return false; + } + for (size_t j = 0; j < inputs1.size(); ++j) { + if (inputs1[j] != inputs2[j]) { + return false; + } + } + } + return true; +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +} + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()(::FlexFlow::GraphOptimizeState const &state) const { + size_t seed = 0; + auto layers = topological_ordering(state.graph_optimize_result.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (auto layer : layers) { + auto inputs = get_layer_inputs(state.graph_optimize_result.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (auto input : inputs) { + ::FlexFlow::hash_combine(seed, input); + } + } + return seed; +} + +} \ No newline at end of file diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 116a238dc2..f39b903c1f 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -4,16 +4,13 @@ #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" +#include "compiler/graph_optimize_state.h" namespace FlexFlow { -bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.runtime < rhs.runtime; -} - /* * Gets all substitutions applicable to a PCG */ -std::unordered_set +std::vector get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -21,73 +18,64 @@ std::unordered_set /* * Applies a substitution to all possible positions in PCG */ -std::unordered_set +std::vector apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { NOT_IMPLEMENTED(); } -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views, OptimizerConfig const &opt_config) { - NOT_IMPLEMENTED(); - // ParallelComputationGraph pcg = cg_to_pcg(cg); + std::vector substitutions = get_all_applicable_substitutions(pcg); - // std::unordered_set subs = - // get_all_applicable_substitutions(pcg); + OptimalCostCache cached_subgraph_costs; + DeduplicatedPriorityQueue candidates; - // OptimalCostCache cached_subgraph_costs; - // DeduplicatedPriorityQueue, - // StrategyRuntimeCmp> - // candidates; + OptimalCostResult original_pcg_cost = optimal_cost(pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); - // OptimalCostResult initial_pcg_result = optimal_cost(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy initial_result{ - // pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; + GraphOptimizeState initial_state = {GraphOptimizeResult(pcg,original_pcg_cost.machine_mapping), original_pcg_cost.runtime}; - // Strategy best_result = initial_result; - // candidates.push(initial_result); + GraphOptimizeState best_state = initial_state; + candidates.push(initial_state); - // for (int iteration = 0; !candidates.empty() && iteration < - // opt_config.budget; - // ++iteration) { - // Strategy const ¤t_result = candidates.top(); - // candidates.pop(); + for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; ++iteration) { + GraphOptimizeState current_state = candidates.top(); + candidates.pop(); - // if (current_result.runtime < best_result.runtime) { - // best_result = current_result; - // } else if (current_result.runtime > - // best_result.runtime * opt_config.alpha) { - // continue; - // } + if (current_state.runtime < best_state.runtime) { + best_state = current_state; + } else if (current_state.runtime > + best_state.runtime * opt_config.alpha) { + continue; + } - // for (auto const &sub : subs) { - // for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) - // { - // OptimalCostResult c = optimal_cost(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; - // if (new_result.runtime <= opt_config.threshold && - // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_result); - // } - // } - // } - // } + for (Substitution const &substitution : substitutions) { + for (ParallelComputationGraph const &new_pcg : apply_substitution(current_state.graph_optimize_result.pcg, substitution)) { + OptimalCostResult new_pcg_cost = optimal_cost(new_pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + GraphOptimizeState new_state{GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), new_pcg_cost.runtime}; + if (new_pcg_cost.runtime <= opt_config.threshold && + get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { + candidates.push(new_state); + } + } + } + } - // return best_result; + return best_state.graph_optimize_result; } } // namespace FlexFlow From c71b773c136606cbb80d830f7f99eacbf1ec7f87 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 7 Aug 2024 16:11:38 -0400 Subject: [PATCH 03/29] fmt --- .../include/compiler/graph_optimize_state.h | 6 +- .../include/compiler/unity_algorithm.h | 2 +- lib/compiler/src/graph_optimize_state.cc | 10 +- lib/compiler/src/machine_mapping.cc | 197 ++++++++++-------- lib/compiler/src/unity_algorithm.cc | 40 ++-- lib/compiler/test/src/test_optimal_cost.cc | 15 +- .../sub_parallel_computation_graph.h | 3 +- .../sub_parallel_computation_graph.cc | 31 ++- 8 files changed, 167 insertions(+), 137 deletions(-) diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h index 4192179ddd..7d3608808f 100644 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H #define _FLEXFLOW_COMPILER_MCMC_STATE_H -#include "compiler/machine_mapping.h" #include "compiler/graph_optimize_result.dtg.h" +#include "compiler/machine_mapping.h" namespace FlexFlow { @@ -15,7 +15,7 @@ struct GraphOptimizeState { bool operator<(GraphOptimizeState const &other) const; }; -} +} // namespace FlexFlow namespace std { @@ -26,4 +26,4 @@ struct hash<::FlexFlow::GraphOptimizeState> { } // namespace std -#endif \ No newline at end of file +#endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index ed77e28b39..9ef85fe639 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H +#include "compiler/graph_optimize_result.dtg.h" #include "compiler/machine_mapping.h" #include "cost_estimator.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" -#include "compiler/graph_optimize_result.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc index 6826aa5e6c..a4176e3f3d 100644 --- a/lib/compiler/src/graph_optimize_state.cc +++ b/lib/compiler/src/graph_optimize_state.cc @@ -10,7 +10,8 @@ bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { } for (size_t i = 0; i < layers1.size(); ++i) { auto inputs1 = get_layer_inputs(graph_optimize_result.pcg, layers1[i]); - auto inputs2 = get_layer_inputs(other.graph_optimize_result.pcg, layers2[i]); + auto inputs2 = + get_layer_inputs(other.graph_optimize_result.pcg, layers2[i]); if (inputs1.size() != inputs2.size()) { return false; } @@ -27,11 +28,12 @@ bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { return !(*this == other); } -} +} // namespace FlexFlow namespace std { -size_t hash<::FlexFlow::GraphOptimizeState>::operator()(::FlexFlow::GraphOptimizeState const &state) const { +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { size_t seed = 0; auto layers = topological_ordering(state.graph_optimize_result.pcg); ::FlexFlow::hash_combine(seed, layers.size()); @@ -45,4 +47,4 @@ size_t hash<::FlexFlow::GraphOptimizeState>::operator()(::FlexFlow::GraphOptimiz return seed; } -} \ No newline at end of file +} // namespace std diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index dd2be6bafc..dd12bb23c4 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -9,22 +9,22 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" +#include "utils/containers/get_first.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" -#include "utils/containers/values.h" -#include "utils/containers/get_first.h" -#include "utils/containers/set_minus.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" #include "utils/exception.h" +#include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" #include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -123,10 +123,10 @@ GraphSplit get_nodes(post_decomposition)}; } -float estimate_cost(SubParallelComputationGraph const &g, - CostEstimator const &estimator, - std::unordered_map const - &machine_views) { +float estimate_cost( + SubParallelComputationGraph const &g, + CostEstimator const &estimator, + std::unordered_map const &machine_views) { // TODO: Consider parallelism float cost = 1.; // for (Node const &node : get_nodes(g.raw_graph)) { @@ -152,8 +152,7 @@ struct MachineMappingSearcher { ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views, OptimalCostCache &cached_subgraph_costs) - : pcg(pcg), - cost_estimator(cost_estimator), + : pcg(pcg), cost_estimator(cost_estimator), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} @@ -178,171 +177,195 @@ struct MachineMappingSearcher { template OptimalCostResult operator()(T const &t) { - OptimalCostState state(SerialParallelDecomposition{t}, - resource, - fixed_machine_views); + OptimalCostState state( + SerialParallelDecomposition{t}, resource, fixed_machine_views); std::optional cached_result = searcher->cached_subgraph_costs.load(state); if (cached_result) { return cached_result.value(); } - OptimalCostResult result = searcher->optimal_cost( - t, resource, fixed_machine_views); + OptimalCostResult result = + searcher->optimal_cost(t, resource, fixed_machine_views); searcher->cached_subgraph_costs.save(state, result); return result; } }; - OptimalCostResult - optimal_cost(MachineSpecification resource) { - return std::visit(OptimalCostFunctor(this, resource, {}), - get_serial_parallel_decomposition(pcg.raw_graph).value().raw_variant); + OptimalCostResult optimal_cost(MachineSpecification resource) { + return std::visit( + OptimalCostFunctor(this, resource, {}), + get_serial_parallel_decomposition(pcg.raw_graph).value().raw_variant); } - OptimalCostResult optimal_cost( - SerialSplit const &serial, - MachineSpecification const &resource, - std::unordered_map const &fixed_machine_views) { + OptimalCostResult + optimal_cost(SerialSplit const &serial, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { OptimalCostResult optimal_result = OptimalCostResult::infinity(); auto [decompn1, decompn2] = decompose(serial); GraphSplit graph_split = get_graph_split(decompn1, decompn2); - OpenDataflowSubgraphResult subgraph_res1 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); + OpenDataflowSubgraphResult subgraph_res1 = + get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = + get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); std::unordered_set split_outputs; - for (auto const &[value, _] : subgraph_res2.full_graph_values_to_subgraph_inputs) { + for (auto const &[value, _] : + subgraph_res2.full_graph_values_to_subgraph_inputs) { assert(value.has()); split_outputs.insert(value.get()); } - for (std::unordered_map const &split_machine_views : enumerate_machine_views(split_outputs, resource)) { - std::unordered_map fixed_machine_views1 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res2.graph)); - - for (auto const &[split_value, split_input] : subgraph_res2.full_graph_values_to_subgraph_inputs) { - MachineView mv = split_machine_views.at(split_value.get()); + for (std::unordered_map const + &split_machine_views : + enumerate_machine_views(split_outputs, resource)) { + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res2.graph)); + + for (auto const &[split_value, split_input] : + subgraph_res2.full_graph_values_to_subgraph_inputs) { + MachineView mv = + split_machine_views.at(split_value.get()); fixed_machine_views1.emplace(split_value, mv); fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); } - minimize_runtime( - optimal_result, - OptimalCostResult::sequential_combine( - std::visit(OptimalCostFunctor(this, - resource, - fixed_machine_views1), - decompn1.raw_variant), - std::visit(OptimalCostFunctor(this, - resource, - fixed_machine_views2), - decompn2.raw_variant))); + minimize_runtime(optimal_result, + OptimalCostResult::sequential_combine( + std::visit(OptimalCostFunctor( + this, resource, fixed_machine_views1), + decompn1.raw_variant), + std::visit(OptimalCostFunctor( + this, resource, fixed_machine_views2), + decompn2.raw_variant))); } return optimal_result; } - OptimalCostResult optimal_cost( - ParallelSplit const ¶llel, - MachineSpecification const &resource, - std::unordered_map const &fixed_machine_views) { + OptimalCostResult + optimal_cost(ParallelSplit const ¶llel, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { auto [decompn1, decompn2] = decompose(parallel); GraphSplit graph_split = get_graph_split(decompn1, decompn2); - OpenDataflowSubgraphResult subgraph_res1 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); + OpenDataflowSubgraphResult subgraph_res1 = + get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = + get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); - std::unordered_map fixed_machine_views1 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res2.graph)); + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res2.graph)); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - std::visit(OptimalCostFunctor(this, - resource, - fixed_machine_views1), + std::visit(OptimalCostFunctor(this, resource, fixed_machine_views1), decompn1.raw_variant), - std::visit(OptimalCostFunctor(this, - resource, - fixed_machine_views1), + std::visit(OptimalCostFunctor(this, resource, fixed_machine_views1), decompn2.raw_variant)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime( optimal_result, OptimalCostResult::parallel_combine( - std::visit(OptimalCostFunctor(this, - resource_split.first, - fixed_machine_views1), + std::visit(OptimalCostFunctor( + this, resource_split.first, fixed_machine_views1), decompn1.raw_variant), - std::visit(OptimalCostFunctor(this, - resource_split.second, - fixed_machine_views1), + std::visit(OptimalCostFunctor( + this, resource_split.second, fixed_machine_views1), decompn2.raw_variant))); } return optimal_result; } - OptimalCostResult optimal_cost( - Node const &node, - MachineSpecification const &resource, - std::unordered_map const &fixed_machine_views) { - SubParallelComputationGraph subgraph = sub_pcg_from_partial_pcg(pcg, {node}); - - OpenDataflowValue any_output = OpenDataflowValue(get_outputs(pcg.raw_graph, node)[0]); + OptimalCostResult + optimal_cost(Node const &node, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + SubParallelComputationGraph subgraph = + sub_pcg_from_partial_pcg(pcg, {node}); + + OpenDataflowValue any_output = + OpenDataflowValue(get_outputs(pcg.raw_graph, node)[0]); if (contains_key(fixed_machine_views, any_output)) { assert(contains(allowed_machine_views(pcg.raw_graph.at(node), resource), fixed_machine_views.at(any_output))); MachineView mv = fixed_machine_views.at(any_output); MachineMapping mv_map{{{node, mv}}}; - return {estimate_cost(subgraph, cost_estimator, fixed_machine_views), mv_map}; + return {estimate_cost(subgraph, cost_estimator, fixed_machine_views), + mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (std::unordered_map node_machine_views : enumerate_machine_views({node}, resource)) { + for (std::unordered_map node_machine_views : + enumerate_machine_views({node}, resource)) { MachineMapping mv_map{{{node, node_machine_views.at(node)}}}; - std::unordered_map machine_views = fixed_machine_views; + std::unordered_map machine_views = + fixed_machine_views; for (DataflowOutput o : get_outputs(pcg.raw_graph, node)) { machine_views.emplace(o, node_machine_views.at(node)); } minimize_runtime( optimal_result, - {estimate_cost(subgraph, cost_estimator, machine_views), - mv_map}); + {estimate_cost(subgraph, cost_estimator, machine_views), mv_map}); } return optimal_result; } } - - std::vector> enumerate_machine_views(std::unordered_set const &nodes, MachineSpecification const &resource) { + + std::vector> + enumerate_machine_views(std::unordered_set const &nodes, + MachineSpecification const &resource) { if (nodes.empty()) { return {{}}; } Node node = get_first(nodes); - std::vector> partial_enumeration = enumerate_machine_views(set_minus(nodes, {node}), resource); - std::unordered_set allowed_machine_views_for_node = this->allowed_machine_views(pcg.raw_graph.at(node), resource); + std::vector> partial_enumeration = + enumerate_machine_views(set_minus(nodes, {node}), resource); + std::unordered_set allowed_machine_views_for_node = + this->allowed_machine_views(pcg.raw_graph.at(node), resource); std::vector> enumeration; for (MachineView const &mv : allowed_machine_views_for_node) { - for (std::unordered_map const &partial : partial_enumeration) { - enumeration.push_back(merge_maps(partial, std::unordered_map{{node, mv}})); + for (std::unordered_map const &partial : + partial_enumeration) { + enumeration.push_back(merge_maps( + partial, std::unordered_map{{node, mv}})); } } return enumeration; } - std::vector> enumerate_machine_views(std::unordered_set const &values, MachineSpecification const &resource) { + std::vector> + enumerate_machine_views(std::unordered_set const &values, + MachineSpecification const &resource) { std::unordered_set nodes; for (DataflowOutput const &v : values) { nodes.insert(v.node); } - std::vector> node_enumeration = enumerate_machine_views(nodes, resource); + std::vector> node_enumeration = + enumerate_machine_views(nodes, resource); std::vector> enumeration; - for (std::unordered_map _node_enumeration : node_enumeration) { + for (std::unordered_map _node_enumeration : + node_enumeration) { std::unordered_map _emumeration; for (DataflowOutput const &v : values) { _emumeration.emplace(v, _node_enumeration.at(v.node)); @@ -362,8 +385,8 @@ OptimalCostResult optimal_cost( CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - MachineMappingSearcher searcher(g, - cost_estimator, allowed_machine_views, cached_subgraph_costs); + MachineMappingSearcher searcher( + g, cost_estimator, allowed_machine_views, cached_subgraph_costs); return searcher.optimal_cost(resources); } diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index f39b903c1f..d85ad08d24 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,10 +1,10 @@ #include "compiler/unity_algorithm.h" +#include "compiler/graph_optimize_state.h" #include "compiler/machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" -#include "compiler/graph_optimize_state.h" namespace FlexFlow { /* @@ -32,44 +32,50 @@ GraphOptimizeResult graph_optimize( ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views, OptimizerConfig const &opt_config) { - std::vector substitutions = get_all_applicable_substitutions(pcg); + std::vector substitutions = + get_all_applicable_substitutions(pcg); OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue candidates; OptimalCostResult original_pcg_cost = optimal_cost(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); - GraphOptimizeState initial_state = {GraphOptimizeResult(pcg,original_pcg_cost.machine_mapping), original_pcg_cost.runtime}; + GraphOptimizeState initial_state = { + GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + original_pcg_cost.runtime}; GraphOptimizeState best_state = initial_state; candidates.push(initial_state); - for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; ++iteration) { + for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; + ++iteration) { GraphOptimizeState current_state = candidates.top(); candidates.pop(); if (current_state.runtime < best_state.runtime) { best_state = current_state; - } else if (current_state.runtime > - best_state.runtime * opt_config.alpha) { + } else if (current_state.runtime > best_state.runtime * opt_config.alpha) { continue; } for (Substitution const &substitution : substitutions) { - for (ParallelComputationGraph const &new_pcg : apply_substitution(current_state.graph_optimize_result.pcg, substitution)) { + for (ParallelComputationGraph const &new_pcg : apply_substitution( + current_state.graph_optimize_result.pcg, substitution)) { OptimalCostResult new_pcg_cost = optimal_cost(new_pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); - GraphOptimizeState new_state{GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), new_pcg_cost.runtime}; + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + GraphOptimizeState new_state{ + GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + new_pcg_cost.runtime}; if (new_pcg_cost.runtime <= opt_config.threshold && get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - candidates.push(new_state); + candidates.push(new_state); } } } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 7f4a26766d..3c2cdee068 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -1,7 +1,7 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "test_cost_estimator.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "test_cost_estimator.h" using namespace FlexFlow; @@ -44,13 +44,12 @@ TEST_SUITE(FF_TEST_SUITE) { CostEstimator estimator = CostEstimator::create(); MachineSpecification machine_spec{1, 1, 1, 1, 1}; OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost( - pcg, - test_allowed_machine_views, - estimator, - machine_spec, - cached_results); + OptimalCostResult result = optimal_cost(pcg, + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); CHECK(bool(result.runtime > 0)); } -} \ No newline at end of file +} diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index c91055c530..ba09d3a166 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -22,7 +22,8 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &); SubParallelComputationGraph - sub_pcg_from_partial_pcg(ParallelComputationGraph const &, std::unordered_set const &); + sub_pcg_from_partial_pcg(ParallelComputationGraph const &, + std::unordered_set const &); parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index ca273b9ad0..923638be59 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -3,10 +3,10 @@ #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms.h" -#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" namespace FlexFlow { @@ -55,24 +55,23 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( } SubParallelComputationGraph - sub_pcg_from_partial_pcg(ParallelComputationGraph const &pcg, std::unordered_set const &nodes) { + sub_pcg_from_partial_pcg(ParallelComputationGraph const &pcg, + std::unordered_set const &nodes) { auto as_open = view_as_labelled_open_dataflow_graph(pcg.raw_graph); OpenDataflowSubgraphResult subgraph_result = get_subgraph(as_open, nodes); - return SubParallelComputationGraph{ - with_labelling( + return SubParallelComputationGraph{with_labelling( subgraph_result.graph, generate_map(nodes, [&](Node const &node) { return as_open.at(node); }), - generate_map( - get_open_dataflow_values(subgraph_result.graph), - [&](OpenDataflowValue const &value) { - if (value.has()) { - return as_open.at(subgraph_result.full_graph_values_to_subgraph_inputs.at_r(value.get())); - } else { - return as_open.at(value); - } - }) - ) - }; + generate_map(get_open_dataflow_values(subgraph_result.graph), + [&](OpenDataflowValue const &value) { + if (value.has()) { + return as_open.at( + subgraph_result.full_graph_values_to_subgraph_inputs + .at_r(value.get())); + } else { + return as_open.at(value); + } + }))}; } parallel_layer_guid_t From 678e9909b089c30e40d096978b3cf7fbf96b755f Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 21 Aug 2024 19:10:36 -0400 Subject: [PATCH 04/29] fix --- .../graph_optimize_result.struct.toml | 3 +- .../include/compiler/graph_optimize_state.h | 2 + .../compiler/optimal_cost_state.struct.toml | 5 +- .../compiler/optimizer_config.struct.toml | 26 +++++++ .../include/compiler/unity_algorithm.h | 8 +- lib/compiler/src/graph_optimize_state.cc | 34 +++++++- lib/compiler/src/machine_mapping.cc | 78 ++++++++----------- .../test/src/test_graph_optimize_state.cc | 72 +++++++++++++++++ lib/compiler/test/src/test_optimal_cost.cc | 2 +- 9 files changed, 172 insertions(+), 58 deletions(-) create mode 100644 lib/compiler/include/compiler/optimizer_config.struct.toml create mode 100644 lib/compiler/test/src/test_graph_optimize_state.cc diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml index c2a9a4ab39..a028981c4c 100644 --- a/lib/compiler/include/compiler/graph_optimize_result.struct.toml +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -3,7 +3,8 @@ name = "GraphOptimizeResult" features = [ ] includes = [ - "compiler/machine_mapping.h" + "compiler/machine_mapping.dtg.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h" ] [[fields]] diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h index 7d3608808f..7fdad576ac 100644 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -7,6 +7,8 @@ namespace FlexFlow { struct GraphOptimizeState { + GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, float runtime); + GraphOptimizeResult graph_optimize_result; float runtime; diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml index 9c8fabbf47..7d06aa1d87 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/optimal_cost_state.struct.toml @@ -13,11 +13,10 @@ includes = [ "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", "pcg/machine_view.dtg.h", - "utils/graph/node/node.dtg.h", "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/optimizer_config.struct.toml new file mode 100644 index 0000000000..b7f4f71e9c --- /dev/null +++ b/lib/compiler/include/compiler/optimizer_config.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OptimizerConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "alpha" +type = "float" + +[[fields]] +name = "budget" +type = "int" + +[[fields]] +name = "threshold" +type = "float" + +[[fields]] +name = "max_num_ops" +type = "int" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 9ef85fe639..b30f01d8f0 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -4,19 +4,13 @@ #include "compiler/graph_optimize_result.dtg.h" #include "compiler/machine_mapping.h" #include "cost_estimator.h" +#include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { -struct OptimizerConfig { - float alpha; - int budget; - float threshold; - int max_num_ops; -}; - GraphOptimizeResult graph_optimize( ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc index a4176e3f3d..cdba17b2c9 100644 --- a/lib/compiler/src/graph_optimize_state.cc +++ b/lib/compiler/src/graph_optimize_state.cc @@ -2,13 +2,22 @@ namespace FlexFlow { +GraphOptimizeState::GraphOptimizeState( + GraphOptimizeResult const &graph_optimize_result, float runtime) + : graph_optimize_result(graph_optimize_result), runtime(runtime) {} + bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { auto layers1 = topological_ordering(graph_optimize_result.pcg); auto layers2 = topological_ordering(other.graph_optimize_result.pcg); if (layers1.size() != layers2.size()) { return false; } + std::unordered_map mapping; for (size_t i = 0; i < layers1.size(); ++i) { + if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != + get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { + return false; + } auto inputs1 = get_layer_inputs(graph_optimize_result.pcg, layers1[i]); auto inputs2 = get_layer_inputs(other.graph_optimize_result.pcg, layers2[i]); @@ -16,10 +25,19 @@ bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { return false; } for (size_t j = 0; j < inputs1.size(); ++j) { - if (inputs1[j] != inputs2[j]) { + if (inputs1[j] != mapping.at(inputs2[j])) { return false; } } + auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); + auto outputs2 = + get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); + if (outputs1.size() != outputs2.size()) { + return false; + } + for (size_t j = 0; j < outputs1.size(); ++j) { + mapping.emplace(outputs2[j], outputs1[j]); + } } return true; } @@ -28,6 +46,10 @@ bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { return !(*this == other); } +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + } // namespace FlexFlow namespace std { @@ -38,10 +60,18 @@ size_t hash<::FlexFlow::GraphOptimizeState>::operator()( auto layers = topological_ordering(state.graph_optimize_result.pcg); ::FlexFlow::hash_combine(seed, layers.size()); for (auto layer : layers) { + ::FlexFlow::hash_combine( + seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); auto inputs = get_layer_inputs(state.graph_optimize_result.pcg, layer); ::FlexFlow::hash_combine(seed, inputs.size()); for (auto input : inputs) { - ::FlexFlow::hash_combine(seed, input); + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(state.graph_optimize_result.pcg, input) == + layers[i]) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } } } return seed; diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index dd12bb23c4..4f53037ab2 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -123,24 +123,13 @@ GraphSplit get_nodes(post_decomposition)}; } -float estimate_cost( +float base_case_estimate_cost( SubParallelComputationGraph const &g, CostEstimator const &estimator, std::unordered_map const &machine_views) { - // TODO: Consider parallelism - float cost = 1.; - // for (Node const &node : get_nodes(g.raw_graph)) { - // std::vector incoming_edges = - // get_incoming_edges(g.raw_graph, node); - // std::vector inputs = - // transform(incoming_edges, - // [&](OpenDataflowEdge const &input_edge) { - // return g.raw_graph.at(input_edge).get_shape(); - // }); - // cost += estimator.estimate_cost( - // g.raw_graph.at(node).op_attrs, inputs, - // device_mapping.machine_views.at(node)); - // } + // In the base case, all the operators are executed sequentially. + float cost = 0.1; + // TODO(@wmdi) return cost; } @@ -218,13 +207,12 @@ struct MachineMappingSearcher { std::unordered_set split_outputs; for (auto const &[value, _] : subgraph_res2.full_graph_values_to_subgraph_inputs) { - assert(value.has()); split_outputs.insert(value.get()); } for (std::unordered_map const &split_machine_views : - enumerate_machine_views(split_outputs, resource)) { + allowed_machine_mappings(split_outputs, resource)) { std::unordered_map fixed_machine_views1 = restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res1.graph)); @@ -240,14 +228,11 @@ struct MachineMappingSearcher { fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); } - minimize_runtime(optimal_result, - OptimalCostResult::sequential_combine( - std::visit(OptimalCostFunctor( - this, resource, fixed_machine_views1), - decompn1.raw_variant), - std::visit(OptimalCostFunctor( - this, resource, fixed_machine_views2), - decompn2.raw_variant))); + minimize_runtime( + optimal_result, + OptimalCostResult::sequential_combine( + optimal_cost(decompn1, resource, fixed_machine_views1), + optimal_cost(decompn2, resource, fixed_machine_views2))); } return optimal_result; @@ -275,26 +260,31 @@ struct MachineMappingSearcher { get_open_dataflow_values(subgraph_res2.graph)); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - std::visit(OptimalCostFunctor(this, resource, fixed_machine_views1), - decompn1.raw_variant), - std::visit(OptimalCostFunctor(this, resource, fixed_machine_views1), - decompn2.raw_variant)); + optimal_cost(decompn1, resource, fixed_machine_views1), + optimal_cost(decompn2, resource, fixed_machine_views2)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime( optimal_result, OptimalCostResult::parallel_combine( - std::visit(OptimalCostFunctor( - this, resource_split.first, fixed_machine_views1), - decompn1.raw_variant), - std::visit(OptimalCostFunctor( - this, resource_split.second, fixed_machine_views1), - decompn2.raw_variant))); + optimal_cost( + decompn1, resource_split.first, fixed_machine_views1), + optimal_cost( + decompn2, resource_split.second, fixed_machine_views2))); } return optimal_result; } + OptimalCostResult + optimal_cost(SerialParallelDecomposition const &decompn, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + return std::visit(OptimalCostFunctor(this, resource, fixed_machine_views), + decompn.raw_variant); + } + OptimalCostResult optimal_cost(Node const &node, MachineSpecification const &resource, @@ -310,12 +300,12 @@ struct MachineMappingSearcher { fixed_machine_views.at(any_output))); MachineView mv = fixed_machine_views.at(any_output); MachineMapping mv_map{{{node, mv}}}; - return {estimate_cost(subgraph, cost_estimator, fixed_machine_views), + return {base_case_estimate_cost(subgraph, cost_estimator, fixed_machine_views), mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (std::unordered_map node_machine_views : - enumerate_machine_views({node}, resource)) { + allowed_machine_mappings({node}, resource)) { MachineMapping mv_map{{{node, node_machine_views.at(node)}}}; std::unordered_map machine_views = fixed_machine_views; @@ -324,21 +314,21 @@ struct MachineMappingSearcher { } minimize_runtime( optimal_result, - {estimate_cost(subgraph, cost_estimator, machine_views), mv_map}); + {base_case_estimate_cost(subgraph, cost_estimator, machine_views), mv_map}); } return optimal_result; } } std::vector> - enumerate_machine_views(std::unordered_set const &nodes, - MachineSpecification const &resource) { + allowed_machine_mappings(std::unordered_set const &nodes, + MachineSpecification const &resource) { if (nodes.empty()) { return {{}}; } Node node = get_first(nodes); std::vector> partial_enumeration = - enumerate_machine_views(set_minus(nodes, {node}), resource); + allowed_machine_mappings(set_minus(nodes, {node}), resource); std::unordered_set allowed_machine_views_for_node = this->allowed_machine_views(pcg.raw_graph.at(node), resource); std::vector> enumeration; @@ -353,15 +343,15 @@ struct MachineMappingSearcher { } std::vector> - enumerate_machine_views(std::unordered_set const &values, - MachineSpecification const &resource) { + allowed_machine_mappings(std::unordered_set const &values, + MachineSpecification const &resource) { std::unordered_set nodes; for (DataflowOutput const &v : values) { nodes.insert(v.node); } std::vector> node_enumeration = - enumerate_machine_views(nodes, resource); + allowed_machine_mappings(nodes, resource); std::vector> enumeration; for (std::unordered_map _node_enumeration : diff --git a/lib/compiler/test/src/test_graph_optimize_state.cc b/lib/compiler/test/src/test_graph_optimize_state.cc new file mode 100644 index 0000000000..6625f294bc --- /dev/null +++ b/lib/compiler/test/src/test_graph_optimize_state.cc @@ -0,0 +1,72 @@ +#include "compiler/graph_optimize_state.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("graph_optimize_state:equality") { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, true, "input0"); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + parallel_tensor_guid_t dense1 = builder.dense(dense0, + 4, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense1"); + + ParallelComputationGraph pcg = builder.pcg; + + // `machine_mapping` is determined by the PCG and the device mapping + // algorithm, and `runtime` is determined by the PCG and the device mapping, + // so their values here do not matter. + std::unordered_map empty_machine_views; + MachineMapping empty_machine_mapping(empty_machine_views); + CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) == + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0)); + + ParallelComputationGraphBuilder builder_; + + parallel_tensor_guid_t input0_ = + builder.create_input_tensor(input_shape, true, "input0"); + parallel_tensor_guid_t dense0_ = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + ParallelComputationGraph pcg_ = builder.pcg; + + CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) != + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0)); + } +} diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 3c2cdee068..fe9b21df4a 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -6,7 +6,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("optimal_cost_0") { + TEST_CASE("optimal_cost does not crash on minimal inputs") { ParallelComputationGraphBuilder builder; ParallelTensorShape input_shape = From 8e27f2a106018eb1ca0e9df042acb45a98be6f1e Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Tue, 27 Aug 2024 17:29:50 -0400 Subject: [PATCH 05/29] refactor machine mapping --- .github/workflows/per-lib-check.yml | 6 +- .../graph_optimize_result.struct.toml | 2 +- .../include/compiler/graph_optimize_state.h | 4 +- .../include/compiler/machine_mapping.h | 54 --- .../get_optimal_machine_mapping.h | 56 +++ .../machine_mapping/machine_mapping.h | 14 + .../machine_mapping.struct.toml | 0 .../machine_mapping/machine_mapping_cache.h | 23 ++ .../machine_mapping_context.struct.toml | 29 ++ .../machine_mapping/machine_mapping_result.h | 18 + .../machine_mapping_result.struct.toml | 18 + .../machine_mapping_state.struct.toml} | 10 +- .../include/compiler/unity_algorithm.h | 1 - .../get_optimal_machine_mapping.cc | 336 +++++++++++++++ .../machine_mapping/machine_mapping.cc | 16 + .../machine_mapping/machine_mapping_cache.cc | 21 + .../machine_mapping/machine_mapping_result.cc | 31 ++ lib/compiler/src/graph_optimize_state.cc | 5 + lib/compiler/src/machine_mapping.cc | 383 ------------------ lib/compiler/src/unity_algorithm.cc | 26 +- lib/compiler/test/src/test_generator.h | 1 - .../test/src/test_graph_optimize_state.cc | 14 +- lib/compiler/test/src/test_optimal_cost.cc | 15 +- .../sub_parallel_computation_graph.h | 5 +- .../sub_parallel_computation_graph.cc | 4 +- lib/utils/include/utils/containers.h | 1 + 26 files changed, 614 insertions(+), 479 deletions(-) delete mode 100644 lib/compiler/include/compiler/machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping.h rename lib/compiler/include/compiler/{ => machine_mapping}/machine_mapping.struct.toml (100%) create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml rename lib/compiler/include/compiler/{optimal_cost_state.struct.toml => machine_mapping/machine_mapping_state.struct.toml} (86%) create mode 100644 lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc delete mode 100644 lib/compiler/src/machine_mapping.cc diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 38556a3c0e..f7b73198d5 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -108,9 +108,9 @@ jobs: run: | test_libs.sh substitutions - # - name: Test compiler - # run: | - # test_libs.sh compiler + - name: Test compiler + run: | + test_libs.sh compiler - name: Test substitution-generator run: | diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml index a028981c4c..22f29cbd59 100644 --- a/lib/compiler/include/compiler/graph_optimize_result.struct.toml +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -3,7 +3,7 @@ name = "GraphOptimizeResult" features = [ ] includes = [ - "compiler/machine_mapping.dtg.h", + "compiler/machine_mapping/machine_mapping.dtg.h", "pcg/parallel_computation_graph/parallel_computation_graph.h" ] diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h index 7fdad576ac..2de2321ba6 100644 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -2,12 +2,12 @@ #define _FLEXFLOW_COMPILER_MCMC_STATE_H #include "compiler/graph_optimize_result.dtg.h" -#include "compiler/machine_mapping.h" namespace FlexFlow { struct GraphOptimizeState { - GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, float runtime); + GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, + float runtime); GraphOptimizeResult graph_optimize_result; float runtime; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h deleted file mode 100644 index eab86ed7a9..0000000000 --- a/lib/compiler/include/compiler/machine_mapping.h +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H - -#include "compiler/machine_mapping.dtg.h" -#include "compiler/optimal_cost_state.dtg.h" -#include "cost_estimator.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &, MachineMapping const &); - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - -struct OptimalCostResult { - static OptimalCostResult sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult infinity(); - - float runtime; - req machine_mapping; -}; -FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); - -class OptimalCostCache { -public: - OptimalCostCache() = default; - - std::optional load(OptimalCostState const &) const; - void save(OptimalCostState const &, OptimalCostResult const &); - -private: - std::unordered_map cache; -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h new file mode 100644 index 0000000000..68199e2524 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -0,0 +1,56 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H + +#include "machine_mapping.h" +#include "machine_mapping_cache.h" +#include "machine_mapping_context.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" + +namespace FlexFlow { + +MachineMappingResult get_optimal_machine_mapping( + ParallelComputationGraph const &pcg, + std::function( + ParallelLayerAttrs const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + MachineMappingCache &cached_subgraph_results); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingContext &context, + MachineSpecification const &resources); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + SerialParallelDecomposition const &decompn, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + SerialSplit const &serial, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + ParallelSplit const ¶llel, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + Node const &node, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h new file mode 100644 index 0000000000..441f9a6034 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H + +#include "machine_mapping.dtg.h" + +namespace FlexFlow { + +MachineMapping combine(MachineMapping const &, MachineMapping const &); + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml similarity index 100% rename from lib/compiler/include/compiler/machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h new file mode 100644 index 0000000000..e2824777ff --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H + +#include "machine_mapping_result.dtg.h" +#include "machine_mapping_state.dtg.h" +#include "utils/optional.h" + +namespace FlexFlow { + +class MachineMappingCache { +public: + MachineMappingCache() = default; + + std::optional load(MachineMappingState const &) const; + void save(MachineMappingState const &, MachineMappingResult const &); + +private: + std::unordered_map cache; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml new file mode 100644 index 0000000000..eea1da6ca1 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "MachineMappingContext" +features = [ +] + +includes = [ + "machine_mapping.dtg.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/cost_estimator.h", + "pcg/machine_view.h", + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_cache.h" +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::CostEstimator" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::ParallelLayerAttrs const &, ::FlexFlow::MachineSpecification const &)>" + +[[fields]] +name = "cached_subgraph_results" +type = "::FlexFlow::MachineMappingCache" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h new file mode 100644 index 0000000000..912b604e4c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H + +#include "machine_mapping_result.dtg.h" + +namespace FlexFlow { + +MachineMappingResult sequential_combine(MachineMappingResult const &s1, + MachineMappingResult const &s2); +MachineMappingResult parallel_combine(MachineMappingResult const &s1, + MachineMappingResult const &s2); +MachineMappingResult get_infinity_machine_mapping_result(); + +void minimize_runtime(MachineMappingResult &m1, MachineMappingResult const &m2); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml new file mode 100644 index 0000000000..e8f3a8562e --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingResult" +features = [ + "eq", + "fmt", +] + +includes = [ + "machine_mapping.dtg.h", +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml similarity index 86% rename from lib/compiler/include/compiler/optimal_cost_state.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 7d06aa1d87..9a1661bcd6 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -1,11 +1,8 @@ namespace = "FlexFlow" -name = "OptimalCostState" +name = "MachineMappingState" features = [ "eq", - # "ord", "hash", - # "json", - # "rapidcheck", "fmt", ] @@ -15,8 +12,11 @@ includes = [ "pcg/machine_view.dtg.h", "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +src_includes = [ "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index b30f01d8f0..9eeb9fe563 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #include "compiler/graph_optimize_result.dtg.h" -#include "compiler/machine_mapping.h" #include "cost_estimator.h" #include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..5a1415c008 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,336 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/cost_estimator.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/containers.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_first.h" +#include "utils/containers/get_only.h" +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" +#include "utils/exception.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/graph_split.dtg.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &nodes, + MachineSpecification const &resource) { + if (nodes.empty()) { + return {{}}; + } + Node node = get_first(nodes); + std::vector> partial_enumeration = + allowed_machine_mappings(context, set_minus(nodes, {node}), resource); + std::unordered_set allowed_machine_views_for_node = + context.allowed_machine_views(context.pcg.raw_graph.at(node), resource); + std::vector> enumeration; + for (MachineView const &mv : allowed_machine_views_for_node) { + for (std::unordered_map const &partial : + partial_enumeration) { + enumeration.push_back(merge_maps( + partial, std::unordered_map{{node, mv}})); + } + } + return enumeration; +} + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &values, + MachineSpecification const &resource) { + std::unordered_set nodes; + for (DataflowOutput const &v : values) { + nodes.insert(v.node); + } + + std::vector> node_enumeration = + allowed_machine_mappings(context, nodes, resource); + std::vector> enumeration; + + for (std::unordered_map _node_enumeration : + node_enumeration) { + std::unordered_map _emumeration; + for (DataflowOutput const &v : values) { + _emumeration.emplace(v, _node_enumeration.at(v.node)); + } + enumeration.push_back(_emumeration); + } + + return enumeration; +} + +std::vector> + get_resource_split(MachineSpecification const &resource) { + std::vector> result; + for (int i = 1; i < resource.num_nodes; ++i) { + MachineSpecification sub_resource1 = resource, sub_resource2 = resource; + sub_resource1.num_nodes = i; + sub_resource2.num_nodes = resource.num_nodes - i; + result.push_back(std::make_pair(sub_resource1, sub_resource2)); + } + return result; +} + +std::pair + decompose(SerialSplit const &serial) { + if (serial.children.size() == 2) { + return {widen(serial.children[0]), + widen(serial.children[1])}; + } + SerialSplit decompn1 = serial; + decompn1.children.pop_back(); + return {SerialParallelDecomposition(decompn1), + widen(serial.children.back())}; +} + +std::pair + decompose(ParallelSplit const ¶llel) { + if (parallel.children.size() == 2) { + std::vector children = + transform(as_vector(parallel.children), [&](auto const &child) { + return widen(child); + }); + return {children[0], children[1]}; + } + ParallelSplit decompn1 = parallel; + std::variant child = *parallel.children.begin(); + decompn1.children.erase(child); + return {SerialParallelDecomposition(decompn1), + widen(child)}; +} + +GraphSplit + get_graph_split(SerialParallelDecomposition const &pre_decomposition, + SerialParallelDecomposition const &post_decomposition) { + return GraphSplit{get_nodes(pre_decomposition), + get_nodes(post_decomposition)}; +} + +float base_case_estimate_cost( + SubParallelComputationGraph const &g, + CostEstimator const &estimator, + std::unordered_map const &machine_views) { + // In the base case, all the operators are executed sequentially. + float cost = 0.1; + // TODO(@wmdi) + return cost; +} + +MachineMappingResult get_optimal_machine_mapping( + ParallelComputationGraph const &pcg, + std::function( + ParallelLayerAttrs const &, MachineSpecification const &)> const + &allowed_machine_views, + CostEstimator const &cost_estimator, + MachineSpecification const &resources, + MachineMappingCache &cached_subgraph_results) { + + MachineMappingContext context( + pcg, cost_estimator, allowed_machine_views, cached_subgraph_results); + MachineMappingResult result = get_optimal_machine_mapping(context, resources); + cached_subgraph_results = context.cached_subgraph_results; + return result; +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingContext &context, + MachineSpecification const &resources) { + SerialParallelDecomposition decompn = + get_serial_parallel_decomposition(context.pcg.raw_graph).value(); + return get_optimal_machine_mapping(context, decompn, resources, {}); +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + SerialParallelDecomposition const &decompn, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + + MachineMappingState state(decompn, resource, fixed_machine_views); + std::optional cached_result = + context.cached_subgraph_results.load(state); + if (cached_result) { + return cached_result.value(); + } + + MachineMappingResult result = decompn.visit( + overload{[&](SerialSplit const &serial) { + return get_optimal_machine_mapping( + context, serial, resource, fixed_machine_views); + }, + [&](ParallelSplit const ¶llel) { + return get_optimal_machine_mapping( + context, parallel, resource, fixed_machine_views); + }, + [&](Node const &node) { + return get_optimal_machine_mapping( + context, node, resource, fixed_machine_views); + }}); + + context.cached_subgraph_results.save(state, result); + return result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + SerialSplit const &serial, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + + auto [decompn1, decompn2] = decompose(serial); + + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + + OpenDataflowSubgraphResult subgraph_res1 = get_subgraph( + sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = get_subgraph( + sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.second); + + std::unordered_set split_outputs = transform( + keys(subgraph_res2.full_graph_values_to_subgraph_inputs), + [](OpenDataflowValue const &v) { return v.get(); }); + + for (std::unordered_map const + &split_machine_views : + allowed_machine_mappings(context, split_outputs, resource)) { + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res2.graph)); + + for (auto const &[split_value, split_input] : + subgraph_res2.full_graph_values_to_subgraph_inputs) { + MachineView mv = + split_machine_views.at(split_value.get()); + fixed_machine_views1.emplace(split_value, mv); + fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); + } + + minimize_runtime( + optimal_result, + sequential_combine( + get_optimal_machine_mapping( + context, decompn1, resource, fixed_machine_views1), + get_optimal_machine_mapping( + context, decompn2, resource, fixed_machine_views2))); + } + + return optimal_result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + ParallelSplit const ¶llel, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + auto [decompn1, decompn2] = decompose(parallel); + + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + + OpenDataflowSubgraphResult subgraph_res1 = get_subgraph( + sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.first); + OpenDataflowSubgraphResult subgraph_res2 = get_subgraph( + sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.second); + + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res1.graph)); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, + get_open_dataflow_values(subgraph_res2.graph)); + + MachineMappingResult optimal_result = sequential_combine( + get_optimal_machine_mapping( + context, decompn1, resource, fixed_machine_views1), + get_optimal_machine_mapping( + context, decompn2, resource, fixed_machine_views2)); + + for (auto const &resource_split : get_resource_split(resource)) { + minimize_runtime( + optimal_result, + parallel_combine( + get_optimal_machine_mapping( + context, decompn1, resource_split.first, fixed_machine_views1), + get_optimal_machine_mapping(context, + decompn2, + resource_split.second, + fixed_machine_views2))); + } + + return optimal_result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingContext &context, + Node const &node, + MachineSpecification const &resource, + std::unordered_map const + &fixed_machine_views) { + + SubParallelComputationGraph subgraph = get_pcg_subgraph(context.pcg, {node}); + + OpenDataflowValue any_output = + OpenDataflowValue(get_outputs(context.pcg.raw_graph, node)[0]); + if (contains_key(fixed_machine_views, any_output)) { + assert(contains( + context.allowed_machine_views(context.pcg.raw_graph.at(node), resource), + fixed_machine_views.at(any_output))); + MachineView mv = fixed_machine_views.at(any_output); + MachineMapping mv_map{{{node, mv}}}; + return MachineMappingResult(base_case_estimate_cost(subgraph, + context.cost_estimator, + fixed_machine_views), + mv_map); + } else { + MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + for (std::unordered_map node_machine_views : + allowed_machine_mappings(context, {node}, resource)) { + MachineView mv = node_machine_views.at(node); + MachineMapping mv_map{{{node, mv}}}; + + std::unordered_map output_mv_map = + generate_map(transform(get_outputs(context.pcg.raw_graph, node), + [](DataflowOutput const &o) { + return OpenDataflowValue(o); + }), + [&](OpenDataflowValue const &o) { return mv; }); + + std::unordered_map machine_views = + merge_maps(fixed_machine_views, output_mv_map); + minimize_runtime(optimal_result, + MachineMappingResult( + base_case_estimate_cost( + subgraph, context.cost_estimator, machine_views), + mv_map)); + } + return optimal_result; + } +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..d739f3fecb --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,16 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "utils/containers.h" +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" + +namespace FlexFlow { + +MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; +} + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc new file mode 100644 index 0000000000..232598b98d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -0,0 +1,21 @@ +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "utils/containers/contains_key.h" + +namespace FlexFlow { + +std::optional + MachineMappingCache::load(MachineMappingState const &state) const { + if (contains_key(cache, state)) { + MachineMappingResult result = cache.at(state); + return std::make_optional(result); + } + return std::nullopt; +} + +void MachineMappingCache::save(MachineMappingState const &state, + MachineMappingResult const &result) { + assert(!contains_key(cache, state)); + cache.emplace(state, result); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..c33b8776df --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,31 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping.h" + +namespace FlexFlow { + +MachineMappingResult sequential_combine(MachineMappingResult const &s1, + MachineMappingResult const &s2) { + return MachineMappingResult{s1.runtime + s2.runtime, + combine(s1.machine_mapping, s2.machine_mapping)}; +} + +MachineMappingResult parallel_combine(MachineMappingResult const &s1, + MachineMappingResult const &s2) { + return MachineMappingResult{std::max(s1.runtime, s2.runtime), + combine(s1.machine_mapping, s2.machine_mapping)}; +} + +MachineMappingResult get_infinity_machine_mapping_result() { + return MachineMappingResult( + std::numeric_limits::infinity(), + MachineMapping(std::unordered_map{})); +} + +void minimize_runtime(MachineMappingResult &m1, + MachineMappingResult const &m2) { + if (m2.runtime < m1.runtime) { + m1 = m2; + } +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc index cdba17b2c9..063d8e3ee4 100644 --- a/lib/compiler/src/graph_optimize_state.cc +++ b/lib/compiler/src/graph_optimize_state.cc @@ -7,6 +7,9 @@ GraphOptimizeState::GraphOptimizeState( : graph_optimize_result(graph_optimize_result), runtime(runtime) {} bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + // Note(@wmdi): This is a hack to implement a partially correct homomorphism + // check. Switch to the homomorphism check used in substitutions right after + // https://github.com/flexflow/FlexFlow/pull/1471 is merged. auto layers1 = topological_ordering(graph_optimize_result.pcg); auto layers2 = topological_ordering(other.graph_optimize_result.pcg); if (layers1.size() != layers2.size()) { @@ -56,6 +59,8 @@ namespace std { size_t hash<::FlexFlow::GraphOptimizeState>::operator()( ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash size_t seed = 0; auto layers = topological_ordering(state.graph_optimize_result.pcg); ::FlexFlow::hash_combine(seed, layers.size()); diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc deleted file mode 100644 index 4f53037ab2..0000000000 --- a/lib/compiler/src/machine_mapping.cc +++ /dev/null @@ -1,383 +0,0 @@ -#include "compiler/machine_mapping.h" -#include "compiler/cost_estimator.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers.h" -#include "utils/containers/are_disjoint.h" -#include "utils/containers/as_vector.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/get_first.h" -#include "utils/containers/get_only.h" -#include "utils/containers/keys.h" -#include "utils/containers/restrict_keys.h" -#include "utils/containers/set_minus.h" -#include "utils/containers/values.h" -#include "utils/exception.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/graph_split.dtg.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { - return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; -} - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); -} - -void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { - if (m2.runtime < m1.runtime) { - m1 = m2; - } -} - -OptimalCostResult - OptimalCostResult::sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{s1.runtime + s2.runtime, - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult - OptimalCostResult::parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{std::max(s1.runtime, s2.runtime), - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult OptimalCostResult::infinity() { - return {std::numeric_limits::infinity(), - MachineMapping{std::unordered_map{}}}; -} - -std::optional - OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - OptimalCostResult result = cache.at(state); - return std::make_optional(result); - } - return std::nullopt; -} - -void OptimalCostCache::save(OptimalCostState const &state, - OptimalCostResult const &result) { - assert(!contains_key(cache, state)); - cache.emplace(state, result); -} - -std::vector> - get_resource_split(MachineSpecification const &resource) { - std::vector> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; - sub_resource1.num_nodes = i; - sub_resource2.num_nodes = resource.num_nodes - i; - result.push_back(std::make_pair(sub_resource1, sub_resource2)); - } - return result; -} - -std::pair - decompose(SerialSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SerialSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - decompose(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; -} - -GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; -} - -float base_case_estimate_cost( - SubParallelComputationGraph const &g, - CostEstimator const &estimator, - std::unordered_map const &machine_views) { - // In the base case, all the operators are executed sequentially. - float cost = 0.1; - // TODO(@wmdi) - return cost; -} - -struct MachineMappingSearcher { - MachineMappingSearcher( - ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : pcg(pcg), cost_estimator(cost_estimator), - allowed_machine_views(allowed_machine_views), - cached_subgraph_costs(cached_subgraph_costs) {} - - ParallelComputationGraph pcg; - CostEstimator cost_estimator; - std::function(ParallelLayerAttrs const &, - MachineSpecification const &)> - allowed_machine_views; - OptimalCostCache &cached_subgraph_costs; - - struct OptimalCostFunctor { - OptimalCostFunctor( - MachineMappingSearcher *searcher, - MachineSpecification resource, - std::unordered_map fixed_machine_views) - : searcher(searcher), resource(resource), - fixed_machine_views(fixed_machine_views) {} - - MachineMappingSearcher *searcher; - MachineSpecification resource; - std::unordered_map fixed_machine_views; - - template - OptimalCostResult operator()(T const &t) { - OptimalCostState state( - SerialParallelDecomposition{t}, resource, fixed_machine_views); - std::optional cached_result = - searcher->cached_subgraph_costs.load(state); - - if (cached_result) { - return cached_result.value(); - } - OptimalCostResult result = - searcher->optimal_cost(t, resource, fixed_machine_views); - - searcher->cached_subgraph_costs.save(state, result); - return result; - } - }; - - OptimalCostResult optimal_cost(MachineSpecification resource) { - return std::visit( - OptimalCostFunctor(this, resource, {}), - get_serial_parallel_decomposition(pcg.raw_graph).value().raw_variant); - } - - OptimalCostResult - optimal_cost(SerialSplit const &serial, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - OptimalCostResult optimal_result = OptimalCostResult::infinity(); - - auto [decompn1, decompn2] = decompose(serial); - - GraphSplit graph_split = get_graph_split(decompn1, decompn2); - - OpenDataflowSubgraphResult subgraph_res1 = - get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = - get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); - - std::unordered_set split_outputs; - for (auto const &[value, _] : - subgraph_res2.full_graph_values_to_subgraph_inputs) { - split_outputs.insert(value.get()); - } - - for (std::unordered_map const - &split_machine_views : - allowed_machine_mappings(split_outputs, resource)) { - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res2.graph)); - - for (auto const &[split_value, split_input] : - subgraph_res2.full_graph_values_to_subgraph_inputs) { - MachineView mv = - split_machine_views.at(split_value.get()); - fixed_machine_views1.emplace(split_value, mv); - fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); - } - - minimize_runtime( - optimal_result, - OptimalCostResult::sequential_combine( - optimal_cost(decompn1, resource, fixed_machine_views1), - optimal_cost(decompn2, resource, fixed_machine_views2))); - } - - return optimal_result; - } - - OptimalCostResult - optimal_cost(ParallelSplit const ¶llel, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - auto [decompn1, decompn2] = decompose(parallel); - - GraphSplit graph_split = get_graph_split(decompn1, decompn2); - - OpenDataflowSubgraphResult subgraph_res1 = - get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = - get_subgraph(sub_pcg_from_full_pcg(pcg).raw_graph, graph_split.second); - - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res2.graph)); - - OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - optimal_cost(decompn1, resource, fixed_machine_views1), - optimal_cost(decompn2, resource, fixed_machine_views2)); - - for (auto const &resource_split : get_resource_split(resource)) { - minimize_runtime( - optimal_result, - OptimalCostResult::parallel_combine( - optimal_cost( - decompn1, resource_split.first, fixed_machine_views1), - optimal_cost( - decompn2, resource_split.second, fixed_machine_views2))); - } - - return optimal_result; - } - - OptimalCostResult - optimal_cost(SerialParallelDecomposition const &decompn, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - return std::visit(OptimalCostFunctor(this, resource, fixed_machine_views), - decompn.raw_variant); - } - - OptimalCostResult - optimal_cost(Node const &node, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - SubParallelComputationGraph subgraph = - sub_pcg_from_partial_pcg(pcg, {node}); - - OpenDataflowValue any_output = - OpenDataflowValue(get_outputs(pcg.raw_graph, node)[0]); - if (contains_key(fixed_machine_views, any_output)) { - assert(contains(allowed_machine_views(pcg.raw_graph.at(node), resource), - fixed_machine_views.at(any_output))); - MachineView mv = fixed_machine_views.at(any_output); - MachineMapping mv_map{{{node, mv}}}; - return {base_case_estimate_cost(subgraph, cost_estimator, fixed_machine_views), - mv_map}; - } else { - OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (std::unordered_map node_machine_views : - allowed_machine_mappings({node}, resource)) { - MachineMapping mv_map{{{node, node_machine_views.at(node)}}}; - std::unordered_map machine_views = - fixed_machine_views; - for (DataflowOutput o : get_outputs(pcg.raw_graph, node)) { - machine_views.emplace(o, node_machine_views.at(node)); - } - minimize_runtime( - optimal_result, - {base_case_estimate_cost(subgraph, cost_estimator, machine_views), mv_map}); - } - return optimal_result; - } - } - - std::vector> - allowed_machine_mappings(std::unordered_set const &nodes, - MachineSpecification const &resource) { - if (nodes.empty()) { - return {{}}; - } - Node node = get_first(nodes); - std::vector> partial_enumeration = - allowed_machine_mappings(set_minus(nodes, {node}), resource); - std::unordered_set allowed_machine_views_for_node = - this->allowed_machine_views(pcg.raw_graph.at(node), resource); - std::vector> enumeration; - for (MachineView const &mv : allowed_machine_views_for_node) { - for (std::unordered_map const &partial : - partial_enumeration) { - enumeration.push_back(merge_maps( - partial, std::unordered_map{{node, mv}})); - } - } - return enumeration; - } - - std::vector> - allowed_machine_mappings(std::unordered_set const &values, - MachineSpecification const &resource) { - std::unordered_set nodes; - for (DataflowOutput const &v : values) { - nodes.insert(v.node); - } - - std::vector> node_enumeration = - allowed_machine_mappings(nodes, resource); - std::vector> enumeration; - - for (std::unordered_map _node_enumeration : - node_enumeration) { - std::unordered_map _emumeration; - for (DataflowOutput const &v : values) { - _emumeration.emplace(v, _node_enumeration.at(v.node)); - } - enumeration.push_back(_emumeration); - } - - return enumeration; - } -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs) { - MachineMappingSearcher searcher( - g, cost_estimator, allowed_machine_views, cached_subgraph_costs); - return searcher.optimal_cost(resources); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index d85ad08d24..caf072fdbc 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,6 +1,6 @@ #include "compiler/unity_algorithm.h" #include "compiler/graph_optimize_state.h" -#include "compiler/machine_mapping.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" @@ -35,14 +35,15 @@ GraphOptimizeResult graph_optimize( std::vector substitutions = get_all_applicable_substitutions(pcg); - OptimalCostCache cached_subgraph_costs; + MachineMappingCache cached_subgraph_costs; DeduplicatedPriorityQueue candidates; - OptimalCostResult original_pcg_cost = optimal_cost(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); + MachineMappingResult original_pcg_cost = + get_optimal_machine_mapping(pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); GraphOptimizeState initial_state = { GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), @@ -65,11 +66,12 @@ GraphOptimizeResult graph_optimize( for (Substitution const &substitution : substitutions) { for (ParallelComputationGraph const &new_pcg : apply_substitution( current_state.graph_optimize_result.pcg, substitution)) { - OptimalCostResult new_pcg_cost = optimal_cost(new_pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); + MachineMappingResult new_pcg_cost = + get_optimal_machine_mapping(new_pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); GraphOptimizeState new_state{ GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), new_pcg_cost.runtime}; diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index d6b8222968..39d0cc710c 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_TEST_GENERATOR_H #define _FLEXFLOW_TEST_GENERATOR_H -#include "compiler/machine_mapping.h" #include "pcg/computation_graph.h" #include "rapidcheck.h" #include "substitutions/sub_parallel_computation_graph.h" diff --git a/lib/compiler/test/src/test_graph_optimize_state.cc b/lib/compiler/test/src/test_graph_optimize_state.cc index 6625f294bc..49c4f9958f 100644 --- a/lib/compiler/test/src/test_graph_optimize_state.cc +++ b/lib/compiler/test/src/test_graph_optimize_state.cc @@ -5,7 +5,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("graph_optimize_state:equality") { + TEST_CASE("GraphOptimizeState::operator==") { ParallelComputationGraphBuilder builder; ParallelTensorShape input_shape = @@ -48,8 +48,10 @@ TEST_SUITE(FF_TEST_SUITE) { // so their values here do not matter. std::unordered_map empty_machine_views; MachineMapping empty_machine_mapping(empty_machine_views); - CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) == - GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0)); + CHECK( + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0)); ParallelComputationGraphBuilder builder_; @@ -66,7 +68,9 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg_ = builder.pcg; - CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) != - GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0)); + CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) != + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), + 0)); } } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index fe9b21df4a..9c82585de4 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -1,4 +1,4 @@ -#include "compiler/unity_algorithm.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "doctest/doctest.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test_cost_estimator.h" @@ -43,12 +43,13 @@ TEST_SUITE(FF_TEST_SUITE) { CostEstimator estimator = CostEstimator::create(); MachineSpecification machine_spec{1, 1, 1, 1, 1}; - OptimalCostCache cached_results; - OptimalCostResult result = optimal_cost(pcg, - test_allowed_machine_views, - estimator, - machine_spec, - cached_results); + MachineMappingCache cached_results; + MachineMappingResult result = + get_optimal_machine_mapping(pcg, + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); CHECK(bool(result.runtime > 0)); } diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index ba09d3a166..37c105eead 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -21,9 +21,8 @@ SubParallelComputationGraph ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &); -SubParallelComputationGraph - sub_pcg_from_partial_pcg(ParallelComputationGraph const &, - std::unordered_set const &); +SubParallelComputationGraph get_pcg_subgraph(ParallelComputationGraph const &, + std::unordered_set const &); parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 923638be59..72f710e93b 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -55,8 +55,8 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( } SubParallelComputationGraph - sub_pcg_from_partial_pcg(ParallelComputationGraph const &pcg, - std::unordered_set const &nodes) { + get_pcg_subgraph(ParallelComputationGraph const &pcg, + std::unordered_set const &nodes) { auto as_open = view_as_labelled_open_dataflow_graph(pcg.raw_graph); OpenDataflowSubgraphResult subgraph_result = get_subgraph(as_open, nodes); return SubParallelComputationGraph{with_labelling( diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6164699f2e..cddc8c7ef0 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -5,6 +5,7 @@ #include "required_core.h" #include "type_traits_core.h" #include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" #include "utils/containers/contains.h" #include "utils/containers/extend.h" #include "utils/containers/extend_vector.h" From 9b9f529e0ee29b7c5975e23d19c96279087d870a Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 28 Aug 2024 20:54:56 -0400 Subject: [PATCH 06/29] add unit tests --- .../get_optimal_machine_mapping.h | 10 +- .../machine_mapping/split_sp_decomposition.h | 16 ++ .../get_optimal_machine_mapping.cc | 76 +++---- .../machine_mapping/split_sp_decomposition.cc | 36 ++++ .../cost_estimator_for_test.h} | 10 +- .../test_get_optimal_machine_mapping.cc | 189 ++++++++++++++++++ .../machine_mapping/test_machine_mapping.cc | 0 .../test_machine_mapping_cache.cc | 81 ++++++++ .../test_machine_mapping_result.cc | 0 lib/compiler/test/src/test_generator.h | 173 ---------------- .../test/src/test_labelled_open_graph.cc | 132 ------------ lib/compiler/test/src/test_machine_mapping.cc | 23 --- lib/compiler/test/src/test_open_graph.cc | 81 -------- lib/compiler/test/src/test_optimal_cost.cc | 56 ------ lib/compiler/test/src/test_unity_algorithm.cc | 2 - 15 files changed, 360 insertions(+), 525 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h create mode 100644 lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc rename lib/compiler/test/src/{test_cost_estimator.h => compiler/machine_mapping/cost_estimator_for_test.h} (73%) create mode 100644 lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc delete mode 100644 lib/compiler/test/src/test_generator.h delete mode 100644 lib/compiler/test/src/test_labelled_open_graph.cc delete mode 100644 lib/compiler/test/src/test_machine_mapping.cc delete mode 100644 lib/compiler/test/src/test_open_graph.cc delete mode 100644 lib/compiler/test/src/test_optimal_cost.cc diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 68199e2524..466b96dcab 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -20,31 +20,31 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &cached_subgraph_results); MachineMappingResult - get_optimal_machine_mapping(MachineMappingContext &context, + get_optimal_machine_mapping_internal(MachineMappingContext &context, MachineSpecification const &resources); -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialParallelDecomposition const &decompn, MachineSpecification const &resource, std::unordered_map const &fixed_machine_views); -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialSplit const &serial, MachineSpecification const &resource, std::unordered_map const &fixed_machine_views); -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, ParallelSplit const ¶llel, MachineSpecification const &resource, std::unordered_map const &fixed_machine_views); -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, Node const &node, MachineSpecification const &resource, diff --git a/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h b/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h new file mode 100644 index 0000000000..acafcfcb0e --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_SPLIT_SP_DECOMPOSITION_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_SPLIT_SP_DECOMPOSITION_H + +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::pair + split_sp_decomposition(SerialSplit const &serial); + +std::pair + split_sp_decomposition(ParallelSplit const ¶llel); + +} + +#endif \ No newline at end of file diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 5a1415c008..dcb51c97a9 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "compiler/cost_estimator.h" #include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/split_sp_decomposition.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" @@ -91,34 +92,6 @@ std::vector> return result; } -std::pair - decompose(SerialSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SerialSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - decompose(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; -} - GraphSplit get_graph_split(SerialParallelDecomposition const &pre_decomposition, SerialParallelDecomposition const &post_decomposition) { @@ -147,20 +120,27 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingContext context( pcg, cost_estimator, allowed_machine_views, cached_subgraph_results); - MachineMappingResult result = get_optimal_machine_mapping(context, resources); + MachineMappingResult result = get_optimal_machine_mapping_internal(context, resources); cached_subgraph_results = context.cached_subgraph_results; return result; } MachineMappingResult - get_optimal_machine_mapping(MachineMappingContext &context, + get_optimal_machine_mapping_internal(MachineMappingContext &context, MachineSpecification const &resources) { - SerialParallelDecomposition decompn = - get_serial_parallel_decomposition(context.pcg.raw_graph).value(); - return get_optimal_machine_mapping(context, decompn, resources, {}); + std::optional decompn_optional = + get_serial_parallel_decomposition(context.pcg.raw_graph); + + if (!decompn_optional) { + throw mk_runtime_error("Failed to get serial parallel decomposition"); + } + + SerialParallelDecomposition decompn = decompn_optional.value(); + + return get_optimal_machine_mapping_internal(context, decompn, resources, {}); } -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialParallelDecomposition const &decompn, MachineSpecification const &resource, @@ -176,15 +156,15 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult result = decompn.visit( overload{[&](SerialSplit const &serial) { - return get_optimal_machine_mapping( + return get_optimal_machine_mapping_internal( context, serial, resource, fixed_machine_views); }, [&](ParallelSplit const ¶llel) { - return get_optimal_machine_mapping( + return get_optimal_machine_mapping_internal( context, parallel, resource, fixed_machine_views); }, [&](Node const &node) { - return get_optimal_machine_mapping( + return get_optimal_machine_mapping_internal( context, node, resource, fixed_machine_views); }}); @@ -192,7 +172,7 @@ MachineMappingResult get_optimal_machine_mapping( return result; } -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialSplit const &serial, MachineSpecification const &resource, @@ -200,7 +180,7 @@ MachineMappingResult get_optimal_machine_mapping( &fixed_machine_views) { MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); - auto [decompn1, decompn2] = decompose(serial); + auto [decompn1, decompn2] = split_sp_decomposition(serial); GraphSplit graph_split = get_graph_split(decompn1, decompn2); @@ -234,22 +214,22 @@ MachineMappingResult get_optimal_machine_mapping( minimize_runtime( optimal_result, sequential_combine( - get_optimal_machine_mapping( + get_optimal_machine_mapping_internal( context, decompn1, resource, fixed_machine_views1), - get_optimal_machine_mapping( + get_optimal_machine_mapping_internal( context, decompn2, resource, fixed_machine_views2))); } return optimal_result; } -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, ParallelSplit const ¶llel, MachineSpecification const &resource, std::unordered_map const &fixed_machine_views) { - auto [decompn1, decompn2] = decompose(parallel); + auto [decompn1, decompn2] = split_sp_decomposition(parallel); GraphSplit graph_split = get_graph_split(decompn1, decompn2); @@ -266,18 +246,18 @@ MachineMappingResult get_optimal_machine_mapping( get_open_dataflow_values(subgraph_res2.graph)); MachineMappingResult optimal_result = sequential_combine( - get_optimal_machine_mapping( + get_optimal_machine_mapping_internal( context, decompn1, resource, fixed_machine_views1), - get_optimal_machine_mapping( + get_optimal_machine_mapping_internal( context, decompn2, resource, fixed_machine_views2)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime( optimal_result, parallel_combine( - get_optimal_machine_mapping( + get_optimal_machine_mapping_internal( context, decompn1, resource_split.first, fixed_machine_views1), - get_optimal_machine_mapping(context, + get_optimal_machine_mapping_internal(context, decompn2, resource_split.second, fixed_machine_views2))); @@ -286,7 +266,7 @@ MachineMappingResult get_optimal_machine_mapping( return optimal_result; } -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, Node const &node, MachineSpecification const &resource, diff --git a/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc b/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc new file mode 100644 index 0000000000..f9ab119171 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc @@ -0,0 +1,36 @@ +#include "compiler/machine_mapping/split_sp_decomposition.h" +#include "utils/variant.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::pair + split_sp_decomposition(SerialSplit const &serial) { + if (serial.children.size() == 2) { + return {widen(serial.children[0]), + widen(serial.children[1])}; + } + SerialSplit decompn1 = serial; + decompn1.children.pop_back(); + return {SerialParallelDecomposition(decompn1), + widen(serial.children.back())}; +} + +std::pair + split_sp_decomposition(ParallelSplit const ¶llel) { + if (parallel.children.size() == 2) { + std::vector children = + transform(as_vector(parallel.children), [&](auto const &child) { + return widen(child); + }); + return {children[0], children[1]}; + } + ParallelSplit decompn1 = parallel; + std::variant child = *parallel.children.begin(); + decompn1.children.erase(child); + return {SerialParallelDecomposition(decompn1), + widen(child)}; +} + +} diff --git a/lib/compiler/test/src/test_cost_estimator.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h similarity index 73% rename from lib/compiler/test/src/test_cost_estimator.h rename to lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index 322f80352f..90218bb322 100644 --- a/lib/compiler/test/src/test_cost_estimator.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -5,18 +5,18 @@ namespace FlexFlow { -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, +struct CostEstimatorForTest : public ICostEstimator { + inline float estimate_cost(PCGOperatorAttrs const &op, std::vector const &inputs, std::vector const &weights, std::vector const &outputs, MachineView const &mv) const override { - return 0.1; + return 1; } - float estimate_cost(ParallelTensorShape const &tensor_shape, + inline float estimate_cost(ParallelTensorShape const &tensor_shape, MachineView const &src, MachineView const &dst) const override { - return 0.1; + return 1; } }; diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..ce84331ab7 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc @@ -0,0 +1,189 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "cost_estimator_for_test.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_optimal_machine_mapping") { + + ParallelComputationGraph pcg_simple = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8); + + return builder.pcg; + }(); + + ParallelComputationGraph pcg_chain = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + parallel_tensor_guid_t dense1 = builder.dense(dense0, 8); + parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); + parallel_tensor_guid_t dense3 = builder.dense(dense2, 8); + parallel_tensor_guid_t dense4 = builder.dense(dense3, 8); + parallel_tensor_guid_t dense5 = builder.dense(dense4, 8); + + return builder.pcg; + }(); + + ParallelComputationGraph pcg_multiple_chains = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape0 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{32, 1}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + ParallelTensorShape input_shape1 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{8, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); + parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); + parallel_tensor_guid_t relu0 = builder.relu(input0); + parallel_tensor_guid_t relu1 = builder.relu(input1); + parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); + + return builder.pcg; + }(); + + ParallelComputationGraph pcg_non_sp = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + parallel_tensor_guid_t dense1 = builder.dense(input0, 4); + parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); + parallel_tensor_guid_t add0 = builder.add(dense0, dense2); + + return builder.pcg; + }(); + + auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, + MachineSpecification const &) { + // TODO(@Mengdi Wu): Replace it with actual allowed machine views when https://github.com/flexflow/FlexFlow/pull/1458 is merged + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; + CostEstimator estimator1 = CostEstimator::create(); + MachineSpecification machine_spec1(1, 1, 1, 1, 1); + MachineMappingCache cached_results1; + + SUBCASE("simple PCG") { + MachineMappingResult result = + get_optimal_machine_mapping(pcg_simple, + allowed_machine_views1, + estimator1, + machine_spec1, + cached_results1); + + CHECK(bool(result.runtime > 0)); + // TODO(@Mengdi Wu): fill it with actual cost + // CHECK(result.runtime == xx); + } + + SUBCASE("PCG is a chain") { + MachineMappingResult result = + get_optimal_machine_mapping(pcg_chain, + allowed_machine_views1, + estimator1, + machine_spec1, + cached_results1); + CHECK(bool(result.runtime > 0)); + // CHECK(result.runtime == xx); + } + + SUBCASE("PCG has multiple chains") { + MachineMappingResult result = + get_optimal_machine_mapping(pcg_multiple_chains, + allowed_machine_views1, + estimator1, + machine_spec1, + cached_results1); + CHECK(bool(result.runtime > 0)); + // CHECK(result.runtime == xx); + } + + SUBCASE("PCG is not sp-izable due to multiple inputs") { + // TODO: Handle this case in compiler + if (false) { + MachineMappingResult result = + get_optimal_machine_mapping(pcg_non_sp, + allowed_machine_views1, + estimator1, + machine_spec1, + cached_results1); + CHECK(bool(result.runtime > 0)); + // CHECK(result.runtime == xx); + } + } + + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc new file mode 100644 index 0000000000..a4d4cef7f6 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc @@ -0,0 +1,81 @@ +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "cost_estimator_for_test.h" +#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "compiler/machine_mapping/split_sp_decomposition.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("machine_mapping_cache") { + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape0 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{32, 1}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + ParallelTensorShape input_shape1 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{8, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); + parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); + parallel_tensor_guid_t relu0 = builder.relu(input0); + parallel_tensor_guid_t relu1 = builder.relu(input1); + parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); + + return builder.pcg; + }(); + + SerialParallelDecomposition subgraph0 = get_serial_parallel_decomposition(pcg.raw_graph).value(); + auto [subgraph1, subgraph2] = split_sp_decomposition(subgraph0.get()); + + MachineSpecification machine_spec(1, 1, 1, 1, 1); + MachineMappingState state0(subgraph0, machine_spec, {}); + MachineMappingState state1(subgraph1, machine_spec, {}); + MachineMappingState state2(subgraph2, machine_spec, {}); + + MachineMappingResult result0(2, MachineMapping(std::unordered_map{})); + MachineMappingResult result1(1, MachineMapping(std::unordered_map{})); + MachineMappingResult result2(1, MachineMapping(std::unordered_map{})); + + MachineMappingCache cache; + + cache.save(state0, result0); + CHECK(cache.load(state0).value() == result0); + CHECK(cache.load(state1) == std::nullopt); + CHECK(cache.load(state2) == std::nullopt); + + cache.save(state1, result1); + CHECK(cache.load(state0).value() == result0); + CHECK(cache.load(state1).value() == result1); + CHECK(cache.load(state2) == std::nullopt); + + cache.save(state2, result2); + CHECK(cache.load(state0).value() == result0); + CHECK(cache.load(state1).value() == result1); + CHECK(cache.load(state2).value() == result2); + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h deleted file mode 100644 index 39d0cc710c..0000000000 --- a/lib/compiler/test/src/test_generator.h +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef _FLEXFLOW_TEST_GENERATOR_H -#define _FLEXFLOW_TEST_GENERATOR_H - -#include "pcg/computation_graph.h" -#include "rapidcheck.h" -#include "substitutions/sub_parallel_computation_graph.h" - -using namespace FlexFlow; - -// Rapidcheck does not work for now -// /* -// Generates computation graphs with trivial layers and tensors, which are -// used for tests focusing on graph structures. -// */ -// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, -// [](Tensor(MultiDiOutput const &)) { -// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; -// })); -// } - -// /* -// Generates parallel computation graphs with trivial layers and tensors, -// which are used for tests focusing on graph structures. -// */ -// ParallelComputationGraph -// test_parallel_computation_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, -// [](Operator(MultiDiOutput const &)) { -// return ParallelTensor(ParallelTensorDims(TensorDims({})), -// DataType::FLOAT); -// })); -// } - -// rc::Gen small_integer_generator() { -// return rc::gen::inRange(1, 4); -// } - -// namespace rc { - -// Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), -// multidigraph_from_sp_decomposition); -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_computataion_graph); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_parallel_computation_graph); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>( -// gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Serial::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Parallel::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( -// gen::arbitrary()) -// : gen::construct( -// gen::arbitrary()); -// }); -// } -// }; - -// template -// struct Arbitrary { -// static Gen< -// std::enable_if, -// Tag>::value>::type> arbitrary() { -// return gen::construct(gen::arbitrary()); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::apply(make_1d_machine_view, -// gen::arbitrary, -// gen::arbitrary, -// small_integer_generator()); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineMapping::machine_views, -// gen::container>( -// gen::arbitrary(), -// gen::arbitrary()))); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, -// 64)), gen::set(&MachineSpecification::num_gpus_per_node, -// gen::inRange(1, 16)), -// gen::set(&MachineSpecification::inter_node_bandwidth, -// gen::nonZero()), -// gen::set(&MachineSpecification::intra_node_bandwidth, -// gen::nonZero())); -// } -// } - -// } // namespace rc - -#endif diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc deleted file mode 100644 index 59fa0f1e5e..0000000000 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ /dev/null @@ -1,132 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// // #include "rapidcheck.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, -// node_set0); auto subgraph2 = -// get_subgraph(g, node_set0); -// auto subgraph3 = get_subgraph(g, node_set0); - -// CHECK(bool(get_nodes(subgraph0) == node_set0)); -// CHECK(bool(get_nodes(subgraph1) == node_set0)); -// CHECK(bool(get_nodes(subgraph2) == node_set0)); -// CHECK(bool(get_nodes(subgraph3) == node_set0)); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK( -// bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(bool(get_closed_sources(subgraph2) == -// std::unordered_set{n3})); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(bool(get_edges(g).size() == 1)); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(bool(open_graph.at(n0) == 0)); -// CHECK(bool(open_graph.at(n1) == 1)); -// CHECK(bool(open_graph.at(e0) == 2)); - -// CHECK(get_edges(open_graph).size() == 1); -// } - -// TEST_CASE("OutputLabelledOpenMultiDiGraph") { -// OutputLabelledOpenMultiDiGraph g = -// OutputLabelledOpenMultiDiGraph::create< -// UnorderedOutputLabelledOpenMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_label(e0, 2); - -// CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); -// CHECK(bool(get_edges(g).size() == 1)); -// } -// } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc deleted file mode 100644 index 4f9b879574..0000000000 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "doctest/doctest.h" -#include "test_generator.h" - -TEST_SUITE(FF_TEST_SUITE) { - // TEST_CASE("MachineMapping::combine") { - // RC_SUBCASE([](MachineMapping const &m0, MachineMapping const &m1) { - // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - - // MachineMapping comb = MachineMapping::combine(m0, m1); - - // RC_ASSERT(comb.machine_views.size() == - // m0.machine_views.size() + m1.machine_views.size()); - // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - // }); - // } - - // TEST_CASE("OptimalCostResult::infinity") { - // RC_SUBCASE([](OptimalCostResult const &c) { - // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - // }); - // } -} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc deleted file mode 100644 index e3426aa293..0000000000 --- a/lib/compiler/test/src/test_open_graph.cc +++ /dev/null @@ -1,81 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "utils/graph/algorithms.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_source_sink_open_graph") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// NodePort p0 = g.add_node_port(); -// InputMultiDiEdge e0{ -// n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; -// g.add_edge(e0); - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -// } - -// TEST_CASE("get_source_sink_open_graph:unconnected") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; -// OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; -// g.add_edge(e0); -// g.add_edge(e1); - -// /* -// g: ->n0 -// n1-> -// */ - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -// } - -// TEST_CASE("get_cut") { -// auto g = OpenMultiDiGraph::create(); - -// std::vector ns = add_nodes(g, 5); - -// MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; -// MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; -// MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; -// OutputMultiDiEdge e5{ -// ns[4], g.add_node_port(), std::make_pair(ns[4].value(), -// ns[4].value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; -// CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, -// e2})); - -// GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; -// CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, -// e4})); -// } -// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc deleted file mode 100644 index 9c82585de4..0000000000 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ /dev/null @@ -1,56 +0,0 @@ -#include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "doctest/doctest.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "test_cost_estimator.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("optimal_cost does not crash on minimal inputs") { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape, true, "input0"); - parallel_tensor_guid_t dense0 = builder.dense(input0, - 8, - Activation::RELU, - true, - DataType::FLOAT, - std::nullopt, - std::nullopt, - "dense0"); - - ParallelComputationGraph pcg = builder.pcg; - - auto test_allowed_machine_views = [](ParallelLayerAttrs const &, - MachineSpecification const &) { - return std::unordered_set{ - make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; - }; - - CostEstimator estimator = CostEstimator::create(); - MachineSpecification machine_spec{1, 1, 1, 1, 1}; - MachineMappingCache cached_results; - MachineMappingResult result = - get_optimal_machine_mapping(pcg, - test_allowed_machine_views, - estimator, - machine_spec, - cached_results); - - CHECK(bool(result.runtime > 0)); - } -} diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc index ed5e895a75..8ff0978ea5 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -1,7 +1,5 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "test_cost_estimator.h" -#include "test_generator.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now From 150ca5e1a2bd6d16e2c6b61034f64d8b11f5fefe Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 28 Aug 2024 21:04:18 -0400 Subject: [PATCH 07/29] fmt --- .../get_optimal_machine_mapping.h | 2 +- .../machine_mapping/split_sp_decomposition.h | 4 ++-- .../get_optimal_machine_mapping.cc | 16 +++++++------- .../machine_mapping/split_sp_decomposition.cc | 4 ++-- .../machine_mapping/cost_estimator_for_test.h | 12 +++++----- .../test_get_optimal_machine_mapping.cc | 22 ++++++++----------- .../test_machine_mapping_cache.cc | 19 ++++++++++------ 7 files changed, 40 insertions(+), 39 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 466b96dcab..1c75dcaedc 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -21,7 +21,7 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult get_optimal_machine_mapping_internal(MachineMappingContext &context, - MachineSpecification const &resources); + MachineSpecification const &resources); MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, diff --git a/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h b/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h index acafcfcb0e..cab8d8d988 100644 --- a/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h +++ b/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h @@ -11,6 +11,6 @@ std::pair std::pair split_sp_decomposition(ParallelSplit const ¶llel); -} +} // namespace FlexFlow -#endif \ No newline at end of file +#endif diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index dcb51c97a9..b90531ea59 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -120,17 +120,17 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingContext context( pcg, cost_estimator, allowed_machine_views, cached_subgraph_results); - MachineMappingResult result = get_optimal_machine_mapping_internal(context, resources); + MachineMappingResult result = + get_optimal_machine_mapping_internal(context, resources); cached_subgraph_results = context.cached_subgraph_results; return result; } -MachineMappingResult - get_optimal_machine_mapping_internal(MachineMappingContext &context, - MachineSpecification const &resources) { +MachineMappingResult get_optimal_machine_mapping_internal( + MachineMappingContext &context, MachineSpecification const &resources) { std::optional decompn_optional = get_serial_parallel_decomposition(context.pcg.raw_graph); - + if (!decompn_optional) { throw mk_runtime_error("Failed to get serial parallel decomposition"); } @@ -258,9 +258,9 @@ MachineMappingResult get_optimal_machine_mapping_internal( get_optimal_machine_mapping_internal( context, decompn1, resource_split.first, fixed_machine_views1), get_optimal_machine_mapping_internal(context, - decompn2, - resource_split.second, - fixed_machine_views2))); + decompn2, + resource_split.second, + fixed_machine_views2))); } return optimal_result; diff --git a/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc b/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc index f9ab119171..b5abe383d3 100644 --- a/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc +++ b/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc @@ -1,7 +1,7 @@ #include "compiler/machine_mapping/split_sp_decomposition.h" -#include "utils/variant.h" #include "utils/containers/as_vector.h" #include "utils/containers/transform.h" +#include "utils/variant.h" namespace FlexFlow { @@ -33,4 +33,4 @@ std::pair widen(child)}; } -} +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index 90218bb322..86eb824dd3 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -7,15 +7,15 @@ namespace FlexFlow { struct CostEstimatorForTest : public ICostEstimator { inline float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override { + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs, + MachineView const &mv) const override { return 1; } inline float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { + MachineView const &src, + MachineView const &dst) const override { return 1; } }; diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc index ce84331ab7..12ec4a11b4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc @@ -1,7 +1,7 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "cost_estimator_for_test.h" #include "doctest/doctest.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "cost_estimator_for_test.h" using namespace FlexFlow; @@ -24,10 +24,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, - 8); + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); return builder.pcg; }(); @@ -48,8 +46,7 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape); + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); parallel_tensor_guid_t dense0 = builder.dense(input0, 8); parallel_tensor_guid_t dense1 = builder.dense(dense0, 8); parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); @@ -115,9 +112,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, }, DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape); + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); parallel_tensor_guid_t dense0 = builder.dense(input0, 8); parallel_tensor_guid_t dense1 = builder.dense(input0, 4); parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); @@ -127,8 +123,9 @@ TEST_SUITE(FF_TEST_SUITE) { }(); auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, - MachineSpecification const &) { - // TODO(@Mengdi Wu): Replace it with actual allowed machine views when https://github.com/flexflow/FlexFlow/pull/1458 is merged + MachineSpecification const &) { + // TODO(@Mengdi Wu): Replace it with actual allowed machine views when + // https://github.com/flexflow/FlexFlow/pull/1458 is merged return std::unordered_set{ make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; @@ -184,6 +181,5 @@ TEST_SUITE(FF_TEST_SUITE) { // CHECK(result.runtime == xx); } } - } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc index a4d4cef7f6..8999185dee 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc @@ -1,9 +1,9 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/split_sp_decomposition.h" +#include "cost_estimator_for_test.h" #include "doctest/doctest.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "cost_estimator_for_test.h" #include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "compiler/machine_mapping/split_sp_decomposition.h" using namespace FlexFlow; @@ -49,17 +49,22 @@ TEST_SUITE(FF_TEST_SUITE) { return builder.pcg; }(); - SerialParallelDecomposition subgraph0 = get_serial_parallel_decomposition(pcg.raw_graph).value(); - auto [subgraph1, subgraph2] = split_sp_decomposition(subgraph0.get()); + SerialParallelDecomposition subgraph0 = + get_serial_parallel_decomposition(pcg.raw_graph).value(); + auto [subgraph1, subgraph2] = + split_sp_decomposition(subgraph0.get()); MachineSpecification machine_spec(1, 1, 1, 1, 1); MachineMappingState state0(subgraph0, machine_spec, {}); MachineMappingState state1(subgraph1, machine_spec, {}); MachineMappingState state2(subgraph2, machine_spec, {}); - MachineMappingResult result0(2, MachineMapping(std::unordered_map{})); - MachineMappingResult result1(1, MachineMapping(std::unordered_map{})); - MachineMappingResult result2(1, MachineMapping(std::unordered_map{})); + MachineMappingResult result0( + 2, MachineMapping(std::unordered_map{})); + MachineMappingResult result1( + 1, MachineMapping(std::unordered_map{})); + MachineMappingResult result2( + 1, MachineMapping(std::unordered_map{})); MachineMappingCache cache; From fc388ce0c6ab50f196297f55f5a81b5006581a67 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Mon, 2 Sep 2024 14:41:21 -0400 Subject: [PATCH 08/29] add more tests --- .../machine_mapping/test_machine_mapping.cc | 28 ++++++ .../test_machine_mapping_result.cc | 99 +++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc index e69de29bb2..181f44ff16 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc @@ -0,0 +1,28 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "cost_estimator_for_test.h" +#include "doctest/doctest.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping result = combine(machine_mapping_0, machine_mapping_1); + CHECK(result == combined); + } + + TEST_CASE("nodes_are_disjoint") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + CHECK(nodes_are_disjoint(machine_mapping_0, machine_mapping_1)); + CHECK_FALSE(nodes_are_disjoint(machine_mapping_0, combined)); + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc index e69de29bb2..e1bba56e7f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc @@ -0,0 +1,99 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "cost_estimator_for_test.h" +#include "doctest/doctest.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sequential_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMappingResult s0(0, machine_mapping_empty); + MachineMappingResult s1(1, machine_mapping_0); + MachineMappingResult s2(2, machine_mapping_1); + + MachineMappingResult result0 = sequential_combine(s0, s1); + CHECK(result0.runtime == 1); + CHECK(result0.machine_mapping == machine_mapping_0); + + MachineMappingResult result1 = sequential_combine(s0, s2); + CHECK(result1.runtime == 2); + CHECK(result1.machine_mapping == machine_mapping_1); + + MachineMappingResult result2 = sequential_combine(s1, s2); + CHECK(result2.runtime == 3); + CHECK(result2.machine_mapping == combined); + } + + TEST_CASE("parallel_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMappingResult s0(0, machine_mapping_empty); + MachineMappingResult s1(1, machine_mapping_0); + MachineMappingResult s2(2, machine_mapping_1); + + MachineMappingResult result0 = parallel_combine(s0, s1); + CHECK(result0.runtime == 1); + CHECK(result0.machine_mapping == machine_mapping_0); + + MachineMappingResult result1 = parallel_combine(s0, s2); + CHECK(result1.runtime == 2); + CHECK(result1.machine_mapping == machine_mapping_1); + + MachineMappingResult result2 = parallel_combine(s1, s2); + CHECK(result2.runtime == 2); + CHECK(result2.machine_mapping == combined); + } + + TEST_CASE("get_infinity_machine_mapping_result") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMappingResult s0(0, machine_mapping_empty); + MachineMappingResult s1(1, machine_mapping_0); + MachineMappingResult s2(2, machine_mapping_1); + + MachineMappingResult inf = get_infinity_machine_mapping_result(); + CHECK(s0.runtime < inf.runtime); + CHECK(s1.runtime < inf.runtime); + CHECK(s2.runtime < inf.runtime); + } + + TEST_CASE("minimize_runtime") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); + MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMappingResult s0(0, machine_mapping_empty); + MachineMappingResult s1(1, machine_mapping_0); + MachineMappingResult s2(2, machine_mapping_1); + + MachineMappingResult _s0 = s0; + MachineMappingResult _s1 = s1; + MachineMappingResult _s2 = s2; + + minimize_runtime(_s0, _s1); + CHECK(_s0 == s0); + minimize_runtime(_s1, _s2); + CHECK(_s1 == s1); + + minimize_runtime(_s1, _s0); + CHECK(_s1 == s0); + + minimize_runtime(_s2, get_infinity_machine_mapping_result()); + CHECK(_s2 == s2); + } +} From e628c72a07792363d7a25cccd2f93f67291d7e08 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Mon, 2 Sep 2024 14:42:55 -0400 Subject: [PATCH 09/29] fmt --- .../machine_mapping/test_machine_mapping.cc | 6 +++-- .../test_machine_mapping_result.cc | 24 ++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc index 181f44ff16..436c723403 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc @@ -11,7 +11,8 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); MachineMapping result = combine(machine_mapping_0, machine_mapping_1); CHECK(result == combined); } @@ -21,7 +22,8 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); CHECK(nodes_are_disjoint(machine_mapping_0, machine_mapping_1)); CHECK_FALSE(nodes_are_disjoint(machine_mapping_0, combined)); } diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc index e1bba56e7f..91f7d3334f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc @@ -8,10 +8,12 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sequential_combine") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_empty( + std::unordered_map{}); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); @@ -32,10 +34,12 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("parallel_combine") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_empty( + std::unordered_map{}); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); @@ -56,10 +60,12 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_infinity_machine_mapping_result") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_empty( + std::unordered_map{}); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); @@ -73,10 +79,12 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("minimize_runtime") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_empty(std::unordered_map{}); + MachineMapping machine_mapping_empty( + std::unordered_map{}); MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined({{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + MachineMapping combined( + {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); From 8eff2b91039b5b014bf01f8062ec194aa8ed586a Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 4 Sep 2024 16:42:17 -0400 Subject: [PATCH 10/29] fix --- .../allowed_machine_mappings.h | 21 ++ .../get_optimal_machine_mapping.h | 6 +- .../machine_mapping/machine_mapping.h | 5 +- .../machine_mapping.struct.toml | 7 +- .../machine_mapping/machine_mapping_cache.h | 4 +- .../machine_mapping_context.struct.toml | 2 +- .../machine_mapping/machine_mapping_result.h | 2 +- .../machine_mapping_result.struct.toml | 2 +- .../allowed_machine_mappings.cc | 57 +++++ .../get_optimal_machine_mapping.cc | 77 ++---- .../machine_mapping/machine_mapping.cc | 3 +- .../machine_mapping/machine_mapping_cache.cc | 2 +- .../machine_mapping/machine_mapping_result.cc | 10 +- .../test_get_optimal_machine_mapping.cc | 231 +++++++++--------- .../machine_mapping/test_machine_mapping.cc | 52 ++-- .../test_machine_mapping_cache.cc | 2 +- .../test_machine_mapping_result.cc | 19 -- 17 files changed, 274 insertions(+), 228 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h create mode 100644 lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc diff --git a/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h new file mode 100644 index 0000000000..4787abece6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ +#define _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ + +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &nodes, + MachineSpecification const &resource); + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &values, + MachineSpecification const &resource); + +} + +#endif \ No newline at end of file diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 1c75dcaedc..ec514b4626 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H -#include "machine_mapping.h" -#include "machine_mapping_cache.h" -#include "machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 441f9a6034..06cbbf942d 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -1,11 +1,12 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#include "machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" namespace FlexFlow { -MachineMapping combine(MachineMapping const &, MachineMapping const &); +MachineMapping combine_disjoint_mappings(MachineMapping const &, + MachineMapping const &); bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index 4c4912a3fd..ed5b89c3da 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -9,11 +9,14 @@ features = [ "fmt", ] -includes = [ +includes = [ "utils/graph/node/node.dtg.h", "pcg/machine_view.dtg.h", +] + +src_includes = [ "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index e2824777ff..a721ea29ed 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H -#include "machine_mapping_result.dtg.h" -#include "machine_mapping_state.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/machine_mapping_state.dtg.h" #include "utils/optional.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index eea1da6ca1..a5f0b8f1f2 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -4,7 +4,7 @@ features = [ ] includes = [ - "machine_mapping.dtg.h", + "compiler/machine_mapping/machine_mapping.dtg.h", "pcg/parallel_computation_graph/parallel_computation_graph.h", "compiler/cost_estimator.h", "pcg/machine_view.h", diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 912b604e4c..62d6b7dfbb 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H -#include "machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml index e8f3a8562e..9436b9bf47 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -6,7 +6,7 @@ features = [ ] includes = [ - "machine_mapping.dtg.h", + "compiler/machine_mapping/machine_mapping.dtg.h", ] [[fields]] diff --git a/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc b/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc new file mode 100644 index 0000000000..16cd145451 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc @@ -0,0 +1,57 @@ +#include "compiler/machine_mapping/allowed_machine_mappings.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/containers.h" +#include "utils/containers/keys.h" + +namespace FlexFlow { + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &nodes, + MachineSpecification const &resource) { + if (nodes.empty()) { + return {{}}; + } + Node node = get_first(nodes); + std::vector> partial_enumeration = + allowed_machine_mappings(context, set_minus(nodes, {node}), resource); + std::unordered_set allowed_machine_views_for_node = + context.allowed_machine_views(context.pcg.raw_graph.at(node), resource); + std::vector> enumeration; + for (MachineView const &mv : allowed_machine_views_for_node) { + for (std::unordered_map const &partial : + partial_enumeration) { + enumeration.push_back(merge_maps( + partial, std::unordered_map{{node, mv}})); + } + } + return enumeration; +} + +std::vector> + allowed_machine_mappings(MachineMappingContext const &context, + std::unordered_set const &values, + MachineSpecification const &resource) { + std::unordered_set nodes; + for (DataflowOutput const &v : values) { + nodes.insert(v.node); + } + + std::vector> node_enumeration = + allowed_machine_mappings(context, nodes, resource); + std::vector> enumeration; + + for (std::unordered_map _node_enumeration : + node_enumeration) { + std::unordered_map _emumeration; + for (DataflowOutput const &v : values) { + _emumeration.emplace(v, _node_enumeration.at(v.node)); + } + enumeration.push_back(_emumeration); + } + + return enumeration; +} + +} diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index b90531ea59..d48be97bc8 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -2,6 +2,7 @@ #include "compiler/cost_estimator.h" #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/split_sp_decomposition.h" +#include "compiler/machine_mapping/allowed_machine_mappings.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" @@ -12,11 +13,9 @@ #include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" #include "utils/containers/generate_map.h" -#include "utils/containers/get_first.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/restrict_keys.h" -#include "utils/containers/set_minus.h" #include "utils/containers/values.h" #include "utils/exception.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -32,54 +31,6 @@ namespace FlexFlow { -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &nodes, - MachineSpecification const &resource) { - if (nodes.empty()) { - return {{}}; - } - Node node = get_first(nodes); - std::vector> partial_enumeration = - allowed_machine_mappings(context, set_minus(nodes, {node}), resource); - std::unordered_set allowed_machine_views_for_node = - context.allowed_machine_views(context.pcg.raw_graph.at(node), resource); - std::vector> enumeration; - for (MachineView const &mv : allowed_machine_views_for_node) { - for (std::unordered_map const &partial : - partial_enumeration) { - enumeration.push_back(merge_maps( - partial, std::unordered_map{{node, mv}})); - } - } - return enumeration; -} - -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &values, - MachineSpecification const &resource) { - std::unordered_set nodes; - for (DataflowOutput const &v : values) { - nodes.insert(v.node); - } - - std::vector> node_enumeration = - allowed_machine_mappings(context, nodes, resource); - std::vector> enumeration; - - for (std::unordered_map _node_enumeration : - node_enumeration) { - std::unordered_map _emumeration; - for (DataflowOutput const &v : values) { - _emumeration.emplace(v, _node_enumeration.at(v.node)); - } - enumeration.push_back(_emumeration); - } - - return enumeration; -} - std::vector> get_resource_split(MachineSpecification const &resource) { std::vector> result; @@ -203,12 +154,12 @@ MachineMappingResult get_optimal_machine_mapping_internal( restrict_keys(fixed_machine_views, get_open_dataflow_values(subgraph_res2.graph)); - for (auto const &[split_value, split_input] : + for (auto const &[full_graph_value, subgraph_input] : subgraph_res2.full_graph_values_to_subgraph_inputs) { MachineView mv = - split_machine_views.at(split_value.get()); - fixed_machine_views1.emplace(split_value, mv); - fixed_machine_views2.emplace(OpenDataflowValue(split_input), mv); + split_machine_views.at(full_graph_value.get()); + fixed_machine_views1.emplace(full_graph_value, mv); + fixed_machine_views2.emplace(OpenDataflowValue(subgraph_input), mv); } minimize_runtime( @@ -278,9 +229,12 @@ MachineMappingResult get_optimal_machine_mapping_internal( OpenDataflowValue any_output = OpenDataflowValue(get_outputs(context.pcg.raw_graph, node)[0]); if (contains_key(fixed_machine_views, any_output)) { - assert(contains( - context.allowed_machine_views(context.pcg.raw_graph.at(node), resource), - fixed_machine_views.at(any_output))); + { + std::unordered_set allowed_machine_views_for_node = context.allowed_machine_views( + context.pcg.raw_graph.at(node), resource); + MachineView fixed_machine_view_for_node = fixed_machine_views.at(any_output); + assert(contains(allowed_machine_views_for_node, fixed_machine_view_for_node)); + } MachineView mv = fixed_machine_views.at(any_output); MachineMapping mv_map{{{node, mv}}}; return MachineMappingResult(base_case_estimate_cost(subgraph, @@ -294,11 +248,12 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineView mv = node_machine_views.at(node); MachineMapping mv_map{{{node, mv}}}; + std::vector outputs_of_node = transform( + get_outputs(context.pcg.raw_graph, node), + [](DataflowOutput const &o) { return OpenDataflowValue(o); }); + std::unordered_map output_mv_map = - generate_map(transform(get_outputs(context.pcg.raw_graph, node), - [](DataflowOutput const &o) { - return OpenDataflowValue(o); - }), + generate_map(outputs_of_node, [&](OpenDataflowValue const &o) { return mv; }); std::unordered_map machine_views = diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index d739f3fecb..1b02485e1f 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -5,7 +5,8 @@ namespace FlexFlow { -MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { +MachineMapping combine_disjoint_mappings(MachineMapping const &s1, + MachineMapping const &s2) { return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index 232598b98d..b2b3fbc8f5 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -7,7 +7,7 @@ std::optional MachineMappingCache::load(MachineMappingState const &state) const { if (contains_key(cache, state)) { MachineMappingResult result = cache.at(state); - return std::make_optional(result); + return result; } return std::nullopt; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index c33b8776df..50eb2d8b53 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -5,14 +5,16 @@ namespace FlexFlow { MachineMappingResult sequential_combine(MachineMappingResult const &s1, MachineMappingResult const &s2) { - return MachineMappingResult{s1.runtime + s2.runtime, - combine(s1.machine_mapping, s2.machine_mapping)}; + return MachineMappingResult{ + s1.runtime + s2.runtime, + combine_disjoint_mappings(s1.machine_mapping, s2.machine_mapping)}; } MachineMappingResult parallel_combine(MachineMappingResult const &s1, MachineMappingResult const &s2) { - return MachineMappingResult{std::max(s1.runtime, s2.runtime), - combine(s1.machine_mapping, s2.machine_mapping)}; + return MachineMappingResult{ + std::max(s1.runtime, s2.runtime), + combine_disjoint_mappings(s1.machine_mapping, s2.machine_mapping)}; } MachineMappingResult get_infinity_machine_mapping_result() { diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc index 12ec4a11b4..42b9edab05 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc @@ -7,121 +7,6 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { - - ParallelComputationGraph pcg_simple = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - - return builder.pcg; - }(); - - ParallelComputationGraph pcg_chain = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - parallel_tensor_guid_t dense1 = builder.dense(dense0, 8); - parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); - parallel_tensor_guid_t dense3 = builder.dense(dense2, 8); - parallel_tensor_guid_t dense4 = builder.dense(dense3, 8); - parallel_tensor_guid_t dense5 = builder.dense(dense4, 8); - - return builder.pcg; - }(); - - ParallelComputationGraph pcg_multiple_chains = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape0 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{32, 1}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - ParallelTensorShape input_shape1 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - ShardParallelDim{8, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); - parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); - parallel_tensor_guid_t relu0 = builder.relu(input0); - parallel_tensor_guid_t relu1 = builder.relu(input1); - parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); - - return builder.pcg; - }(); - - ParallelComputationGraph pcg_non_sp = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - parallel_tensor_guid_t dense1 = builder.dense(input0, 4); - parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); - parallel_tensor_guid_t add0 = builder.add(dense0, dense2); - - return builder.pcg; - }(); - auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, MachineSpecification const &) { // TODO(@Mengdi Wu): Replace it with actual allowed machine views when @@ -134,6 +19,29 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingCache cached_results1; SUBCASE("simple PCG") { + + ParallelComputationGraph pcg_simple = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + + return builder.pcg; + }(); + MachineMappingResult result = get_optimal_machine_mapping(pcg_simple, allowed_machine_views1, @@ -147,6 +55,33 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("PCG is a chain") { + ParallelComputationGraph pcg_chain = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + parallel_tensor_guid_t dense1 = builder.dense(dense0, 8); + parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); + parallel_tensor_guid_t dense3 = builder.dense(dense2, 8); + parallel_tensor_guid_t dense4 = builder.dense(dense3, 8); + parallel_tensor_guid_t dense5 = builder.dense(dense4, 8); + + return builder.pcg; + }(); + MachineMappingResult result = get_optimal_machine_mapping(pcg_chain, allowed_machine_views1, @@ -158,6 +93,46 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("PCG has multiple chains") { + ParallelComputationGraph pcg_multiple_chains = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape0 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{32, 1}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + ParallelTensorShape input_shape1 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{8, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); + parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); + parallel_tensor_guid_t relu0 = builder.relu(input0); + parallel_tensor_guid_t relu1 = builder.relu(input1); + parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); + + return builder.pcg; + }(); + MachineMappingResult result = get_optimal_machine_mapping(pcg_multiple_chains, allowed_machine_views1, @@ -169,7 +144,33 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("PCG is not sp-izable due to multiple inputs") { + ParallelComputationGraph pcg_non_sp = [&] { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); + parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + parallel_tensor_guid_t dense1 = builder.dense(input0, 4); + parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); + parallel_tensor_guid_t add0 = builder.add(dense0, dense2); + + return builder.pcg; + }(); + // TODO: Handle this case in compiler + // TODO: separate testcases for this too that actually check the graph manipulation if (false) { MachineMappingResult result = get_optimal_machine_mapping(pcg_non_sp, diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc index 436c723403..4adc6ee558 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc @@ -6,25 +6,49 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("combine") { + TEST_CASE("combine_disjoint_mappings(MachineMapping, MachineMappping)") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); - MachineMapping result = combine(machine_mapping_0, machine_mapping_1); - CHECK(result == combined); + MachineMapping machine_mapping_0 = MachineMapping({ + {Node(0), machine_view_0}, + }); + MachineMapping machine_mapping_1 = MachineMapping({ + {Node(1), machine_view_1}, + }); + MachineMapping correct = MachineMapping({ + {Node(0), machine_view_0}, + {Node(1), machine_view_1}, + }); + MachineMapping result = + combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); } - TEST_CASE("nodes_are_disjoint") { + TEST_CASE("nodes_are_disjoint(MachineMapping, MachineMappping)") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); - CHECK(nodes_are_disjoint(machine_mapping_0, machine_mapping_1)); - CHECK_FALSE(nodes_are_disjoint(machine_mapping_0, combined)); + MachineMapping machine_mapping_0 = MachineMapping({ + {Node(0), machine_view_0}, + }); + + SUBCASE("nodes are disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {Node(1), machine_view_1}, + }); + + bool correct = true; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + SUBCASE("nodes are not disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {Node(0), machine_view_0}, + {Node(1), machine_view_1}, + }); + bool correct = false; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc index 8999185dee..cff2b1de50 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc @@ -8,7 +8,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("machine_mapping_cache") { + TEST_CASE("MachineMappingCache") { ParallelComputationGraph pcg = [&] { ParallelComputationGraphBuilder builder; diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc index 91f7d3334f..4ce650bf0a 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc @@ -57,25 +57,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result2.machine_mapping == combined); } - TEST_CASE("get_infinity_machine_mapping_result") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); - MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); - MachineMappingResult s0(0, machine_mapping_empty); - MachineMappingResult s1(1, machine_mapping_0); - MachineMappingResult s2(2, machine_mapping_1); - - MachineMappingResult inf = get_infinity_machine_mapping_result(); - CHECK(s0.runtime < inf.runtime); - CHECK(s1.runtime < inf.runtime); - CHECK(s2.runtime < inf.runtime); - } - TEST_CASE("minimize_runtime") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); From 7c03f240c51c137349834d262f3cc7f6a9adb483 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 11 Sep 2024 20:26:22 -0400 Subject: [PATCH 11/29] refactor get_optimal_machine_mapping a bit and improve the tests --- .../allowed_machine_mappings.h | 21 -- .../get_allowed_machine_views_list.h | 18 +- .../get_optimal_machine_mapping.h | 9 +- .../machine_mapping.struct.toml | 4 +- .../machine_mapping_state.struct.toml | 5 +- .../allowed_machine_mappings.cc | 57 ----- .../get_allowed_machine_views_list.cc | 65 +++-- .../get_optimal_machine_mapping.cc | 230 ++++++++++-------- .../machine_mapping/machine_mapping.cc | 1 + .../machine_mapping/machine_mapping_result.cc | 2 +- lib/compiler/src/graph_optimize_state.cc | 4 +- ...ping.cc => get_optimal_machine_mapping.cc} | 69 ++++-- ..._machine_mapping.cc => machine_mapping.cc} | 16 +- ...ping_cache.cc => machine_mapping_cache.cc} | 18 +- ...ng_result.cc => machine_mapping_result.cc} | 30 ++- ...imize_state.cc => graph_optimize_state.cc} | 18 +- ..._unity_algorithm.cc => unity_algorithm.cc} | 0 .../sub_parallel_computation_graph.h | 3 - .../sub_parallel_computation_graph.cc | 20 -- 19 files changed, 290 insertions(+), 300 deletions(-) delete mode 100644 lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h delete mode 100644 lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc rename lib/compiler/test/src/compiler/machine_mapping/{test_get_optimal_machine_mapping.cc => get_optimal_machine_mapping.cc} (76%) rename lib/compiler/test/src/compiler/machine_mapping/{test_machine_mapping.cc => machine_mapping.cc} (76%) rename lib/compiler/test/src/compiler/machine_mapping/{test_machine_mapping_cache.cc => machine_mapping_cache.cc} (88%) rename lib/compiler/test/src/compiler/machine_mapping/{test_machine_mapping_result.cc => machine_mapping_result.cc} (70%) rename lib/compiler/test/src/{test_graph_optimize_state.cc => graph_optimize_state.cc} (87%) rename lib/compiler/test/src/{test_unity_algorithm.cc => unity_algorithm.cc} (100%) diff --git a/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h deleted file mode 100644 index 4787abece6..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/allowed_machine_mappings.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ -#define _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ - -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "pcg/machine_specification.dtg.h" - -namespace FlexFlow { - -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &nodes, - MachineSpecification const &resource); - -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &values, - MachineSpecification const &resource); - -} - -#endif \ No newline at end of file diff --git a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h b/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h index 2bf87e7e9b..8681852cbe 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h +++ b/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h @@ -8,15 +8,17 @@ namespace FlexFlow { std::vector> - get_allowed_machine_views_list(MachineMappingContext const &context, - std::unordered_set const &layers, - MachineSpecification const &resource); + get_allowed_machine_views_list( + MachineMappingContext const &context, + std::unordered_set const &layers, + MachineSpecification const &resource); std::vector> - get_allowed_src_machine_views_list(MachineMappingContext const &context, - std::unordered_set const &values, - MachineSpecification const &resource); + get_allowed_src_machine_views_list( + MachineMappingContext const &context, + std::unordered_set const &values, + MachineSpecification const &resource); -} +} // namespace FlexFlow -#endif \ No newline at end of file +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index ec514b4626..102f4204fd 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -7,6 +7,7 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" namespace FlexFlow { @@ -27,28 +28,28 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialParallelDecomposition const &decompn, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views); MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialSplit const &serial, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views); MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, ParallelSplit const ¶llel, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views); MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, Node const &node, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index ed5b89c3da..8e3b94c891 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "utils/graph/node/node.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", ] @@ -21,4 +21,4 @@ src_includes = [ [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" \ No newline at end of file diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 9a1661bcd6..9142c308be 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -10,8 +10,7 @@ includes = [ "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", "pcg/machine_view.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", ] src_includes = [ @@ -29,4 +28,4 @@ type = "::FlexFlow::MachineSpecification" [[fields]] name = "fixed_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ::FlexFlow::MachineView>" +type = "std::unordered_map<::FlexFlow::parallel_tensor_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc b/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc deleted file mode 100644 index 16cd145451..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/allowed_machine_mappings.cc +++ /dev/null @@ -1,57 +0,0 @@ -#include "compiler/machine_mapping/allowed_machine_mappings.h" -#include "utils/containers/get_first.h" -#include "utils/containers/set_minus.h" -#include "utils/containers.h" -#include "utils/containers/keys.h" - -namespace FlexFlow { - -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &nodes, - MachineSpecification const &resource) { - if (nodes.empty()) { - return {{}}; - } - Node node = get_first(nodes); - std::vector> partial_enumeration = - allowed_machine_mappings(context, set_minus(nodes, {node}), resource); - std::unordered_set allowed_machine_views_for_node = - context.allowed_machine_views(context.pcg.raw_graph.at(node), resource); - std::vector> enumeration; - for (MachineView const &mv : allowed_machine_views_for_node) { - for (std::unordered_map const &partial : - partial_enumeration) { - enumeration.push_back(merge_maps( - partial, std::unordered_map{{node, mv}})); - } - } - return enumeration; -} - -std::vector> - allowed_machine_mappings(MachineMappingContext const &context, - std::unordered_set const &values, - MachineSpecification const &resource) { - std::unordered_set nodes; - for (DataflowOutput const &v : values) { - nodes.insert(v.node); - } - - std::vector> node_enumeration = - allowed_machine_mappings(context, nodes, resource); - std::vector> enumeration; - - for (std::unordered_map _node_enumeration : - node_enumeration) { - std::unordered_map _emumeration; - for (DataflowOutput const &v : values) { - _emumeration.emplace(v, _node_enumeration.at(v.node)); - } - enumeration.push_back(_emumeration); - } - - return enumeration; -} - -} diff --git a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc b/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc index d2cdb41e46..3c80d75289 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc @@ -1,57 +1,72 @@ #include "compiler/machine_mapping/get_allowed_machine_views_list.h" -#include "utils/containers/get_first.h" -#include "utils/containers/set_minus.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" #include "utils/containers.h" +#include "utils/containers/get_first.h" #include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/set_minus.h" namespace FlexFlow { std::vector> - get_allowed_machine_views_list(MachineMappingContext const &context, - std::unordered_set const &layers, - MachineSpecification const &resource) { + get_allowed_machine_views_list( + MachineMappingContext const &context, + std::unordered_set const &layers, + MachineSpecification const &resource) { if (layers.empty()) { return {{}}; } parallel_layer_guid_t curr_layer = get_first(layers); - std::unordered_set other_layers = set_minus(layers, {curr_layer}); + std::unordered_set other_layers = + set_minus(layers, {curr_layer}); - std::vector> other_machine_views_from_recursion = - allowed_machine_mappings(context, other_layers, resource); + std::vector> + other_machine_views_from_recursion = + get_allowed_machine_views_list(context, other_layers, resource); + ParallelLayerAttrs curr_layer_attrs = + get_parallel_layer_attrs(context.pcg, curr_layer); std::unordered_set allowed_machine_views_for_curr_layer = - context.allowed_machine_views(layer, resource); + context.allowed_machine_views(curr_layer_attrs, resource); std::vector> result; - for (MachineView const &for_curr_node : allowed_machine_views_for_curr_node) { - for (std::unordered_map const &for_other_nodes : - other_node_mappings_from_recursion) { - enumeration.push_back(merge_maps( - partial, std::unordered_map{{layer, mv}})); + for (MachineView const &for_curr_node : + allowed_machine_views_for_curr_layer) { + for (std::unordered_map const + &for_other_layers : other_machine_views_from_recursion) { + result.push_back( + merge_maps(for_other_layers, + std::unordered_map{ + {curr_layer, for_curr_node}})); } } return result; } std::vector> - get_allowed_src_machine_views_list(MachineMappingContext const &context, - std::unordered_set const &tensors, - MachineSpecification const &resource) { + get_allowed_src_machine_views_list( + MachineMappingContext const &context, + std::unordered_set const &tensors, + MachineSpecification const &resource) { std::unordered_set layers; for (parallel_tensor_guid_t const &tensor : tensors) { - layers.insert(get_source_layer(context.pcg, tensor)); + layers.insert(get_source_layer(tensor)); } - std::vector> machine_views_for_layers_list = - get_allowed_machine_views_list(context, layers, resource); + std::vector> + machine_views_for_layers_list = + get_allowed_machine_views_list(context, layers, resource); + std::vector> result; - for (std::unordered_map machine_views_for_layers : - machine_views_for_layers_list) { - std::unordered_map machine_views_for_tensors; + for (std::unordered_map + machine_views_for_layers : machine_views_for_layers_list) { + std::unordered_map + machine_views_for_tensors; for (parallel_tensor_guid_t const &tensor : tensors) { - machine_views_for_tensors.emplace(tensor, machine_views_for_layers.at(get_source_layer(context.pcg, v))); + machine_views_for_tensors.emplace( + tensor, machine_views_for_layers.at(get_source_layer(tensor))); } result.push_back(machine_views_for_tensors); } @@ -59,4 +74,4 @@ std::vector> return result; } -} +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index d48be97bc8..f99f7f94bd 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,8 +1,8 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "compiler/cost_estimator.h" +#include "compiler/machine_mapping/get_allowed_machine_views_list.h" #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/split_sp_decomposition.h" -#include "compiler/machine_mapping/allowed_machine_mappings.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" @@ -11,18 +11,21 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "utils/containers.h" #include "utils/containers/as_vector.h" +#include "utils/containers/contains.h" #include "utils/containers/contains_key.h" +#include "utils/containers/filter.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/exception.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" @@ -50,14 +53,46 @@ GraphSplit get_nodes(post_decomposition)}; } -float base_case_estimate_cost( - SubParallelComputationGraph const &g, +float singleton_subgraph_cost( + ParallelComputationGraph const &pcg, CostEstimator const &estimator, - std::unordered_map const &machine_views) { - // In the base case, all the operators are executed sequentially. - float cost = 0.1; - // TODO(@wmdi) - return cost; + parallel_layer_guid_t const &layer, + std::unordered_map const + &machine_views) { + // TODO: Replace it with the actual implementation. + auto get_input_shapes = [&](parallel_layer_guid_t) { + return std::vector{}; + }; + auto get_weight_attrs = [&](parallel_layer_guid_t) { + return std::vector{}; + }; + auto get_output_attrss = [&](parallel_layer_guid_t) { + return std::vector{}; + }; + + assert(contains_key(machine_views, get_layer_outputs(pcg, layer)[0])); + MachineView layer_machine_view = + machine_views.at(get_layer_outputs(pcg, layer)[0]); + float computation_cost = + estimator.estimate_cost(get_parallel_layer_attrs(pcg, layer).op_attrs, + get_input_shapes(layer), + get_weight_attrs(layer), + get_output_attrss(layer), + layer_machine_view); + float communication_cost = 0; + for (parallel_tensor_guid_t const &input : get_layer_inputs(pcg, layer)) { + assert(contains_key(machine_views, input)); + communication_cost = std::max( + communication_cost, + estimator.estimate_cost(get_parallel_tensor_attrs(pcg, input).shape, + machine_views.at(input), + layer_machine_view)); + } + std::cerr << "layer inputs: " << get_layer_inputs(pcg, layer).size() + << std::endl; + std::cerr << "computation_cost: " << computation_cost + << " communication_cost: " << communication_cost << std::endl; + return computation_cost + communication_cost; } MachineMappingResult get_optimal_machine_mapping( @@ -95,7 +130,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialParallelDecomposition const &decompn, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views) { MachineMappingState state(decompn, resource, fixed_machine_views); @@ -127,7 +162,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, SerialSplit const &serial, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views) { MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); @@ -135,42 +170,54 @@ MachineMappingResult get_optimal_machine_mapping_internal( GraphSplit graph_split = get_graph_split(decompn1, decompn2); - OpenDataflowSubgraphResult subgraph_res1 = get_subgraph( - sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = get_subgraph( - sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.second); - - std::unordered_set split_outputs = transform( - keys(subgraph_res2.full_graph_values_to_subgraph_inputs), - [](OpenDataflowValue const &v) { return v.get(); }); - - for (std::unordered_map const - &split_machine_views : - allowed_machine_mappings(context, split_outputs, resource)) { - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res2.graph)); - - for (auto const &[full_graph_value, subgraph_input] : - subgraph_res2.full_graph_values_to_subgraph_inputs) { - MachineView mv = - split_machine_views.at(full_graph_value.get()); - fixed_machine_views1.emplace(full_graph_value, mv); - fixed_machine_views2.emplace(OpenDataflowValue(subgraph_input), mv); - } - + auto is_subgraph_input = [&](std::unordered_set const &subgraph_nodes, + parallel_tensor_guid_t const &input_tensor) { + return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); + }; + + std::unordered_set all_edges1 = + set_union(transform(graph_split.first, [&](Node const &node) { + return unordered_set_of( + get_layer_outputs(context.pcg, parallel_layer_guid_t(node))); + })); + std::unordered_set all_edges2 = + set_union(transform(graph_split.second, [&](Node const &node) { + return unordered_set_of( + get_layer_inputs(context.pcg, parallel_layer_guid_t(node))); + })); + std::unordered_set split_edges = + filter(all_edges2, [&](parallel_tensor_guid_t const &input_tensor) { + return is_subgraph_input(graph_split.second, input_tensor); + }); + + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, all_edges1); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, all_edges2); + std::vector> + machine_views_list_for_split_edges = + get_allowed_src_machine_views_list(context, split_edges, resource); + + for (std::unordered_map const + &machine_views_for_split_edge : machine_views_list_for_split_edges) { minimize_runtime( optimal_result, sequential_combine( get_optimal_machine_mapping_internal( - context, decompn1, resource, fixed_machine_views1), + context, + decompn1, + resource, + merge_maps(fixed_machine_views1, machine_views_for_split_edge)), get_optimal_machine_mapping_internal( - context, decompn2, resource, fixed_machine_views2))); + context, + decompn2, + resource, + merge_maps(fixed_machine_views2, + machine_views_for_split_edge)))); } + std::cerr << "serial result: " << optimal_result.runtime << std::endl; + return optimal_result; } @@ -178,23 +225,26 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, ParallelSplit const ¶llel, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views) { auto [decompn1, decompn2] = split_sp_decomposition(parallel); GraphSplit graph_split = get_graph_split(decompn1, decompn2); - OpenDataflowSubgraphResult subgraph_res1 = get_subgraph( - sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.first); - OpenDataflowSubgraphResult subgraph_res2 = get_subgraph( - sub_pcg_from_full_pcg(context.pcg).raw_graph, graph_split.second); - - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res1.graph)); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, - get_open_dataflow_values(subgraph_res2.graph)); + std::unordered_set all_edges1 = + set_union(transform(graph_split.first, [&](Node const &node) { + return unordered_set_of( + get_layer_outputs(context.pcg, parallel_layer_guid_t(node))); + })); + std::unordered_set all_edges2 = + set_union(transform(graph_split.second, [&](Node const &node) { + return unordered_set_of( + get_layer_inputs(context.pcg, parallel_layer_guid_t(node))); + })); + std::unordered_map fixed_machine_views1 = + restrict_keys(fixed_machine_views, all_edges1); + std::unordered_map fixed_machine_views2 = + restrict_keys(fixed_machine_views, all_edges2); MachineMappingResult optimal_result = sequential_combine( get_optimal_machine_mapping_internal( @@ -214,6 +264,8 @@ MachineMappingResult get_optimal_machine_mapping_internal( fixed_machine_views2))); } + std::cerr << "parallel result: " << optimal_result.runtime << std::endl; + return optimal_result; } @@ -221,51 +273,39 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, Node const &node, MachineSpecification const &resource, - std::unordered_map const + std::unordered_map const &fixed_machine_views) { - SubParallelComputationGraph subgraph = get_pcg_subgraph(context.pcg, {node}); - - OpenDataflowValue any_output = - OpenDataflowValue(get_outputs(context.pcg.raw_graph, node)[0]); - if (contains_key(fixed_machine_views, any_output)) { - { - std::unordered_set allowed_machine_views_for_node = context.allowed_machine_views( - context.pcg.raw_graph.at(node), resource); - MachineView fixed_machine_view_for_node = fixed_machine_views.at(any_output); - assert(contains(allowed_machine_views_for_node, fixed_machine_view_for_node)); - } - MachineView mv = fixed_machine_views.at(any_output); - MachineMapping mv_map{{{node, mv}}}; - return MachineMappingResult(base_case_estimate_cost(subgraph, - context.cost_estimator, - fixed_machine_views), - mv_map); - } else { - MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); - for (std::unordered_map node_machine_views : - allowed_machine_mappings(context, {node}, resource)) { - MachineView mv = node_machine_views.at(node); - MachineMapping mv_map{{{node, mv}}}; - - std::vector outputs_of_node = transform( - get_outputs(context.pcg.raw_graph, node), - [](DataflowOutput const &o) { return OpenDataflowValue(o); }); - - std::unordered_map output_mv_map = - generate_map(outputs_of_node, - [&](OpenDataflowValue const &o) { return mv; }); - - std::unordered_map machine_views = - merge_maps(fixed_machine_views, output_mv_map); - minimize_runtime(optimal_result, - MachineMappingResult( - base_case_estimate_cost( - subgraph, context.cost_estimator, machine_views), - mv_map)); - } - return optimal_result; + parallel_layer_guid_t layer = parallel_layer_guid_t(node); + std::unordered_set machine_views_not_fixed = + set_minus(unordered_set_of(get_layer_outputs(context.pcg, layer)), + keys(fixed_machine_views)); + + std::vector> + machine_views_list_for_not_fixed = get_allowed_src_machine_views_list( + context, machine_views_not_fixed, resource); + + MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + + for (std::unordered_map const + &machine_views_for_not_fixed : machine_views_list_for_not_fixed) { + std::unordered_map full_machine_views = + merge_maps(fixed_machine_views, machine_views_for_not_fixed); + float runtime = singleton_subgraph_cost( + context.pcg, context.cost_estimator, layer, full_machine_views); + MachineMapping machine_mapping = + MachineMapping{std::unordered_map{ + {layer, + full_machine_views.at(get_layer_outputs(context.pcg, layer)[0])}, + }}; + MachineMappingResult curr_result = + MachineMappingResult(runtime, machine_mapping); + minimize_runtime(optimal_result, curr_result); } + + std::cerr << "node result: " << optimal_result.runtime << std::endl; + + return optimal_result; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 1b02485e1f..6f350d8773 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -2,6 +2,7 @@ #include "utils/containers.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 50eb2d8b53..721a3bd27f 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -20,7 +20,7 @@ MachineMappingResult parallel_combine(MachineMappingResult const &s1, MachineMappingResult get_infinity_machine_mapping_result() { return MachineMappingResult( std::numeric_limits::infinity(), - MachineMapping(std::unordered_map{})); + MachineMapping(std::unordered_map{})); } void minimize_runtime(MachineMappingResult &m1, diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc index 063d8e3ee4..71dd7f0ec1 100644 --- a/lib/compiler/src/graph_optimize_state.cc +++ b/lib/compiler/src/graph_optimize_state.cc @@ -1,4 +1,5 @@ #include "compiler/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" namespace FlexFlow { @@ -71,8 +72,7 @@ size_t hash<::FlexFlow::GraphOptimizeState>::operator()( ::FlexFlow::hash_combine(seed, inputs.size()); for (auto input : inputs) { for (size_t i = 0; i < layers.size(); ++i) { - if (get_source_layer(state.graph_optimize_result.pcg, input) == - layers[i]) { + if (get_source_layer(input) == layers[i]) { ::FlexFlow::hash_combine(seed, i); break; } diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc similarity index 76% rename from lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc rename to lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 42b9edab05..e77b6f76cf 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -9,24 +9,23 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, MachineSpecification const &) { - // TODO(@Mengdi Wu): Replace it with actual allowed machine views when - // https://github.com/flexflow/FlexFlow/pull/1458 is merged return std::unordered_set{ make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; CostEstimator estimator1 = CostEstimator::create(); - MachineSpecification machine_spec1(1, 1, 1, 1, 1); + MachineSpecification machine_spec1(2, 1, 1, 1, 1); MachineMappingCache cached_results1; SUBCASE("simple PCG") { - + ParallelComputationGraph pcg_simple = [&] { ParallelComputationGraphBuilder builder; - ParallelTensorShape input_shape = + ParallelTensorShape input_shape0 = ParallelTensorShape{ParallelTensorDims{ FFOrdered{ ShardParallelDim{32, 2}, + ShardParallelDim{32, 1}, ShardParallelDim{16, 1}, }, ReplicaParallelDimSet{ @@ -36,8 +35,25 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + ParallelTensorShape input_shape1 = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{8, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape0); + parallel_tensor_guid_t input1 = + builder.create_input_tensor(input_shape1); + parallel_tensor_guid_t dense0 = builder.batch_matmul(input0, input1); return builder.pcg; }(); @@ -49,9 +65,7 @@ TEST_SUITE(FF_TEST_SUITE) { machine_spec1, cached_results1); - CHECK(bool(result.runtime > 0)); - // TODO(@Mengdi Wu): fill it with actual cost - // CHECK(result.runtime == xx); + CHECK(result.runtime == 3); } SUBCASE("PCG is a chain") { @@ -71,13 +85,14 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - parallel_tensor_guid_t dense1 = builder.dense(dense0, 8); - parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); - parallel_tensor_guid_t dense3 = builder.dense(dense2, 8); - parallel_tensor_guid_t dense4 = builder.dense(dense3, 8); - parallel_tensor_guid_t dense5 = builder.dense(dense4, 8); + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape); + parallel_tensor_guid_t layer1 = builder.identity(input0); + parallel_tensor_guid_t layer2 = builder.identity(layer1); + parallel_tensor_guid_t layer3 = builder.identity(layer2); + parallel_tensor_guid_t layer4 = builder.identity(layer3); + parallel_tensor_guid_t layer5 = builder.identity(layer4); + parallel_tensor_guid_t layer6 = builder.identity(layer5); return builder.pcg; }(); @@ -88,8 +103,7 @@ TEST_SUITE(FF_TEST_SUITE) { estimator1, machine_spec1, cached_results1); - CHECK(bool(result.runtime > 0)); - // CHECK(result.runtime == xx); + CHECK(result.runtime == 13); } SUBCASE("PCG has multiple chains") { @@ -124,8 +138,10 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); - parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape0); + parallel_tensor_guid_t input1 = + builder.create_input_tensor(input_shape1); parallel_tensor_guid_t relu0 = builder.relu(input0); parallel_tensor_guid_t relu1 = builder.relu(input1); parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); @@ -139,8 +155,7 @@ TEST_SUITE(FF_TEST_SUITE) { estimator1, machine_spec1, cached_results1); - CHECK(bool(result.runtime > 0)); - // CHECK(result.runtime == xx); + CHECK(result.runtime == 5); } SUBCASE("PCG is not sp-izable due to multiple inputs") { @@ -160,7 +175,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT}; - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape); + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape); parallel_tensor_guid_t dense0 = builder.dense(input0, 8); parallel_tensor_guid_t dense1 = builder.dense(input0, 4); parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); @@ -170,7 +186,8 @@ TEST_SUITE(FF_TEST_SUITE) { }(); // TODO: Handle this case in compiler - // TODO: separate testcases for this too that actually check the graph manipulation + // TODO: separate testcases for this too that actually check the graph + // manipulation if (false) { MachineMappingResult result = get_optimal_machine_mapping(pcg_non_sp, @@ -179,7 +196,7 @@ TEST_SUITE(FF_TEST_SUITE) { machine_spec1, cached_results1); CHECK(bool(result.runtime > 0)); - // CHECK(result.runtime == xx); + CHECK(result.runtime == 7); } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc similarity index 76% rename from lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc rename to lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index 4adc6ee558..ffd20c429a 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -10,14 +10,14 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMapping machine_mapping_0 = MachineMapping({ - {Node(0), machine_view_0}, + {parallel_layer_guid_t(Node(0)), machine_view_0}, }); MachineMapping machine_mapping_1 = MachineMapping({ - {Node(1), machine_view_1}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, }); MachineMapping correct = MachineMapping({ - {Node(0), machine_view_0}, - {Node(1), machine_view_1}, + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, }); MachineMapping result = combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); @@ -28,12 +28,12 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMapping machine_mapping_0 = MachineMapping({ - {Node(0), machine_view_0}, + {parallel_layer_guid_t(Node(0)), machine_view_0}, }); SUBCASE("nodes are disjoint") { MachineMapping machine_mapping_1 = MachineMapping({ - {Node(1), machine_view_1}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, }); bool correct = true; @@ -43,8 +43,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nodes are not disjoint") { MachineMapping machine_mapping_1 = MachineMapping({ - {Node(0), machine_view_0}, - {Node(1), machine_view_1}, + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, }); bool correct = false; bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc similarity index 88% rename from lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc rename to lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc index cff2b1de50..9c09efe3fa 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_cache.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -60,23 +60,29 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingState state2(subgraph2, machine_spec, {}); MachineMappingResult result0( - 2, MachineMapping(std::unordered_map{})); + 2, + MachineMapping( + std::unordered_map{})); MachineMappingResult result1( - 1, MachineMapping(std::unordered_map{})); + 1, + MachineMapping( + std::unordered_map{})); MachineMappingResult result2( - 1, MachineMapping(std::unordered_map{})); + 1, + MachineMapping( + std::unordered_map{})); MachineMappingCache cache; cache.save(state0, result0); CHECK(cache.load(state0).value() == result0); - CHECK(cache.load(state1) == std::nullopt); - CHECK(cache.load(state2) == std::nullopt); + CHECK(!cache.load(state1)); + CHECK(!cache.load(state2)); cache.save(state1, result1); CHECK(cache.load(state0).value() == result0); CHECK(cache.load(state1).value() == result1); - CHECK(cache.load(state2) == std::nullopt); + CHECK(!cache.load(state2)); cache.save(state2, result2); CHECK(cache.load(state0).value() == result0); diff --git a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc similarity index 70% rename from lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc rename to lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 4ce650bf0a..0157f73ef3 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/test_machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -8,12 +8,14 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sequential_combine") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + std::unordered_map{}); + MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + MachineMapping machine_mapping_1({{layer1, machine_view_1}}); MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + {{layer0, machine_view_0}, {layer1, machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); @@ -34,12 +36,14 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("parallel_combine") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + std::unordered_map{}); + MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + MachineMapping machine_mapping_1({{layer1, machine_view_1}}); MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + {{layer0, machine_view_0}, {layer1, machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); @@ -60,12 +64,14 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("minimize_runtime") { MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{Node(0), machine_view_0}}); - MachineMapping machine_mapping_1({{Node(1), machine_view_1}}); + std::unordered_map{}); + MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + MachineMapping machine_mapping_1({{layer1, machine_view_1}}); MachineMapping combined( - {{Node(0), machine_view_0}, {Node(1), machine_view_1}}); + {{layer0, machine_view_0}, {layer1, machine_view_1}}); MachineMappingResult s0(0, machine_mapping_empty); MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); diff --git a/lib/compiler/test/src/test_graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc similarity index 87% rename from lib/compiler/test/src/test_graph_optimize_state.cc rename to lib/compiler/test/src/graph_optimize_state.cc index 49c4f9958f..fa8385e560 100644 --- a/lib/compiler/test/src/test_graph_optimize_state.cc +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -46,12 +46,14 @@ TEST_SUITE(FF_TEST_SUITE) { // `machine_mapping` is determined by the PCG and the device mapping // algorithm, and `runtime` is determined by the PCG and the device mapping, // so their values here do not matter. - std::unordered_map empty_machine_views; + std::unordered_map empty_machine_views; MachineMapping empty_machine_mapping(empty_machine_views); - CHECK( + bool result1 = GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0) == - GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0)); + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0); + bool correct1 = true; + CHECK(result1 == correct1); ParallelComputationGraphBuilder builder_; @@ -68,9 +70,11 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg_ = builder.pcg; - CHECK(GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), - 0) != - GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), - 0)); + bool result2 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0); + bool correct2 = false; + CHECK(result2 == correct2); } } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc similarity index 100% rename from lib/compiler/test/src/test_unity_algorithm.cc rename to lib/compiler/test/src/unity_algorithm.cc diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index ef3fddef63..00032045c0 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -27,9 +27,6 @@ SubParallelComputationGraph ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &); -SubParallelComputationGraph get_pcg_subgraph(ParallelComputationGraph const &, - std::unordered_set const &); - parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name); diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index c6e75f1841..0bbe0e97a7 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -62,26 +62,6 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( // }; } -SubParallelComputationGraph - get_pcg_subgraph(ParallelComputationGraph const &pcg, - std::unordered_set const &nodes) { - auto as_open = view_as_labelled_open_dataflow_graph(pcg.raw_graph); - OpenDataflowSubgraphResult subgraph_result = get_subgraph(as_open, nodes); - return SubParallelComputationGraph{with_labelling( - subgraph_result.graph, - generate_map(nodes, [&](Node const &node) { return as_open.at(node); }), - generate_map(get_open_dataflow_values(subgraph_result.graph), - [&](OpenDataflowValue const &value) { - if (value.has()) { - return as_open.at( - subgraph_result.full_graph_values_to_subgraph_inputs - .at_r(value.get())); - } else { - return as_open.at(value); - } - }))}; -} - parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name) { From 89ed108f4dad59f9fb9e45b25adb64e4de855280 Mon Sep 17 00:00:00 2001 From: Mengdi Wu Date: Wed, 11 Sep 2024 20:44:17 -0400 Subject: [PATCH 12/29] remove debug codes --- .../machine_mapping/get_optimal_machine_mapping.cc | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index f99f7f94bd..4894e5f5de 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -88,10 +88,6 @@ float singleton_subgraph_cost( machine_views.at(input), layer_machine_view)); } - std::cerr << "layer inputs: " << get_layer_inputs(pcg, layer).size() - << std::endl; - std::cerr << "computation_cost: " << computation_cost - << " communication_cost: " << communication_cost << std::endl; return computation_cost + communication_cost; } @@ -216,8 +212,6 @@ MachineMappingResult get_optimal_machine_mapping_internal( machine_views_for_split_edge)))); } - std::cerr << "serial result: " << optimal_result.runtime << std::endl; - return optimal_result; } @@ -264,8 +258,6 @@ MachineMappingResult get_optimal_machine_mapping_internal( fixed_machine_views2))); } - std::cerr << "parallel result: " << optimal_result.runtime << std::endl; - return optimal_result; } @@ -303,8 +295,6 @@ MachineMappingResult get_optimal_machine_mapping_internal( minimize_runtime(optimal_result, curr_result); } - std::cerr << "node result: " << optimal_result.runtime << std::endl; - return optimal_result; } From 3a3595163736d85f3ccddb934a97cccc04415ae9 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 25 Sep 2024 21:55:58 -0700 Subject: [PATCH 13/29] A lot of simplifying and modularizing of unity dp code --- .proj.toml | 8 +- .../comm_cost_estimate_key.struct.toml | 25 + .../include/compiler/cost_estimator.h | 21 +- .../estimate_cost_across_split.h | 18 + .../machine_mapping/estimate_layer_cost.h | 16 + .../get_allowed_machine_views_list.h | 2 + .../get_machine_resource_splits.h | 15 + .../get_optimal_machine_mapping.h | 46 +- .../machine_mapping.struct.toml | 2 +- .../machine_mapping/machine_mapping_context.h | 23 + .../machine_mapping_context.struct.toml | 17 +- .../machine_mapping/machine_mapping_result.h | 1 + .../machine_mapping_result.struct.toml | 1 + .../machine_mapping_state.struct.toml | 16 +- .../machine_mapping/partial_machine_mapping.h | 25 + .../partial_machine_mapping.struct.toml | 21 + .../machine_mapping/split_sp_decomposition.h | 16 - .../machine_mapping/transitive_reduced_pcg.h | 36 ++ .../transitive_reduced_pcg.struct.toml | 16 + .../compiler/op_cost_estimate_key.struct.toml | 40 ++ ...get_pcg_balanced_binary_sp_decomposition.h | 12 + .../get_pcg_series_parallel_decomposition.h | 15 + .../pcg_binary_parallel_split.h | 14 + .../pcg_binary_parallel_split.struct.toml | 22 + .../series_parallel/pcg_binary_series_split.h | 14 + .../pcg_binary_series_split.struct.toml | 22 + .../pcg_binary_sp_decomposition.h | 45 ++ .../pcg_binary_sp_decomposition.struct.toml | 22 + .../estimate_cost_across_split.cc | 37 ++ .../machine_mapping/estimate_layer_cost.cc | 27 ++ .../get_machine_resource_splits.cc | 18 + .../get_optimal_machine_mapping.cc | 411 +++++++--------- .../machine_mapping_context.cc | 31 ++ .../machine_mapping/machine_mapping_result.cc | 3 +- .../partial_machine_mapping.cc | 30 ++ .../machine_mapping/split_sp_decomposition.cc | 36 -- .../machine_mapping/transitive_reduced_pcg.cc | 84 ++++ .../get_pcg_series_parallel_decomposition.cc | 11 + .../pcg_binary_parallel_split.cc | 19 + .../pcg_binary_series_split.cc | 19 + .../pcg_binary_sp_decomposition.cc | 15 + .../cost_estimator_for_test.cc | 78 +++ .../machine_mapping/cost_estimator_for_test.h | 48 +- .../estimate_cost_across_split.cc | 36 ++ .../machine_mapping/estimate_layer_cost.cc | 124 +++++ .../get_machine_resource_splits.cc | 213 +++++++++ .../get_optimal_machine_mapping.cc | 451 +++++++++++------- .../machine_mapping/transitive_reduced_pcg.cc | 26 + lib/kernels/include/kernels/linear_kernels.h | 2 +- lib/local-execution/src/ops/batch_matmul.h | 2 +- lib/local-execution/src/ops/batch_norm.h | 2 +- lib/local-execution/src/ops/cast.h | 2 +- lib/local-execution/src/ops/combine.h | 2 +- lib/local-execution/src/ops/concat.h | 2 +- lib/local-execution/src/ops/conv_2d.h | 2 +- lib/local-execution/src/ops/dropout.h | 2 +- lib/local-execution/src/ops/element_binary.h | 2 +- lib/local-execution/src/ops/element_unary.h | 2 +- lib/local-execution/src/ops/embedding.h | 2 +- lib/local-execution/src/ops/flat.h | 2 +- lib/local-execution/src/ops/gather.h | 2 +- lib/local-execution/src/ops/input.h | 2 +- lib/local-execution/src/ops/layer_norm.h | 2 +- lib/local-execution/src/ops/linear.h | 2 +- lib/local-execution/src/ops/noop.h | 4 +- lib/local-execution/src/ops/pool_2d.h | 2 +- lib/local-execution/src/ops/reduce.h | 2 +- lib/local-execution/src/ops/reduction.h | 2 +- lib/local-execution/src/ops/repartition.h | 2 +- lib/local-execution/src/ops/replicate.h | 2 +- lib/local-execution/src/ops/reshape.h | 2 +- lib/local-execution/src/ops/reverse.h | 2 +- lib/local-execution/src/ops/softmax.h | 2 +- lib/local-execution/src/ops/split.h | 2 +- lib/local-execution/src/ops/topk.h | 2 +- lib/local-execution/src/ops/transpose.h | 2 +- .../include/op-attrs/operator_attrs.h | 50 -- lib/op-attrs/src/operator_attrs.cc | 287 ----------- .../parallel_computation_graph.h | 7 + .../parallel_layer_attrs.struct.toml | 2 +- .../parallel_computation_graph.cc | 17 + .../parallel_computation_graph_builder.cc | 19 +- .../parallel_computation_graph.cc | 3 + .../parallel_computation_graph_builder.cc | 1 + lib/runtime/test/src/test_serialization.cc | 1 - .../perform_shape_inference.h | 1 + .../perform_shape_inference.cc | 1 + lib/utils/include/utils/containers.decl.h | 3 - lib/utils/include/utils/containers.h | 28 -- .../utils/containers/cartesian_product.h | 39 ++ .../utils/containers/get_all_assignments.h | 43 ++ .../include/utils/containers/transform.h | 26 +- .../containers/unordered_map_from_pairs.h | 16 + .../get_dataflow_edges_from_node_to_node.h | 14 + .../get_edges_from_subgraph_to_subgraph.h | 13 + .../src/utils/containers/cartesian_product.cc | 1 + .../utils/containers/get_all_assignments.cc | 1 + .../containers/unordered_map_from_pairs.cc | 1 + .../get_dataflow_edges_from_node_to_node.cc | 16 + .../get_edges_from_subgraph_to_subgraph.cc | 20 + .../src/utils/containers/cartesian_product.cc | 62 +++ .../utils/containers/get_all_assignments.cc | 50 ++ .../containers/unordered_map_from_pairs.cc | 53 ++ .../get_dataflow_edges_from_node_to_node.cc | 97 ++++ .../get_edges_from_subgraph_to_subgraph.cc | 131 +++++ 105 files changed, 2370 insertions(+), 939 deletions(-) create mode 100644 lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h create mode 100644 lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h create mode 100644 lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h create mode 100644 lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml create mode 100644 lib/compiler/include/compiler/op_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml create mode 100644 lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc create mode 100644 lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc delete mode 100644 lib/op-attrs/include/op-attrs/operator_attrs.h delete mode 100644 lib/op-attrs/src/operator_attrs.cc create mode 100644 lib/utils/include/utils/containers/cartesian_product.h create mode 100644 lib/utils/include/utils/containers/get_all_assignments.h create mode 100644 lib/utils/include/utils/containers/unordered_map_from_pairs.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h create mode 100644 lib/utils/src/utils/containers/cartesian_product.cc create mode 100644 lib/utils/src/utils/containers/get_all_assignments.cc create mode 100644 lib/utils/src/utils/containers/unordered_map_from_pairs.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc create mode 100644 lib/utils/test/src/utils/containers/cartesian_product.cc create mode 100644 lib/utils/test/src/utils/containers/get_all_assignments.cc create mode 100644 lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc diff --git a/.proj.toml b/.proj.toml index 5592f184ad..22649424f8 100644 --- a/.proj.toml +++ b/.proj.toml @@ -22,11 +22,11 @@ test_targets = [ "utils-tests", "op-attrs-tests", "pcg-tests", - "substitutions-tests", + # "substitutions-tests", "compiler-tests", - "substitution-generator-tests", - "local-execution-tests", - "models-tests", + # "substitution-generator-tests", + # "local-execution-tests", + # "models-tests", ] [cmake_flags_extra] diff --git a/lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..ae45b493c3 --- /dev/null +++ b/lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "CommCostEstimateKey" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_view" +type = "::FlexFlow::MachineView" + +[[fields]] +name = "dst_machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator.h index 2e4ff8448b..52e82ad8d5 100644 --- a/lib/compiler/include/compiler/cost_estimator.h +++ b/lib/compiler/include/compiler/cost_estimator.h @@ -1,19 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_H -#ifndef _FLEXFLOW_COMPILER_COST_ESTIMATE_H -#define _FLEXFLOW_COMPILER_COST_ESTIMATE_H - -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/machine_view.dtg.h" namespace FlexFlow { struct ICostEstimator { virtual float estimate_cost(PCGOperatorAttrs const &op, std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, + std::vector const &weights, + std::vector const &outputs, MachineView const &mv) const = 0; virtual float estimate_cost(ParallelTensorShape const &tensor_shape, MachineView const &src, @@ -30,8 +29,8 @@ CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); struct CostEstimator { float estimate_cost(PCGOperatorAttrs const &op, std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, + std::vector const &weights, + std::vector const &outputs, MachineView const &mv) const { return this->implementation_ptr->estimate_cost( op, inputs, weights, outputs, mv); diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h b/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h new file mode 100644 index 0000000000..a1f061b15b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H + +#include "compiler/cost_estimator.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +float estimate_cost_across_split(TransitiveReducedPCG const &, + CostEstimator const &, + std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h new file mode 100644 index 0000000000..69370aabda --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H + +#include "compiler/cost_estimator.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +namespace FlexFlow { + +float estimate_layer_cost(ParallelComputationGraph const &pcg, + CostEstimator const &cost_estimator, + parallel_layer_guid_t const &layer, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h b/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h index 8681852cbe..1da08daf1b 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h +++ b/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h @@ -3,6 +3,8 @@ #include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "pcg/machine_specification.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h new file mode 100644 index 0000000000..2800c0a353 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H + +#include "pcg/machine_specification.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 102f4204fd..7b4ba275a2 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -4,6 +4,9 @@ #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" +#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -21,36 +24,37 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &cached_subgraph_results); MachineMappingResult - get_optimal_machine_mapping_internal(MachineMappingContext &context, + get_optimal_machine_mapping_internal(MachineMappingCache &result_cache, + MachineMappingContext const &context, MachineSpecification const &resources); MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - SerialParallelDecomposition const &decompn, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views); + MachineMappingCache &result_cache, + MachineMappingContext const &context, + PCGBinarySPDecomposition const &sp_decomposition, + MachineSpecification const &resources, + PartialMachineMapping const &); MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - SerialSplit const &serial, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views); + MachineMappingCache &result_cache, + MachineMappingContext const &context, + PCGBinarySeriesSplit const &series, + MachineSpecification const &resources, + PartialMachineMapping const &); MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - ParallelSplit const ¶llel, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views); + MachineMappingCache &result_cache, + MachineMappingContext const &context, + PCGBinaryParallelSplit const ¶llel, + MachineSpecification const &resources, + PartialMachineMapping const &); MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - Node const &node, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views); + MachineMappingCache &result_cache, + MachineMappingContext const &, + parallel_layer_guid_t const &, + MachineSpecification const &, + PartialMachineMapping const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index 8e3b94c891..92517c1110 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -21,4 +21,4 @@ src_includes = [ [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h new file mode 100644 index 0000000000..894f935015 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONTEXT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONTEXT_H + +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_allowed_machine_views_for_tensor(MachineMappingContext const &, + parallel_tensor_guid_t const &); +std::unordered_set get_allowed_machine_views_for_layer(MachineMappingContext const &, + parallel_layer_guid_t const &); + +MachineMappingContext make_machine_mapping_context(ParallelComputationGraph const &pcg, + CostEstimator const &cost_estimator, + std::function( + ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index a5f0b8f1f2..270d57fe98 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -1,20 +1,17 @@ namespace = "FlexFlow" name = "MachineMappingContext" -features = [ -] +features = [] includes = [ - "compiler/machine_mapping/machine_mapping.dtg.h", - "pcg/parallel_computation_graph/parallel_computation_graph.h", "compiler/cost_estimator.h", - "pcg/machine_view.h", + "pcg/machine_view.dtg.h", "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/machine_mapping_cache.h" + "compiler/machine_mapping/transitive_reduced_pcg.dtg.h", ] [[fields]] -name = "pcg" -type = "::FlexFlow::ParallelComputationGraph" +name = "transitive_reduced_pcg" +type = "::FlexFlow::TransitiveReducedPCG" [[fields]] name = "cost_estimator" @@ -23,7 +20,3 @@ type = "::FlexFlow::CostEstimator" [[fields]] name = "allowed_machine_views" type = "std::function(::FlexFlow::ParallelLayerAttrs const &, ::FlexFlow::MachineSpecification const &)>" - -[[fields]] -name = "cached_subgraph_results" -type = "::FlexFlow::MachineMappingCache" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 62d6b7dfbb..621285ae16 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -6,6 +6,7 @@ namespace FlexFlow { MachineMappingResult sequential_combine(MachineMappingResult const &s1, + float comm_cost, MachineMappingResult const &s2); MachineMappingResult parallel_combine(MachineMappingResult const &s1, MachineMappingResult const &s2); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml index 9436b9bf47..f2f2e15e9a 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -2,6 +2,7 @@ namespace = "FlexFlow" name = "MachineMappingResult" features = [ "eq", + "hash", "fmt", ] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 9142c308be..0fcb065b10 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -7,25 +7,19 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", - "pcg/machine_view.dtg.h", - "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", -] - -src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h", + "compiler/machine_mapping/partial_machine_mapping.dtg.h", ] [[fields]] name = "subgraph" -type = "::FlexFlow::SerialParallelDecomposition" +type = "::FlexFlow::PCGBinarySPDecomposition" [[fields]] name = "resource" type = "::FlexFlow::MachineSpecification" [[fields]] -name = "fixed_machine_views" -type = "std::unordered_map<::FlexFlow::parallel_tensor_guid_t, ::FlexFlow::MachineView>" +name = "partial_solution" +type = "::FlexFlow::PartialMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h new file mode 100644 index 0000000000..45be24974f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" + +namespace FlexFlow { + +PartialMachineMapping get_unconstrained_solution(); + +PartialMachineMapping get_sub_solution(MachineMappingContext const &ctx, + PartialMachineMapping const &partial_solution, + PCGBinarySPDecomposition const &sub_problem); + +PartialMachineMapping with_additional_tensor_machine_views(MachineMappingContext const &ctx, + PartialMachineMapping const &partial_solution, + std::unordered_map const &additional); + +PartialMachineMapping with_additional_layer_machine_views(MachineMappingContext const &ctx, + PartialMachineMapping const &partial_solution, + std::unordered_map const &additional); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml new file mode 100644 index 0000000000..f63c51a4c4 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "PartialMachineMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h b/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h deleted file mode 100644 index cab8d8d988..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/split_sp_decomposition.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_SPLIT_SP_DECOMPOSITION_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_SPLIT_SP_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -std::pair - split_sp_decomposition(SerialSplit const &serial); - -std::pair - split_sp_decomposition(ParallelSplit const ¶llel); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h new file mode 100644 index 0000000000..fcd3b47204 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H + +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const &); + +std::unordered_set get_transitive_reduced_predecessors(TransitiveReducedPCG const &, + parallel_layer_guid_t const &); +std::unordered_set get_transitive_reduced_successors(TransitiveReducedPCG const &, + parallel_layer_guid_t const &); + +std::unordered_set + get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + +std::unordered_set + get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + +std::pair< + std::unordered_set, + std::unordered_set +> get_split_transitive_reduced_boundary_layers(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml new file mode 100644 index 0000000000..bb76ec2ff7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TransitiveReducedPCG" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "full_pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/compiler/include/compiler/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..8fd860d00d --- /dev/null +++ b/lib/compiler/include/compiler/op_cost_estimate_key.struct.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "OpCostEstimateKey" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h new file mode 100644 index 0000000000..65a7a69ef8 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition( + ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h new file mode 100644 index 0000000000..04f84d76fd --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition( + ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h new file mode 100644 index 0000000000..0bbc81ded3 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H + +#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" + +namespace FlexFlow { + +PCGBinarySPDecomposition get_left_child(PCGBinaryParallelSplit const &); +PCGBinarySPDecomposition get_right_child(PCGBinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..75e1fec52f --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h new file mode 100644 index 0000000000..196e0e502c --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H + +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" + +namespace FlexFlow { + +PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &); +PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml new file mode 100644 index 0000000000..63fc7562cd --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h new file mode 100644 index 0000000000..a75f1a8116 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &); + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); + +PCGBinarySeriesSplit make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); +PCGBinarySeriesSplit make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); +PCGBinarySeriesSplit make_pcg_leaf_node(parallel_layer_guid_t const &); + +PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &); +PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &); +parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &); + +template +ReturnType visit(PCGBinarySPDecomposition const &d, F &&f) { + SPDecompositionTreeNodeType node_type = get_node_type(d); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return f(require_series(d)); + case SPDecompositionTreeNodeType::PARALLEL: + return f(require_parallel(d)); + case SPDecompositionTreeNodeType::NODE: + return f(require_leaf(d)); + default: + throw mk_runtime_error(fmt::format("Unknown node type {}", node_type)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml new file mode 100644 index 0000000000..c9950bf3f4 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc b/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc new file mode 100644 index 0000000000..8d7d6ccd03 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc @@ -0,0 +1,37 @@ +#include "compiler/machine_mapping/estimate_cost_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" + +namespace FlexFlow { + +float estimate_cost_across_split(TransitiveReducedPCG const &tr_pcg, + CostEstimator const &cost_estimator, + std::unordered_map const &pre_machine_views, + std::unordered_map const &post_machine_views) { + std::unordered_set + edges_across_split = get_transitive_reduced_edges_across_split(tr_pcg, + keys(pre_machine_views), + keys(post_machine_views)); + + auto get_cost_of_edge = [&](ParallelComputationGraphEdge const &e) { + MachineView src_view = pre_machine_views.at(get_src_layer(e)); + MachineView dst_view = post_machine_views.at(get_dst_layer(e)); + ParallelTensorShape tensor_shape = get_parallel_tensor_shape(tr_pcg.full_pcg, + get_parallel_tensor(e)); + + return cost_estimator.estimate_cost(tensor_shape, src_view, dst_view); + }; + + // note this is only correct for certain split types, and for others (tensor reuse, etc.) this is + // an overapproximation. This should eventually get fixed. + return sum(transform(edges_across_split, get_cost_of_edge)); +} + + +} // namespace FlexFlow + + diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc new file mode 100644 index 0000000000..02d31ec7f0 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -0,0 +1,27 @@ +#include "compiler/machine_mapping/estimate_layer_cost.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" + +namespace FlexFlow { + +float estimate_layer_cost(ParallelComputationGraph const &pcg, + CostEstimator const &cost_estimator, + parallel_layer_guid_t const &layer, + MachineView const &machine_view) { + PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, layer).op_attrs; + + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { + return get_parallel_tensor_shape(pcg, t); + }; + + std::vector input_tensors = get_incoming_inputs(pcg, layer); + std::vector weight_tensors = get_incoming_weights(pcg, layer); + std::vector output_tensors = get_layer_outputs(pcg, layer); + + return cost_estimator.estimate_cost(op_attrs, + transform(input_tensors, get_tensor_shape), + transform(weight_tensors, get_tensor_shape), + transform(output_tensors, get_tensor_shape), + machine_view); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..c77d53a928 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "utils/hash/pair.h" + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource) { + std::unordered_set> result; + for (int i = 1; i < resource.num_nodes; ++i) { + MachineSpecification sub_resource1 = resource, sub_resource2 = resource; + sub_resource1.num_nodes = i; + sub_resource2.num_nodes = resource.num_nodes - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + } + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 4894e5f5de..e20e3a1883 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,95 +1,34 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "compiler/cost_estimator.h" #include "compiler/machine_mapping/get_allowed_machine_views_list.h" +#include "compiler/machine_mapping/get_machine_resource_splits.h" #include "compiler/machine_mapping/machine_mapping_result.h" -#include "compiler/machine_mapping/split_sp_decomposition.h" +#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" +#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.h" -#include "utils/containers.h" -#include "utils/containers/as_vector.h" #include "utils/containers/contains.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/filter.h" #include "utils/containers/generate_map.h" -#include "utils/containers/get_only.h" +#include "utils/containers/get_all_assignments.h" #include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/restrict_keys.h" -#include "utils/containers/set_minus.h" #include "utils/containers/unordered_set_of.h" -#include "utils/containers/values.h" #include "utils/exception.h" -#include "utils/graph/dataflow_graph/algorithms.h" -#include "utils/graph/graph_split.dtg.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" #include "utils/overload.h" +#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg_binary_series_split.h" +#include "compiler/machine_mapping/machine_mapping_context.h" +#include "utils/containers/flatmap.h" +#include "compiler/machine_mapping/estimate_layer_cost.h" -namespace FlexFlow { - -std::vector> - get_resource_split(MachineSpecification const &resource) { - std::vector> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; - sub_resource1.num_nodes = i; - sub_resource2.num_nodes = resource.num_nodes - i; - result.push_back(std::make_pair(sub_resource1, sub_resource2)); - } - return result; -} - -GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; -} -float singleton_subgraph_cost( - ParallelComputationGraph const &pcg, - CostEstimator const &estimator, - parallel_layer_guid_t const &layer, - std::unordered_map const - &machine_views) { - // TODO: Replace it with the actual implementation. - auto get_input_shapes = [&](parallel_layer_guid_t) { - return std::vector{}; - }; - auto get_weight_attrs = [&](parallel_layer_guid_t) { - return std::vector{}; - }; - auto get_output_attrss = [&](parallel_layer_guid_t) { - return std::vector{}; - }; - - assert(contains_key(machine_views, get_layer_outputs(pcg, layer)[0])); - MachineView layer_machine_view = - machine_views.at(get_layer_outputs(pcg, layer)[0]); - float computation_cost = - estimator.estimate_cost(get_parallel_layer_attrs(pcg, layer).op_attrs, - get_input_shapes(layer), - get_weight_attrs(layer), - get_output_attrss(layer), - layer_machine_view); - float communication_cost = 0; - for (parallel_tensor_guid_t const &input : get_layer_inputs(pcg, layer)) { - assert(contains_key(machine_views, input)); - communication_cost = std::max( - communication_cost, - estimator.estimate_cost(get_parallel_tensor_attrs(pcg, input).shape, - machine_views.at(input), - layer_machine_view)); - } - return computation_cost + communication_cost; -} +namespace FlexFlow { MachineMappingResult get_optimal_machine_mapping( ParallelComputationGraph const &pcg, @@ -98,164 +37,196 @@ MachineMappingResult get_optimal_machine_mapping( &allowed_machine_views, CostEstimator const &cost_estimator, MachineSpecification const &resources, - MachineMappingCache &cached_subgraph_results) { + MachineMappingCache &result_cache) { + + MachineMappingContext context = make_machine_mapping_context( + pcg, + cost_estimator, + allowed_machine_views); - MachineMappingContext context( - pcg, cost_estimator, allowed_machine_views, cached_subgraph_results); MachineMappingResult result = - get_optimal_machine_mapping_internal(context, resources); - cached_subgraph_results = context.cached_subgraph_results; + get_optimal_machine_mapping_internal(result_cache, context, resources); + return result; } MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, MachineSpecification const &resources) { - std::optional decompn_optional = - get_serial_parallel_decomposition(context.pcg.raw_graph); - - if (!decompn_optional) { - throw mk_runtime_error("Failed to get serial parallel decomposition"); - } - - SerialParallelDecomposition decompn = decompn_optional.value(); - - return get_optimal_machine_mapping_internal(context, decompn, resources, {}); + MachineMappingCache &result_cache, + MachineMappingContext &context, + MachineSpecification const &resources) { + + PCGBinarySPDecomposition sp_decomposition_tree = ({ + std::optional returned = get_pcg_balanced_binary_sp_decomposition(context.transitive_reduced_pcg.full_pcg); + if (!returned.has_value()) { + throw mk_runtime_error("Failed to get serial parallel decomposition"); + } + returned.value(); + }); + + return get_optimal_machine_mapping_internal(result_cache, + context, + sp_decomposition_tree, + resources, + get_unconstrained_solution()); } MachineMappingResult get_optimal_machine_mapping_internal( + MachineMappingCache &result_cache, MachineMappingContext &context, - SerialParallelDecomposition const &decompn, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - - MachineMappingState state(decompn, resource, fixed_machine_views); - std::optional cached_result = - context.cached_subgraph_results.load(state); - if (cached_result) { - return cached_result.value(); + PCGBinarySPDecomposition const &sp_decomposition_tree, + MachineSpecification const &resources, + PartialMachineMapping const &partial_solution) { + + MachineMappingState state = MachineMappingState{ + sp_decomposition_tree, resources, partial_solution, + }; + + { + std::optional cached_result = + result_cache.load(state); + if (cached_result) { + return cached_result.value(); + } } - MachineMappingResult result = decompn.visit( - overload{[&](SerialSplit const &serial) { - return get_optimal_machine_mapping_internal( - context, serial, resource, fixed_machine_views); - }, - [&](ParallelSplit const ¶llel) { - return get_optimal_machine_mapping_internal( - context, parallel, resource, fixed_machine_views); - }, - [&](Node const &node) { - return get_optimal_machine_mapping_internal( - context, node, resource, fixed_machine_views); - }}); - - context.cached_subgraph_results.save(state, result); + MachineMappingResult result = visit( + sp_decomposition_tree, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping_internal(context, decomp_tree_node, resources, partial_solution); + }); + + result_cache.save(state, result); return result; } MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - SerialSplit const &serial, + MachineMappingCache &result_cache, + MachineMappingContext const &context, + PCGBinarySeriesSplit const &series, MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + PartialMachineMapping const &partial_solution) { - auto [decompn1, decompn2] = split_sp_decomposition(serial); - - GraphSplit graph_split = get_graph_split(decompn1, decompn2); + MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); auto is_subgraph_input = [&](std::unordered_set const &subgraph_nodes, parallel_tensor_guid_t const &input_tensor) { return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); }; - std::unordered_set all_edges1 = - set_union(transform(graph_split.first, [&](Node const &node) { - return unordered_set_of( - get_layer_outputs(context.pcg, parallel_layer_guid_t(node))); - })); - std::unordered_set all_edges2 = - set_union(transform(graph_split.second, [&](Node const &node) { - return unordered_set_of( - get_layer_inputs(context.pcg, parallel_layer_guid_t(node))); - })); - std::unordered_set split_edges = - filter(all_edges2, [&](parallel_tensor_guid_t const &input_tensor) { - return is_subgraph_input(graph_split.second, input_tensor); - }); - - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, all_edges1); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, all_edges2); - std::vector> - machine_views_list_for_split_edges = - get_allowed_src_machine_views_list(context, split_edges, resource); - - for (std::unordered_map const - &machine_views_for_split_edge : machine_views_list_for_split_edges) { - minimize_runtime( + PCGBinarySPDecomposition pre_sub_tree = get_left_child(series); + PCGBinarySPDecomposition post_sub_tree = get_right_child(series); + + std::pair< + std::unordered_set, + std::unordered_set + > boundary_layers = + get_split_transitive_reduced_boundary_layers(context.transitive_reduced_pcg, + series); + + auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) + -> std::unordered_set> + { + std::unordered_map> + allowed = generate_map(layers, + [&](parallel_layer_guid_t const &l) { + return get_allowed_machine_views_for_layer(context, l); + }); + return get_all_assignments(allowed); + }; + + for (std::unordered_map const &assigned_pre_machine_views + : get_boundary_machine_view_assignments(boundary_layers.first)) { + + PartialMachineMapping pre_candidate = + with_additional_layer_machine_views( + context, + get_sub_solution(context, partial_solution, pre_sub_tree), + assigned_pre_machine_views); + + MachineMappingResult pre_result = + get_optimal_machine_mapping_internal(result_cache, + context, + pre_sub_tree, + resource, + pre_candidate); + + + for (std::unordered_map const &assigned_post_machine_views + : get_boundary_machine_view_assignments(boundary_layers.second)) { + + PartialMachineMapping post_candidate = + with_additional_layer_machine_views( + context, + get_sub_solution(context, partial_solution, post_sub_tree), + assigned_post_machine_views); + + MachineMappingResult post_result = + get_optimal_machine_mapping_internal(result_cache, + context, + post_sub_tree, + resource, + post_candidate); + + float cost_across_split = estimate_cost_across_split(context, + assigned_pre_machine_views, + assigned_post_machine_views); + + minimize_runtime( optimal_result, - sequential_combine( - get_optimal_machine_mapping_internal( - context, - decompn1, - resource, - merge_maps(fixed_machine_views1, machine_views_for_split_edge)), - get_optimal_machine_mapping_internal( - context, - decompn2, - resource, - merge_maps(fixed_machine_views2, - machine_views_for_split_edge)))); + sequential_combine(pre_result, cost_across_split, post_result)); + } } return optimal_result; } + + MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, - ParallelSplit const ¶llel, - MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - auto [decompn1, decompn2] = split_sp_decomposition(parallel); - - GraphSplit graph_split = get_graph_split(decompn1, decompn2); - - std::unordered_set all_edges1 = - set_union(transform(graph_split.first, [&](Node const &node) { - return unordered_set_of( - get_layer_outputs(context.pcg, parallel_layer_guid_t(node))); - })); - std::unordered_set all_edges2 = - set_union(transform(graph_split.second, [&](Node const &node) { - return unordered_set_of( - get_layer_inputs(context.pcg, parallel_layer_guid_t(node))); - })); - std::unordered_map fixed_machine_views1 = - restrict_keys(fixed_machine_views, all_edges1); - std::unordered_map fixed_machine_views2 = - restrict_keys(fixed_machine_views, all_edges2); - - MachineMappingResult optimal_result = sequential_combine( - get_optimal_machine_mapping_internal( - context, decompn1, resource, fixed_machine_views1), - get_optimal_machine_mapping_internal( - context, decompn2, resource, fixed_machine_views2)); - - for (auto const &resource_split : get_resource_split(resource)) { + MachineMappingCache &result_cache, + MachineMappingContext const &context, + PCGBinaryParallelSplit const ¶llel, + MachineSpecification const &resources, + PartialMachineMapping const &partial_solution) { + + PCGBinarySPDecomposition left_subtree = get_left_child(parallel); + PartialMachineMapping left_sub_solution = get_sub_solution(context, + partial_solution, + left_subtree); + + PCGBinarySPDecomposition right_subtree = get_right_child(parallel); + PartialMachineMapping right_sub_solution = get_sub_solution(context, + partial_solution, + right_subtree); + + MachineMappingResult optimal_result = [&] { + PCGBinarySeriesSplit series = make_pcg_series_split( + get_left_child(parallel), + get_right_child(parallel)); + return get_optimal_machine_mapping_internal(result_cache, + context, + series, + resources, + partial_solution); + }(); + + for (auto const &resource_split : get_machine_resource_splits(resources)) { + MachineMappingResult left_result = + get_optimal_machine_mapping_internal(result_cache, + context, + left_subtree, + resource_split.first, + left_sub_solution); + MachineMappingResult right_result = + get_optimal_machine_mapping_internal(result_cache, + context, + right_subtree, + resource_split.second, + right_sub_solution); + minimize_runtime( optimal_result, - parallel_combine( - get_optimal_machine_mapping_internal( - context, decompn1, resource_split.first, fixed_machine_views1), - get_optimal_machine_mapping_internal(context, - decompn2, - resource_split.second, - fixed_machine_views2))); + parallel_combine(left_result, right_result)); } return optimal_result; @@ -263,39 +234,21 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingContext &context, - Node const &node, + parallel_layer_guid_t const &layer, MachineSpecification const &resource, - std::unordered_map const - &fixed_machine_views) { - - parallel_layer_guid_t layer = parallel_layer_guid_t(node); - std::unordered_set machine_views_not_fixed = - set_minus(unordered_set_of(get_layer_outputs(context.pcg, layer)), - keys(fixed_machine_views)); + PartialMachineMapping const &partial_solution) { - std::vector> - machine_views_list_for_not_fixed = get_allowed_src_machine_views_list( - context, machine_views_not_fixed, resource); + assert (keys(partial_solution.machine_views) == std::unordered_set{layer}); - MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + float cost = estimate_layer_cost(context.transitive_reduced_pcg.full_pcg, + context.cost_estimator, + layer, + partial_solution.machine_views.at(layer)); - for (std::unordered_map const - &machine_views_for_not_fixed : machine_views_list_for_not_fixed) { - std::unordered_map full_machine_views = - merge_maps(fixed_machine_views, machine_views_for_not_fixed); - float runtime = singleton_subgraph_cost( - context.pcg, context.cost_estimator, layer, full_machine_views); - MachineMapping machine_mapping = - MachineMapping{std::unordered_map{ - {layer, - full_machine_views.at(get_layer_outputs(context.pcg, layer)[0])}, - }}; - MachineMappingResult curr_result = - MachineMappingResult(runtime, machine_mapping); - minimize_runtime(optimal_result, curr_result); - } - - return optimal_result; + return MachineMappingResult{ + /*runtime=*/cost, + /*machine_mapping=*/MachineMapping{partial_solution.machine_views}, + }; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc new file mode 100644 index 0000000000..36e12ff5eb --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc @@ -0,0 +1,31 @@ +#include "compiler/machine_mapping/machine_mapping_context.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" + +namespace FlexFlow { + +std::unordered_set get_allowed_machine_views_for_tensor(MachineMappingContext const &, + parallel_tensor_guid_t const &) { + NOT_IMPLEMENTED(); +} + +MachineMappingContext make_machine_mapping_context(ParallelComputationGraph const &pcg, + CostEstimator const &cost_estimator, + std::function( + ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views) { + NOT_IMPLEMENTED(); +} + +std::unordered_set get_transitively_reduced_predecessors(MachineMappingContext const &ctx, + parallel_layer_guid_t const &l) { + NOT_IMPLEMENTED(); +} + +std::unordered_set get_transitively_reduced_successors(MachineMappingContext const &ctx, + parallel_layer_guid_t const &l) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 721a3bd27f..5e630cdef7 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -4,9 +4,10 @@ namespace FlexFlow { MachineMappingResult sequential_combine(MachineMappingResult const &s1, + float comm_cost, MachineMappingResult const &s2) { return MachineMappingResult{ - s1.runtime + s2.runtime, + s1.runtime + comm_cost + s2.runtime, combine_disjoint_mappings(s1.machine_mapping, s2.machine_mapping)}; } diff --git a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc new file mode 100644 index 0000000000..09b59c75cd --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc @@ -0,0 +1,30 @@ +#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_context.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/restrict_keys.h" + +namespace FlexFlow { + +PartialMachineMapping get_sub_solution(MachineMappingContext const &ctx, + PartialMachineMapping const &partial_solution, + PCGBinarySPDecomposition const &sub_problem) { + std::unordered_set sub_solution_layers = + flatmap(get_parallel_layers(sub_problem), + [&](parallel_layer_guid_t l) { + return set_union( + get_transitively_reduced_predecessors(ctx, l), + get_transitively_reduced_successors(ctx, l)); + }); + + return PartialMachineMapping{ + restrict_keys(partial_solution.machine_views, sub_solution_layers), + }; +} + +MachineMapping require_complete(MachineMappingContext const &ctx, + PartialMachineMapping const &partial_solution) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc b/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc deleted file mode 100644 index b5abe383d3..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/split_sp_decomposition.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "compiler/machine_mapping/split_sp_decomposition.h" -#include "utils/containers/as_vector.h" -#include "utils/containers/transform.h" -#include "utils/variant.h" - -namespace FlexFlow { - -std::pair - split_sp_decomposition(SerialSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SerialSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - split_sp_decomposition(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc new file mode 100644 index 0000000000..9c1d5bb05e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -0,0 +1,84 @@ +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const &pcg) { + DiGraphView raw_digraph = pcg.raw_graph; + DiGraphView transitively_reduced = transitive_reduction(raw_digraph); + + return TransitiveReducedPCG{ + /*pcg=*/pcg, + /*transitive_reduction=*/transitively_reduced, + }; +} + +std::unordered_set get_transitive_reduced_predecessors(TransitiveReducedPCG const &tr_pcg, + parallel_layer_guid_t const &layer) { + std::unordered_set raw_predecessors = get_predecessors(tr_pcg.transitive_reduction, layer.raw_graph_node); + return transform(raw_predecessors, [](Node const &n) { return parallel_layer_guid_t{n}; }); +} + +std::unordered_set get_transitive_reduced_successors(TransitiveReducedPCG const &tr_pcg, + parallel_layer_guid_t const &layer) { + std::unordered_set raw_successors = get_successors(tr_pcg.transitive_reduction, layer.raw_graph_node); + return transform(raw_successors, [](Node const &n) { return parallel_layer_guid_t{n}; }); +} + +std::unordered_set + get_transitively_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + std::unordered_set src_subgraph = unordered_set_of(get_parallel_layers(get_left_child(split))); + std::unordered_set dst_subgraph = unordered_set_of(get_parallel_layers(get_right_child(split))); + + std::unordered_set raw_src_subgraph = transform(src_subgraph, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }); + std::unordered_set raw_dst_subgraph = transform(dst_subgraph, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }); + + std::unordered_set raw_edges = get_edges_from_subgraph_to_subgraph(tr_pcg.transitive_reduction, + raw_src_subgraph, + raw_dst_subgraph); + + return flatmap(raw_edges, + [&](DirectedEdge const &e) { + return get_pcg_edges_from_layer_to_layer(tr_pcg.full_pcg, + parallel_layer_guid_t{e.src}, + parallel_layer_guid_t{e.dst}); + }); +} + +std::unordered_set + get_transitively_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + return transform(get_transitively_reduced_edges_across_split(tr_pcg, split), + [](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e); }); +} + +std::pair< + std::unordered_set, + std::unordered_set +> get_split_transitively_reduced_boundary_layers(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + std::unordered_set edges = get_transitive_reduced_edges_across_split(tr_pcg, split); + + std::unordered_set src_boundary_layers = transform(edges, + [](ParallelComputationGraphEdge const &e) { return get_src_layer(e); }); + + std::unordered_set dst_boundary_layers = transform(edges, + [](ParallelComputationGraphEdge const &e) { return get_dst_layer(e); }); + + return { + src_boundary_layers, + dst_boundary_layers, + }; +} + + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc new file mode 100644 index 0000000000..5559465fa3 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc @@ -0,0 +1,11 @@ +#include "compiler/series_parallel/get_pcg_series_parallel_decomposition.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition( + ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc new file mode 100644 index 0000000000..0fe344aef8 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc @@ -0,0 +1,19 @@ +#include "compiler/series_parallel/pcg_binary_parallel_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +PCGBinarySPDecomposition get_left_child(PCGBinaryParallelSplit const &s) { + return PCGBinarySPDecomposition{ + get_left_child(s.raw_split), + }; +} + +PCGBinarySPDecomposition get_right_child(PCGBinaryParallelSplit const &s) { + return PCGBinarySPDecomposition{ + get_right_child(s.raw_split), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc new file mode 100644 index 0000000000..636671c8fa --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -0,0 +1,19 @@ +#include "compiler/series_parallel/pcg_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &s) { + return PCGBinarySPDecomposition{ + get_left_child(s.raw_split), + }; +} + +PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &s) { + return PCGBinarySPDecomposition{ + get_right_child(s.raw_split), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..be9ccd3ab2 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -0,0 +1,15 @@ +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc new file mode 100644 index 0000000000..c3e7a8f3bf --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -0,0 +1,78 @@ +#include "./cost_estimator_for_test.h" + +namespace FlexFlow { + +TestCostEstimator::TestCostEstimator( + std::function const &, + std::vector const &, + std::vector const &, + MachineView const &)> const &get_operator_cost, + std::function const &get_communication_cost) + : get_operator_cost(get_operator_cost), + get_communication_cost(get_communication_cost) +{ } + +float TestCostEstimator::estimate_cost(PCGOperatorAttrs const &op, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs, + MachineView const &mv) const { + return this->get_operator_cost(op, inputs, weights, outputs, mv); +} + +float TestCostEstimator::estimate_cost(ParallelTensorShape const &tensor_shape, + MachineView const &src, + MachineView const &dst) const { + return this->get_communication_cost(tensor_shape, src, dst); +} + +CostEstimator make_cost_estimator( + std::function const &, + std::vector const &, + std::vector const &, + MachineView const &)> const &get_operator_cost, + std::function const &get_communication_cost) { + return CostEstimator::create(get_operator_cost, get_communication_cost); +} + +CostEstimator make_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map) { + return make_cost_estimator( + [op_cost_map](PCGOperatorAttrs const &op_attrs, + std::vector const &input_shapes, + std::vector const &weight_shapes, + std::vector const &output_shapes, + MachineView const &machine_view) { + + OpCostEstimateKey key = OpCostEstimateKey{ + /*op_attrs=*/op_attrs, + /*input_shapes=*/input_shapes, + /*weight_shapes=*/weight_shapes, + /*output_shapes=*/output_shapes, + /*machine_view=*/machine_view, + }; + + return op_cost_map.at(key); + }, + [comm_cost_map](ParallelTensorShape const ¶llel_tensor_shape, + MachineView const &src_machine_view, + MachineView const &dst_machine_view) { + + CommCostEstimateKey key = CommCostEstimateKey{ + /*parallel_tensor_shape=*/parallel_tensor_shape, + /*src_machine_view=*/src_machine_view, + /*dst_machine_view=*/dst_machine_view, + }; + + return comm_cost_map.at(key); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index 86eb824dd3..bfb3f6d8eb 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -2,24 +2,50 @@ #define _FLEXFLOW_TEST_COST_ESTIMATOR_H #include "compiler/cost_estimator.h" +#include "compiler/op_cost_estimate_key.dtg.h" +#include "compiler/comm_cost_estimate_key.dtg.h" namespace FlexFlow { -struct CostEstimatorForTest : public ICostEstimator { - inline float estimate_cost(PCGOperatorAttrs const &op, +struct TestCostEstimator : public ICostEstimator { + std::function const &, + std::vector const &, + std::vector const &, + MachineView const &)> get_operator_cost; + std::function get_communication_cost; + + TestCostEstimator() = delete; + TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, + decltype(get_communication_cost) const &get_communication_cost); + + float estimate_cost(PCGOperatorAttrs const &op, std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override { - return 1; - } - inline float estimate_cost(ParallelTensorShape const &tensor_shape, + std::vector const &weights, + std::vector const &outputs, + MachineView const &mv) const override; + + float estimate_cost(ParallelTensorShape const &tensor_shape, MachineView const &src, - MachineView const &dst) const override { - return 1; - } + MachineView const &dst) const override; }; +CostEstimator make_cost_estimator( + std::function const &, + std::vector const &, + std::vector const &, + MachineView const &)> const &get_operator_cost, + std::function const &get_communication_cost); + +CostEstimator make_cost_estimator( + std::unordered_map const &, + std::unordered_map const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc new file mode 100644 index 0000000000..49200fdd50 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc @@ -0,0 +1,36 @@ +#include "compiler/machine_mapping/estimate_cost_across_split.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("estimate_cost_across_split") { + SUBCASE("single edge across split") { + SUBCASE("src and dst layers have same MachineView") { + FAIL("TODO"); + } + + SUBCASE("src and dst layers have different MachineViews") { + FAIL("TODO"); + } + } + + SUBCASE("single tensor, multiple consumers across split") { + SUBCASE("consumers have same view") { + FAIL("TODO"); + } + + SUBCASE("consumers have non-overlapping views") { + FAIL("TODO"); + } + + SUBCASE("consumers have different but overlapping views") { + FAIL("TODO"); + } + } + + SUBCASE("multiple tensors, multiple consumers across split") { + FAIL("TODO"); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc new file mode 100644 index 0000000000..1daa6aa272 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -0,0 +1,124 @@ +#include "compiler/machine_mapping/estimate_layer_cost.h" +#include "./cost_estimator_for_test.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("estimate_layer_cost") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{8, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + LinearAttrs linear_attrs = LinearAttrs{ + /*out_channels=*/12, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }; + + ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); + ParallelTensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(linear_attrs, input_shape)); + + auto make_tensor_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_grad=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + ParallelLayerAddedResult input = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs(PCGOperatorAttrs{InputAttrs{}}), + /*inputs=*/{}, + /*output_labels=*/{make_tensor_attrs(input_shape)}); + parallel_tensor_guid_t input_tensor = get_only(input.outputs); + + ParallelLayerAddedResult projection = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs( + PCGOperatorAttrs{ + WeightAttrs{ + /*tensor_shape=*/get_reduced_shape(projection_shape), + }, + }), + /*inputs=*/{}, + /*output_labels=*/{make_tensor_attrs(projection_shape)}); + parallel_tensor_guid_t projection_tensor = get_only(projection.outputs); + + ParallelLayerAddedResult bias = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs( + PCGOperatorAttrs{ + WeightAttrs{ + /*tensor_shape=*/get_reduced_shape(bias_shape), + }, + }), + /*inputs=*/{}, + /*output_labels=*/{make_tensor_attrs(bias_shape)}); + parallel_tensor_guid_t bias_tensor = get_only(bias.outputs); + + ParallelLayerAddedResult linear = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs(PCGOperatorAttrs{linear_attrs}), + /*inputs=*/{ + get_only(input.outputs), + get_only(projection.outputs), + get_only(bias.outputs), + }, + /*output_labels=*/{make_tensor_attrs(output_shape)}); + parallel_tensor_guid_t linear_output = get_only(linear.outputs); + + MachineView machine_view = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); + + + CostEstimator cost_estimator = make_cost_estimator( + { + { + OpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{linear_attrs}, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{projection_shape, bias_shape}, + /*output_shapes=*/{output_shape}, + /*machine_view=*/machine_view, + }, + 2.0, + }, + }, + {} + ); + + SUBCASE("returns just the layer cost if the layer exists") { + float result = estimate_layer_cost(pcg, + cost_estimator, + linear.parallel_layer, + machine_view); + float correct = 2.0; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..af2814bca0 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,213 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include +#include "utils/hash/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_resource_splits") { + auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) { + return MachineSpecification{ + /*num_nodes=*/num_nodes, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/num_gpus_per_node, + /*inter_node_bandwidth=*/1.0, + /*intra_node_bandwidth=*/1.0, + }; + }; + + SUBCASE("returns no splits if no splits are possible") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1); + + std::unordered_set> result = get_machine_resource_splits(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns splits in gpu and node dimensions, but not at the same time") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/2); + + std::unordered_set> result = get_machine_resource_splits(input); + + std::unordered_set> correct = { + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + + }; + + CHECK(result == correct); + } + + SUBCASE("returns splits in node dimension in powers of two") { + SUBCASE("num_nodes is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/8, + /*num_gpus_per_node=*/1); + + std::unordered_set> result = get_machine_resource_splits(input); + + std::unordered_set> correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_nodes is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1); + + std::unordered_set> result = get_machine_resource_splits(input); + + std::unordered_set> correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + + } + } + + SUBCASE("returns splits in gpu dimension in powers of two") { + SUBCASE("num_gpus_per_node is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/8); + + std::unordered_set> result = get_machine_resource_splits(input); + + std::unordered_set> correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_gpus_per_node is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6); + + std::unordered_set> result = get_machine_resource_splits(input); + + std::unordered_set> correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + } + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index e77b6f76cf..16e7b46b09 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,10 +1,12 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "cost_estimator_for_test.h" -#include "doctest/doctest.h" +#include #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" using namespace FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, @@ -13,191 +15,278 @@ TEST_SUITE(FF_TEST_SUITE) { make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; CostEstimator estimator1 = CostEstimator::create(); - MachineSpecification machine_spec1(2, 1, 1, 1, 1); - MachineMappingCache cached_results1; - - SUBCASE("simple PCG") { - - ParallelComputationGraph pcg_simple = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape0 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{32, 1}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - ParallelTensorShape input_shape1 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - ShardParallelDim{8, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape0); - parallel_tensor_guid_t input1 = - builder.create_input_tensor(input_shape1); - parallel_tensor_guid_t dense0 = builder.batch_matmul(input0, input1); - - return builder.pcg; - }(); - - MachineMappingResult result = - get_optimal_machine_mapping(pcg_simple, - allowed_machine_views1, - estimator1, - machine_spec1, - cached_results1); - - CHECK(result.runtime == 3); - } + MachineSpecification machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + + SUBCASE("single layer") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + MachineView mv1 = make_1d_machine_view(gpu_id_t{1}, gpu_id_t{2}); + + auto allowed_machine_views = [&](ParallelLayerAttrs const &, + MachineSpecification const &) { + return std::unordered_set{mv1}; + }; - SUBCASE("PCG is a chain") { - ParallelComputationGraph pcg_chain = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape); - parallel_tensor_guid_t layer1 = builder.identity(input0); - parallel_tensor_guid_t layer2 = builder.identity(layer1); - parallel_tensor_guid_t layer3 = builder.identity(layer2); - parallel_tensor_guid_t layer4 = builder.identity(layer3); - parallel_tensor_guid_t layer5 = builder.identity(layer4); - parallel_tensor_guid_t layer6 = builder.identity(layer5); - - return builder.pcg; - }(); - - MachineMappingResult result = - get_optimal_machine_mapping(pcg_chain, - allowed_machine_views1, - estimator1, - machine_spec1, - cached_results1); - CHECK(result.runtime == 13); + CostEstimator cost_estimator = make_cost_estimator( + [&](PCGOperatorAttrs const &, + std::vector const &, + std::vector const &, + std::vector const &, + MachineView const &) { + return 1.0; + }, + [&](ParallelTensorShape const &, + MachineView const &, + MachineView const &) { + return 0.5; + }); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + InputAttrs{}, + }, + std::nullopt, + }; + + ParallelTensorAttrs output_tensor_attrs = ParallelTensorAttrs{ + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + ParallelLayerAddedResult added = add_parallel_layer(pcg, + layer_attrs, + {}, + {output_tensor_attrs}); + parallel_layer_guid_t layer = added.parallel_layer; + parallel_tensor_guid_t output_tensor = get_only(added.outputs); + + MachineMappingCache cache; + + MachineMappingResult result = get_optimal_machine_mapping(pcg, + allowed_machine_views, + cost_estimator, + machine_spec, + cache); + MachineMappingResult correct = MachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/MachineMapping{{ + {layer, mv1}, + }}, + }; + + CHECK(result == correct); } - SUBCASE("PCG has multiple chains") { - ParallelComputationGraph pcg_multiple_chains = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape0 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{32, 1}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - ParallelTensorShape input_shape1 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - ShardParallelDim{8, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape0); - parallel_tensor_guid_t input1 = - builder.create_input_tensor(input_shape1); - parallel_tensor_guid_t relu0 = builder.relu(input0); - parallel_tensor_guid_t relu1 = builder.relu(input1); - parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); - - return builder.pcg; - }(); - - MachineMappingResult result = - get_optimal_machine_mapping(pcg_multiple_chains, - allowed_machine_views1, - estimator1, - machine_spec1, - cached_results1); - CHECK(result.runtime == 5); + SUBCASE("pair of layers in sequence") { + FAIL("TODO"); } - SUBCASE("PCG is not sp-izable due to multiple inputs") { - ParallelComputationGraph pcg_non_sp = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape); - parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - parallel_tensor_guid_t dense1 = builder.dense(input0, 4); - parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); - parallel_tensor_guid_t add0 = builder.add(dense0, dense2); - - return builder.pcg; - }(); - - // TODO: Handle this case in compiler - // TODO: separate testcases for this too that actually check the graph - // manipulation - if (false) { - MachineMappingResult result = - get_optimal_machine_mapping(pcg_non_sp, - allowed_machine_views1, - estimator1, - machine_spec1, - cached_results1); - CHECK(bool(result.runtime > 0)); - CHECK(result.runtime == 7); - } + SUBCASE("pair of layers in parallel") { + FAIL("TODO"); } + + // SUBCASE("simple PCG") { + // + // ParallelComputationGraph pcg_simple = [&] { + // ParallelComputationGraphBuilder builder; + // + // ParallelTensorShape input_shape0 = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{32, 1}, + // ShardParallelDim{16, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // ParallelTensorShape input_shape1 = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{16, 1}, + // ShardParallelDim{8, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // parallel_tensor_guid_t input0 = + // builder.create_input_tensor(input_shape0); + // parallel_tensor_guid_t input1 = + // builder.create_input_tensor(input_shape1); + // parallel_tensor_guid_t dense0 = builder.batch_matmul(input0, input1); + // + // return builder.pcg; + // }(); + // + // MachineMappingResult result = + // get_optimal_machine_mapping(pcg_simple, + // allowed_machine_views1, + // estimator1, + // machine_spec1, + // cached_results1); + // + // CHECK(result.runtime == 3); + // } + + // SUBCASE("PCG is a chain") { + // ParallelComputationGraph pcg_chain = [&] { + // ParallelComputationGraphBuilder builder; + // + // ParallelTensorShape input_shape = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{16, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // parallel_tensor_guid_t input0 = + // builder.create_input_tensor(input_shape); + // parallel_tensor_guid_t layer1 = builder.identity(input0); + // parallel_tensor_guid_t layer2 = builder.identity(layer1); + // parallel_tensor_guid_t layer3 = builder.identity(layer2); + // parallel_tensor_guid_t layer4 = builder.identity(layer3); + // parallel_tensor_guid_t layer5 = builder.identity(layer4); + // parallel_tensor_guid_t layer6 = builder.identity(layer5); + // + // return builder.pcg; + // }(); + // + // MachineMappingResult result = + // get_optimal_machine_mapping(pcg_chain, + // allowed_machine_views1, + // estimator1, + // machine_spec1, + // cached_results1); + // CHECK(result.runtime == 13); + // } + // + // SUBCASE("PCG has multiple chains") { + // ParallelComputationGraph pcg_multiple_chains = [&] { + // ParallelComputationGraphBuilder builder; + // + // ParallelTensorShape input_shape0 = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{32, 1}, + // ShardParallelDim{16, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // ParallelTensorShape input_shape1 = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{16, 1}, + // ShardParallelDim{8, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // parallel_tensor_guid_t input0 = + // builder.create_input_tensor(input_shape0); + // parallel_tensor_guid_t input1 = + // builder.create_input_tensor(input_shape1); + // parallel_tensor_guid_t relu0 = builder.relu(input0); + // parallel_tensor_guid_t relu1 = builder.relu(input1); + // parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); + // + // return builder.pcg; + // }(); + // + // MachineMappingResult result = + // get_optimal_machine_mapping(pcg_multiple_chains, + // allowed_machine_views1, + // estimator1, + // machine_spec1, + // cached_results1); + // CHECK(result.runtime == 5); + // } + // + // SUBCASE("PCG is not sp-izable due to multiple inputs") { + // ParallelComputationGraph pcg_non_sp = [&] { + // ParallelComputationGraphBuilder builder; + // + // ParallelTensorShape input_shape = + // ParallelTensorShape{ParallelTensorDims{ + // FFOrdered{ + // ShardParallelDim{32, 2}, + // ShardParallelDim{16, 1}, + // }, + // ReplicaParallelDimSet{ + // SumDegree{1}, + // DiscardCopyDegree{1}, + // }, + // }, + // DataType::FLOAT}; + // + // parallel_tensor_guid_t input0 = + // builder.create_input_tensor(input_shape); + // parallel_tensor_guid_t dense0 = builder.dense(input0, 8); + // parallel_tensor_guid_t dense1 = builder.dense(input0, 4); + // parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); + // parallel_tensor_guid_t add0 = builder.add(dense0, dense2); + // + // return builder.pcg; + // }(); + // + // // TODO: Handle this case in compiler + // // TODO: separate testcases for this too that actually check the graph + // // manipulation + // if (false) { + // MachineMappingResult result = + // get_optimal_machine_mapping(pcg_non_sp, + // allowed_machine_views1, + // estimator1, + // machine_spec1, + // cached_results1); + // CHECK(bool(result.runtime > 0)); + // CHECK(result.runtime == 7); + // } + // } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc new file mode 100644 index 0000000000..4bb1afab53 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -0,0 +1,26 @@ +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_predecessors") { + FAIL("TODO"); + } + + TEST_CASE("get_transitive_reduced_successors") { + FAIL("TODO"); + } + + TEST_CASE("get_transitive_reduced_edges_across_split") { + FAIL("TODO"); + } + + TEST_CASE("get_transitive_reduced_tensors_across_split") { + FAIL("TODO"); + } + + TEST_CASE("get_split_transitive_reduced_boundary_layers") { + FAIL("TODO"); + } +} diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index c761eaf1d9..3128e39fd0 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/batch_matmul.h b/lib/local-execution/src/ops/batch_matmul.h index c082dec020..a7e29b1931 100644 --- a/lib/local-execution/src/ops/batch_matmul.h +++ b/lib/local-execution/src/ops/batch_matmul.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/op_task_signature.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_matmul.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/batch_norm.h b/lib/local-execution/src/ops/batch_norm.h index 1f6cceec19..36aa8ffa4e 100644 --- a/lib/local-execution/src/ops/batch_norm.h +++ b/lib/local-execution/src/ops/batch_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/cast.h b/lib/local-execution/src/ops/cast.h index b4a1e91c91..e7af6aca6b 100644 --- a/lib/local-execution/src/ops/cast.h +++ b/lib/local-execution/src/ops/cast.h @@ -17,7 +17,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/combine.h b/lib/local-execution/src/ops/combine.h index c6157a2955..e85e8fba39 100644 --- a/lib/local-execution/src/ops/combine.h +++ b/lib/local-execution/src/ops/combine.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/combine_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/concat.h b/lib/local-execution/src/ops/concat.h index 1f1443f25d..eab70d621c 100644 --- a/lib/local-execution/src/ops/concat.h +++ b/lib/local-execution/src/ops/concat.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/conv_2d.h b/lib/local-execution/src/ops/conv_2d.h index f70d36d514..0358d71eea 100644 --- a/lib/local-execution/src/ops/conv_2d.h +++ b/lib/local-execution/src/ops/conv_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 84b67a29c2..a3dc5ff8af 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" #include "local-execution/task_id_t.dtg.h" -#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.h b/lib/local-execution/src/ops/element_binary.h index 05273e34b4..72c0976df8 100644 --- a/lib/local-execution/src/ops/element_binary.h +++ b/lib/local-execution/src/ops/element_binary.h @@ -3,7 +3,7 @@ #include "local-execution/sim_environment.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_unary.h b/lib/local-execution/src/ops/element_unary.h index 4d1783f1f6..04a72e2e12 100644 --- a/lib/local-execution/src/ops/element_unary.h +++ b/lib/local-execution/src/ops/element_unary.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/embedding.h b/lib/local-execution/src/ops/embedding.h index 0463984122..995d2296e1 100644 --- a/lib/local-execution/src/ops/embedding.h +++ b/lib/local-execution/src/ops/embedding.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/flat.h b/lib/local-execution/src/ops/flat.h index d1501f85ca..e019bfc654 100644 --- a/lib/local-execution/src/ops/flat.h +++ b/lib/local-execution/src/ops/flat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_FLAT_H #include "local-execution/sim_environment.h" -#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/flat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/src/ops/gather.h index 74db276e35..e339683381 100644 --- a/lib/local-execution/src/ops/gather.h +++ b/lib/local-execution/src/ops/gather.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/gather_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h index 97985585e1..baad25b798 100644 --- a/lib/local-execution/src/ops/input.h +++ b/lib/local-execution/src/ops/input.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_INPUT_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" +#include "op-attrs/ops/input_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/layer_norm.h b/lib/local-execution/src/ops/layer_norm.h index 4f8d87153b..8e034ac519 100644 --- a/lib/local-execution/src/ops/layer_norm.h +++ b/lib/local-execution/src/ops/layer_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/linear.h b/lib/local-execution/src/ops/linear.h index 2c76483df4..2aaf13a95a 100644 --- a/lib/local-execution/src/ops/linear.h +++ b/lib/local-execution/src/ops/linear.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index 959f7dc054..1097adeb5e 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -2,9 +2,7 @@ #define _FLEXFLOW_NOOP_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" -#include "op-attrs/ops/noop.h" -#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/pool_2d.h b/lib/local-execution/src/ops/pool_2d.h index e8624185ac..908fd5462f 100644 --- a/lib/local-execution/src/ops/pool_2d.h +++ b/lib/local-execution/src/ops/pool_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduce.h b/lib/local-execution/src/ops/reduce.h index 92f0578757..7900c28159 100644 --- a/lib/local-execution/src/ops/reduce.h +++ b/lib/local-execution/src/ops/reduce.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduce.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduction.h b/lib/local-execution/src/ops/reduction.h index a0af4f3aea..56833602e6 100644 --- a/lib/local-execution/src/ops/reduction.h +++ b/lib/local-execution/src/ops/reduction.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/repartition.h b/lib/local-execution/src/ops/repartition.h index b38a93f8b1..5187d04ca0 100644 --- a/lib/local-execution/src/ops/repartition.h +++ b/lib/local-execution/src/ops/repartition.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/replicate.h b/lib/local-execution/src/ops/replicate.h index 77bda411c1..85d1dff41a 100644 --- a/lib/local-execution/src/ops/replicate.h +++ b/lib/local-execution/src/ops/replicate.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reshape.h b/lib/local-execution/src/ops/reshape.h index 06a6b32597..37f07534ee 100644 --- a/lib/local-execution/src/ops/reshape.h +++ b/lib/local-execution/src/ops/reshape.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reshape.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reverse.h b/lib/local-execution/src/ops/reverse.h index 10072860b0..7c16073be7 100644 --- a/lib/local-execution/src/ops/reverse.h +++ b/lib/local-execution/src/ops/reverse.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reverse.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/softmax.h b/lib/local-execution/src/ops/softmax.h index b5756d92ff..d440fe7239 100644 --- a/lib/local-execution/src/ops/softmax.h +++ b/lib/local-execution/src/ops/softmax.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/split.h b/lib/local-execution/src/ops/split.h index c82152b06a..dde46c20bf 100644 --- a/lib/local-execution/src/ops/split.h +++ b/lib/local-execution/src/ops/split.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/split.h" +#include "op-attrs/ops/split_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/topk.h b/lib/local-execution/src/ops/topk.h index b04d807400..c8f3175ebd 100644 --- a/lib/local-execution/src/ops/topk.h +++ b/lib/local-execution/src/ops/topk.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/topk.h" +#include "op-attrs/ops/topk_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/transpose.h b/lib/local-execution/src/ops/transpose.h index 3feffc7d86..0f3a2e80a0 100644 --- a/lib/local-execution/src/ops/transpose.h +++ b/lib/local-execution/src/ops/transpose.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h deleted file mode 100644 index 268554b5be..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H - -#include "op-attrs/ops/core.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/broadcast.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" -#include "utils/record_formatter.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes); - -bool is_valid(PCGOperatorAttrs const &, - std::vector const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc deleted file mode 100644 index e6459c6819..0000000000 --- a/lib/op-attrs/src/operator_attrs.cc +++ /dev/null @@ -1,287 +0,0 @@ -#include "op-attrs/operator_attrs.h" -#include "utils/fmt.h" -#include "utils/record_formatter.h" -#include "utils/type_traits.h" - -namespace FlexFlow { - -/* OperatorType GetOpType::operator()(BatchMatmulAttrs const &p) const { return - * OP_BATCHMATMUL; } */ -/* OperatorType GetOpType::operator()(Conv2DAttrs const &p) const { return - * OP_CONV2D; } */ -/* OperatorType GetOpType::operator()(ConcatAttrs const &p) const { return - * OP_CONCAT; } */ -/* OperatorType GetOpType::operator()(CastAttrs const &p) const { return - * OP_CAST; } */ -/* OperatorType GetOpType::operator()(ElementBinaryAttrs const &p) const { - * return p.type; } */ -/* OperatorType GetOpType::operator()(ElementUnaryAttrs const &p) const { return - * p.op_type; } */ -/* OperatorType GetOpType::operator()(DropoutAttrs const &p) const { return - * OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(EmbeddingAttrs const &p) const { return - * OP_EMBEDDING; } */ -/* OperatorType GetOpType::operator()(FlatAttrs const &p) const { return - * OP_FLAT; } */ -/* OperatorType GetOpType::operator()(LayerNormAttrs const &p) const { return - * OP_LAYERNORM; } */ -/* OperatorType GetOpType::operator()(LinearAttrs const &p) const { return - * OP_LINEAR; } */ -/* OperatorType GetOpType::operator()(MultiHeadAttentionAttrs const &p) const { - * return OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(Pool2DAttrs const &p) const { return - * OP_POOL2D; } */ -/* OperatorType GetOpType::operator()(ReshapeAttrs const &p) const { return - * OP_RESHAPE; } */ -/* OperatorType GetOpType::operator()(SplitAttrs const &p) const { return - * OP_SPLIT; } */ -/* OperatorType GetOpType::operator()(SoftmaxAttrs const &p) const { return - * OP_SOFTMAX; } */ -/* OperatorType GetOpType::operator()(TransposeAttrs const &p) const { return - * OP_TRANSPOSE; } */ -/* OperatorType GetOpType::operator()(RepartitionAttrs const &p) const { return - * OP_REPARTITION; } */ -/* OperatorType GetOpType::operator()(ReplicateAttrs const &p) const { return - * OP_REPLICATE; } */ -/* OperatorType GetOpType::operator()(ReductionAttrs const &p) const { return - * OP_REDUCTION; } */ -/* OperatorType GetOpType::operator()(CombineAttrs const &p) const { return - * OP_COMBINE; } */ -/* OperatorType GetOpType::operator()(FusedParallelOpAttrs const &p) const { - * return OP_FUSED_PARALLEL; } */ - -/* struct AsOpAttrs { */ -/* template */ -/* OpAttrsInterface const &operator()(T const &p) { */ -/* return p; */ -/* } */ -/* }; */ - -/* OperatorType get_op_type(OpAttrsInterface const &o) { */ -/* return o.op_type(); */ -/* } */ -/* // */ -/* OperatorType get_op_type(CompGraphOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* OperatorType get_op_type(PCGOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* std::vector get_output_shapes(PCGOperatorAttrs const - * &op_params, std::vector const &input_tensor_shapes) { */ -/* return mpark::visit(AsOpAttrs{}, - * op_params).output_shapes(input_tensor_shapes); */ -/* } */ - -/* bool is_parallel_op(PCGOperatorAttrs const &o) { */ -/* return is_parallel_op(get_op_type(o)); */ -/* } */ -template -typename std::enable_if<(is_streamable::value && - !is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - std::ostringstream oss; - oss << t; - r << oss; -} - -template -typename std::enable_if<(is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - r << fmt::to_string(t); -} -void as_dot(int x, RecordFormatter &r) { - r << std::to_string(x); -} - -void as_dot(std::string const &s, RecordFormatter &r) { - r << s; -} - -template -void as_dot(std::vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -template -void as_dot(stack_vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -struct as_dot_visitor { - as_dot_visitor() = delete; - as_dot_visitor(RecordFormatter &result) : result(result) {} - - RecordFormatter &result; - - template - void operator()(char const *name, T const &t) { - RecordFormatter kv; - kv << name; - as_dot(t, result); - result << kv; - } - - template - void operator()(T const &t) { - as_dot(t, result); - } - - /* template */ - /* void operator()(const char *name, std::vector const &t) { */ - /* RecordFormatter kv; */ - /* kv << name; */ - /* RecordFormatter v; */ - /* for (V const &vv : t) { */ - /* v << as_dot_str(vv); */ - /* } */ - /* kv << v; */ - /* } */ -}; - -template -typename std::enable_if::value>::type - as_dot(T const &t, RecordFormatter &r) { - as_dot_visitor vis(r); - visit_struct::for_each(t, vis); -} - -struct AsDot { - template - RecordFormatter operator()(T const &t) { - return as_dot(t); - } -}; - -template -RecordFormatter as_dot(std::variant const &o) { - return std::visit(AsDot{}, o); -} - -struct IsValidFunctor { - IsValidFunctor(std::vector const &_input_shapes) - : input_shapes(_input_shapes) {} - - std::vector const &input_shapes; - - // bool operator()(AggregateAttrs const &attrs) { - // return is_valid(attrs, - // input_shapes.at(0), - // input_shapes.at(1), - // input_shapes.at(2), - // input_shapes.at(3), - // subvec(input_shapes, 4, nullopt)); - // } - - template - bool operator()(T const &) { - return true; // TODO FIXME @lockshaw - } -}; - -bool is_valid(PCGOperatorAttrs const &attrs, - std::vector const &input_shapes) { - NOT_IMPLEMENTED(); -} - -/* int num_outputs(OperatorParameters const &o) { */ -/* switch (get_op_type(o)) { */ -/* case OP_SPLIT: */ -/* } */ -/* } */ - -// tl::optional get_op_parameters(Op const *op) { -// switch (op->op_type) { -// case OP_LINEAR: -// return ((Linear *)op)->get_params(); -// case OP_CONV2D: -// return ((Conv2D *)op)->get_params(); -// case OP_EW_ADD: -// case OP_EW_SUB: -// case OP_EW_MUL: -// case OP_EW_DIV: -// return ((ElementBinary *)op)->get_params(); -// case OP_EXP: -// case OP_SIN: -// case OP_COS: -// case OP_SCALAR_MULTIPLY: -// case OP_SCALAR_ADD: -// case OP_SCALAR_SUB: -// case OP_SCALAR_TRUE_DIV: -// case OP_RELU: -// case OP_SIGMOID: -// case OP_TANH: -// case OP_IDENTITY: -// case OP_GELU: -// case OP_ELU: -// return ((ElementUnary *)op)->get_params(); -// case OP_CONCAT: -// return ((Concat *)op)->get_params(); -// case OP_POOL2D: -// return ((Pool2D *)op)->get_params(); -// case OP_CAST: -// return ((Cast *)op)->get_params(); -// case OP_DROPOUT: -// return ((Dropout *)op)->get_params(); -// case OP_EMBEDDING: -// return ((Embedding *)op)->get_params(); -// case OP_FLAT: -// return ((Flat *)op)->get_params(); -// case OP_MULTIHEAD_ATTENTION: -// return ((MultiHeadAttention *)op)->get_params(); -// case OP_LAYERNORM: -// return ((LayerNorm *)op)->get_params(); -// case OP_RESHAPE: -// return ((Reshape *)op)->get_params(); -// case OP_SOFTMAX: -// return ((Softmax *)op)->get_params(); -// case OP_REPARTITION: -// return ((Repartition *)op)->get_params(); -// case OP_REPLICATE: -// return ((Replicate *)op)->get_params(); -// case OP_REDUCTION: -// return ((Reduction *)op)->get_params(); -// case OP_COMBINE: -// return ((Combine *)op)->get_params(); -// case OP_FUSED_PARALLEL: -// return ((FusedParallelOp *)op)->get_params(); -// case OP_TRANSPOSE: -// return ((Transpose *)op)->get_params(); -// case OP_BATCHMATMUL: -// return ((BatchMatmul *)op)->get_params(); -// case OP_SPLIT: -// return ((Split *)op)->get_params(); -// -// // TODO: implement the get_params() function for the operators below -// and -// // uncomment the lines below -// -// // case OP_NOOP: -// // return ((NoOp *)op)->get_params(); -// // case OP_TOPK: -// // return ((TopK *)op)->get_params(); -// // case OP_MEAN: -// // return ((Mean *)op)->get_params(); -// // case OP_CACHE: -// // return ((Cache *)op)->get_params(); -// // case OP_REVERSE: -// // return ((Reverse *)op)->get_params(); -// // case OP_BATCHNORM: -// // return ((BatchNorm *)op)->get_params(); -// -// default: -// return tl::nullopt; -// } -// } - -} // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index d7248afde4..1239e75ce1 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -21,6 +22,10 @@ ParallelLayerAddedResult std::vector const &inputs, std::vector const &output_labels); +std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &, + parallel_layer_guid_t const &); + std::vector get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -39,6 +44,8 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); +ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::vector topological_ordering(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 4d61f24d37..027b9f6c80 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/operator_attrs.h", + "op-attrs/pcg_operator_attrs.dtg.h", "utils/stack_string.h", "", ] diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b04d9d37b3..7d9c217b25 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -42,6 +43,16 @@ ParallelLayerAddedResult }; } +std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst) { + std::unordered_set raw_edges = get_dataflow_edges_from_node_to_node(pcg.raw_graph, + src.raw_graph_node, + dst.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); +} + + std::vector get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -116,6 +127,12 @@ ParallelTensorAttrs return pcg.raw_graph.at(t.raw_graph_output); } +ParallelTensorShape + get_parallel_tensor_shape(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg, t).shape; +} + std::vector topological_ordering(ParallelComputationGraph const &pcg) { return transform(get_topological_ordering(pcg.raw_graph), diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index ce00ea62f4..2d425f5c6c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -9,6 +9,19 @@ #include "utils/containers/enumerate_vector.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/reduction.h" namespace FlexFlow { @@ -182,7 +195,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, std::optional const &maybe_name) { LinearAttrs attrs = LinearAttrs{ @@ -205,9 +218,9 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::vector weights; { - ParallelTensorShape kernel_shape = + ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(attrs, input_shape)); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + weights.push_back(make_weight_attrs(projection_shape, projection_initializer)); } if (use_bias) { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 77d938e08a..b88fe38042 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index c445085635..20bd0ac92d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/ops/conv_2d.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index e46a481a1a..471f2a2709 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" -#include "op-attrs/operator_attrs.h" #include "serialization.h" #include diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h index de9d1cd78a..b7ce13db0e 100644 --- a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc index 0bde326bd1..9fa91d75b7 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,5 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/get_output_shapes.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 20ab6ce440..8620529b88 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -11,9 +11,6 @@ namespace FlexFlow { -template -Element sum(Container const &container); - template diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 33f94811d6..68a4cffe80 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -147,34 +147,6 @@ bool are_all_same(C const &c) { return true; } -template -std::vector flatmap(std::vector const &v, F const &f) { - std::vector result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - template std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; diff --git a/lib/utils/include/utils/containers/cartesian_product.h b/lib/utils/include/utils/containers/cartesian_product.h new file mode 100644 index 0000000000..bcba52113e --- /dev/null +++ b/lib/utils/include/utils/containers/cartesian_product.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H + +#include "utils/containers/vector_of.h" +#include "utils/hash/vector.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set> + cartesian_product(std::vector> const &containers) { + std::unordered_set> result; + + std::function &, size_t)> recurse = + [&](std::vector ¤t, size_t depth) { + if (depth == containers.size()) { + result.insert(current); + return; + } + + for (E const &item : containers.at(depth)) { + current.push_back(item); + recurse(current, depth + 1); + current.pop_back(); + } + }; + + std::vector current; + recurse(current, 0); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h new file mode 100644 index 0000000000..5980996c27 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H + +#include +#include +#include "utils/containers/vector_of.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/containers/keys.h" +#include "utils/containers/zip.h" +#include "utils/containers/cartesian_product.h" +#include "utils/hash/unordered_map.h" +#include + +namespace FlexFlow { + +/** + * @note If \p options_per_key is empty, an empty set is returned from the + * function (not a set containing an empty set, as the "empty" assignment is + * not considered a valid assignment) + */ +template +std::unordered_set> get_all_assignments(std::unordered_map> const &options_per_key) { + if (options_per_key.empty()) { + return {}; + } + + std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector> ordered_value_option_sets = transform(ordered_keys, + [&](K const &k) { return options_per_key.at(k); }); + + std::unordered_set> result = transform( + cartesian_product(ordered_value_option_sets), + [&](std::vector const &chosen_values) { + return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); + }); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ec3d5f5612..05a955c485 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -24,7 +24,7 @@ auto transform(req const &c, F const &f) template ()(std::declval()))> + typename Out = std::invoke_result_t> std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; for (auto const &e : v) { @@ -35,7 +35,18 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { template ()(std::declval()))> + typename Out = std::invoke_result_t> +std::unordered_multiset transform(std::unordered_multiset const &v, F const &f) { + std::unordered_multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + +template > std::set transform(std::set const &v, F const &f) { std::set result; for (auto const &e : v) { @@ -44,6 +55,17 @@ std::set transform(std::set const &v, F const &f) { return result; } +template > +std::multiset transform(std::multiset const &v, F const &f) { + std::multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template std::string transform(std::string const &s, F const &f) { std::string result; diff --git a/lib/utils/include/utils/containers/unordered_map_from_pairs.h b/lib/utils/include/utils/containers/unordered_map_from_pairs.h new file mode 100644 index 0000000000..34a1d91e86 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_pairs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H + +#include + +namespace FlexFlow { + +template +std::unordered_map + unordered_map_from_pairs(C const &c) { + return std::unordered_map(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h new file mode 100644 index 0000000000..5c4632ca2a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node(DataflowGraphView const &g, + Node const &src, + Node const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h new file mode 100644 index 0000000000..ddb3ca1c68 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" +namespace FlexFlow { + +std::unordered_set get_edges_from_subgraph_to_subgraph(DiGraphView const &, + std::unordered_set const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/cartesian_product.cc b/lib/utils/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..b716a49ad5 --- /dev/null +++ b/lib/utils/src/utils/containers/cartesian_product.cc @@ -0,0 +1 @@ +#include "utils/containers/cartesian_product.h" diff --git a/lib/utils/src/utils/containers/get_all_assignments.cc b/lib/utils/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..3a7cf6377a --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_assignments.cc @@ -0,0 +1 @@ +#include "utils/containers/get_all_assignments.h" diff --git a/lib/utils/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..60cc978be7 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/containers/unordered_map_from_pairs.h" diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..53e17f3917 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,16 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node(DataflowGraphView const &g, + Node const &src, + Node const &dst) { + return g.query_edges(DataflowEdgeQuery{ + /*src_nodes=*/query_set{src}, + /*src_idxs=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_idxs=*/query_set::matchall(), + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..da6cd1d493 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,20 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +std::unordered_set get_edges_from_subgraph_to_subgraph(DiGraphView const &g, + std::unordered_set const &src_subgraph, + std::unordered_set const &dst_subgraph) { + if (!are_disjoint(src_subgraph, dst_subgraph)) { + throw mk_runtime_error(fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) expected src_subgraph and dst_subgraph to be disjoint, " + "but found src_subgraph={}, dst_subgraph={}", src_subgraph, dst_subgraph)); + } + + return g.query_edges(DirectedEdgeQuery{ + /*srcs=*/query_set{src_subgraph}, + /*dsts=*/query_set{dst_subgraph}, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/containers/cartesian_product.cc b/lib/utils/test/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..42b8a10439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/cartesian_product.cc @@ -0,0 +1,62 @@ +#include "utils/containers/cartesian_product.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cartesian_product") { + + SUBCASE("empty") { + std::vector> containers = {}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{}}; + CHECK(result == correct); + } + + SUBCASE("single container, one element") { + std::vector> containers = {{1}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}}; + CHECK(result == correct); + } + + SUBCASE("single container, multiple elements") { + std::vector> containers = {{1, 2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}, {2}, {3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, one element each") { + std::vector> containers = {{1}, {2}, {3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1, 2, 3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, multiple elements") { + std::vector> containers = {{1, 2}, {3, 4}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = { + {1, 3}, {1, 4}, {2, 3}, {2, 4}}; + CHECK(result == correct); + } + + SUBCASE("1 empty container, 1 non-empty container") { + std::vector> containers = {{}, {2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..2b2810efe5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -0,0 +1,50 @@ +#include "utils/containers/get_all_assignments.h" +#include +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_assignments") { + SUBCASE("empty input") { + std::unordered_map> input = {}; + + std::unordered_set> result = get_all_assignments(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("non-empty input") { + std::unordered_map> input = { + {"a", {1, 2, 3}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = get_all_assignments(input); + std::unordered_set> correct = { + {{"a", 1}, {"b", 2}}, + {{"a", 1}, {"b", 3}}, + {{"a", 2}, {"b", 2}}, + {{"a", 2}, {"b", 3}}, + {{"a", 3}, {"b", 2}}, + {{"a", 3}, {"b", 3}}, + }; + + CHECK(result == correct); + } + + SUBCASE("one possible-values set is empty") { + std::unordered_map> input = { + {"a", {}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = get_all_assignments(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..a87e54ed8e --- /dev/null +++ b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1,53 @@ +#include "utils/containers/unordered_map_from_pairs.h" +#include +#include +#include +#include "utils/containers/contains.h" +#include "test/utils/doctest/fmt/unordered_map.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("unordered_map_from_pairs") { + SUBCASE("nonempty input") { + std::vector> input = { + {1, "hello"}, + {3, "world"}, + }; + + std::unordered_map result = unordered_map_from_pairs(input); + std::unordered_map correct = { + {1, "hello"}, + {3, "world"}, + }; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector> input = {}; + + std::unordered_map result = unordered_map_from_pairs(input); + std::unordered_map correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input with duplicate keys") { + std::vector> input = { + {1, "a"}, + {2, "c"}, + {1, "b"}, + }; + + std::unordered_map result = unordered_map_from_pairs(input); + + std::vector> possible_correct_values = { + {{1, "a"}, {2, "c"}}, + {{1, "b"}, {2, "c"}}, + }; + + CHECK(contains(possible_correct_values, result)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..a93f22802c --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,97 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dataflow_edges_from_node_to_node") { + DataflowGraph g = DataflowGraph::create(); + + SUBCASE("gets edges if there are multiple") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o0 = n1_added.outputs.at(0); + DataflowOutput n1_o1 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o0, n1_o0, n1_o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = { + DataflowEdge{ + n1_o0, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not get edges to/from other nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n3); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("does not get flipped edges (i.e., respects from vs to direction)") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n2, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if no edges exist between the given nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + + std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if src node == dst node (as cycles cannot exist in DataflowGraph") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..bf51bd028e --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,131 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_edges_from_subgraph_to_subgraph") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + SUBCASE("basic tests") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(5)}; + std::unordered_set dst_subgraph = {n.at(2), n.at(3)}; + + SUBCASE("returns all edges between subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(5), n.at(2)}, + }; + + add_edges(g, e); + + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = unordered_set_of(e); + + CHECK(result == correct); + } + + SUBCASE("does not return reverse edges") { + std::vector e = { + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(0)}, + }; + + add_edges(g, e); + + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(0)}; + + CHECK(result == correct); + } + + SUBCASE("does not return edges within subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if there are no edges from src_subgraph to dst_subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("empty subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + SUBCASE("returns no edges if no nodes in src_subgraph") { + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if no nodes in dst_subgraph") { + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if both subgraphs are empty") { + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, {}, {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("if subgraphs do not cover graph, then does not return external edges") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set src_subgraph = {n.at(0)}; + std::unordered_set dst_subgraph = {n.at(3)}; + + std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("throws an error if subgraphs are not disjoint") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(2)}; + std::unordered_set dst_subgraph = {n.at(1), n.at(3)}; + CHECK_THROWS(get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); + } + } +} From 00c2baebea9290e52e619d27a7c4d0f6461d9685 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 27 Sep 2024 14:32:43 -0700 Subject: [PATCH 14/29] Get tests building again --- .../include/compiler/cost_estimator.h | 60 ----- .../compiler/cost_estimator/cost_estimator.h | 45 ++++ .../op_cost_estimate_key.struct.toml | 0 .../single_tensor_movement.struct.toml} | 15 +- .../tensor_set_movement.struct.toml | 21 ++ .../estimate_cost_across_split.h | 18 -- .../machine_mapping/estimate_layer_cost.h | 2 +- .../get_tensor_set_movement_across_split.h | 18 ++ .../include_unconstrained.struct.toml | 16 ++ .../machine_mapping_context.struct.toml | 2 +- .../machine_mapping/partial_machine_mapping.h | 24 +- .../partial_machine_mapping.struct.toml | 4 +- .../pcg_binary_sp_decomposition.h | 6 +- .../include/compiler/unity_algorithm.h | 2 +- .../compiler/cost_estimator/cost_estimator.cc | 16 ++ .../estimate_cost_across_split.cc | 37 --- .../machine_mapping/estimate_layer_cost.cc | 14 +- .../get_allowed_machine_views_list.cc | 104 ++++---- .../get_optimal_machine_mapping.cc | 56 +++-- .../get_tensor_set_movement_across_split.cc | 56 +++++ .../machine_mapping_context.cc | 5 + .../partial_machine_mapping.cc | 76 ++++-- .../machine_mapping/transitive_reduced_pcg.cc | 14 +- .../pcg_binary_sp_decomposition.cc | 28 +++ lib/compiler/src/graph_optimize_state.cc | 6 +- .../cost_estimator_for_test.cc | 73 ++---- .../machine_mapping/cost_estimator_for_test.h | 46 ++-- .../estimate_cost_across_split.cc | 36 --- .../machine_mapping/estimate_layer_cost.cc | 2 +- .../get_machine_resource_splits.cc | 2 + .../get_optimal_machine_mapping.cc | 20 +- .../get_tensor_set_movement_across_split.cc | 233 ++++++++++++++++++ .../machine_mapping/machine_mapping.cc | 1 + .../machine_mapping/machine_mapping_cache.cc | 83 +++---- .../machine_mapping/machine_mapping_result.cc | 9 +- lib/compiler/test/src/graph_optimize_state.cc | 4 +- .../parallel_computation_graph.h | 4 + .../parallel_computation_graph.cc | 20 ++ .../parallel_computation_graph.cc | 43 ++++ lib/utils/include/utils/containers/values.h | 8 +- .../graph/instances/adjacency_multidigraph.cc | 1 + .../algorithms/get_subgraph_inputs.cc | 1 + 42 files changed, 797 insertions(+), 434 deletions(-) delete mode 100644 lib/compiler/include/compiler/cost_estimator.h create mode 100644 lib/compiler/include/compiler/cost_estimator/cost_estimator.h rename lib/compiler/include/compiler/{ => cost_estimator}/op_cost_estimate_key.struct.toml (100%) rename lib/compiler/include/compiler/{comm_cost_estimate_key.struct.toml => cost_estimator/single_tensor_movement.struct.toml} (51%) create mode 100644 lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml create mode 100644 lib/compiler/src/compiler/cost_estimator/cost_estimator.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc delete mode 100644 lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc diff --git a/lib/compiler/include/compiler/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator.h deleted file mode 100644 index 52e82ad8d5..0000000000 --- a/lib/compiler/include/compiler/cost_estimator.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_H - -#include -#include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "pcg/machine_view.dtg.h" - -namespace FlexFlow { - -struct ICostEstimator { - virtual float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const = 0; - virtual float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const = 0; - - ICostEstimator() = default; - ICostEstimator(ICostEstimator const &) = delete; - ICostEstimator &operator=(ICostEstimator const &) = delete; - - virtual ~ICostEstimator() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); - -struct CostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const { - return this->implementation_ptr->estimate_cost( - op, inputs, weights, outputs, mv); - } - - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const { - return this->implementation_ptr->estimate_cost(tensor_shape, src, dst); - } - - template - static typename std::enable_if::value, - CostEstimator>::type - create(Args &&...args) { - return CostEstimator(std::make_shared(std::forward(args)...)); - } - -private: - CostEstimator(std::shared_ptr implementation_ptr) - : implementation_ptr(implementation_ptr) {} - std::shared_ptr implementation_ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h new file mode 100644 index 0000000000..7d3aa6bb9f --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H + +#include +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/machine_view.dtg.h" + +namespace FlexFlow { + +struct ICostEstimator { + virtual float estimate_cost(OpCostEstimateKey const &) const = 0; + virtual float estimate_cost(TensorSetMovement const &) const = 0; + + ICostEstimator() = default; + ICostEstimator(ICostEstimator const &) = delete; + ICostEstimator &operator=(ICostEstimator const &) = delete; + + virtual ~ICostEstimator() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); + +struct CostEstimator { + float estimate_cost(OpCostEstimateKey const &k) const; + float estimate_cost(TensorSetMovement const &m) const; + + template + static typename std::enable_if::value, + CostEstimator>::type + create(Args &&...args) { + return CostEstimator(std::make_shared(std::forward(args)...)); + } + +private: + CostEstimator(std::shared_ptr implementation_ptr); + +private: + std::shared_ptr implementation_ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml similarity index 100% rename from lib/compiler/include/compiler/op_cost_estimate_key.struct.toml rename to lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml diff --git a/lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml similarity index 51% rename from lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml rename to lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml index ae45b493c3..52f66f3420 100644 --- a/lib/compiler/include/compiler/comm_cost_estimate_key.struct.toml +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -1,8 +1,7 @@ namespace = "FlexFlow" -name = "CommCostEstimateKey" +name = "SingleTensorMovement" features = [ "eq", - "ord", "hash", "fmt", ] @@ -10,6 +9,12 @@ features = [ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "pcg/machine_view.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", ] [[fields]] @@ -17,9 +22,9 @@ name = "parallel_tensor_shape" type = "::FlexFlow::ParallelTensorShape" [[fields]] -name = "src_machine_view" -type = "::FlexFlow::MachineView" +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" [[fields]] name = "dst_machine_view" -type = "::FlexFlow::MachineView" +type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml new file mode 100644 index 0000000000..3625605239 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h b/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h deleted file mode 100644 index a1f061b15b..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/estimate_cost_across_split.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H - -#include "compiler/cost_estimator.h" -#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" - -namespace FlexFlow { - -float estimate_cost_across_split(TransitiveReducedPCG const &, - CostEstimator const &, - std::unordered_map const &, - std::unordered_map const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h index 69370aabda..dcb8856fe8 100644 --- a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h +++ b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H -#include "compiler/cost_estimator.h" +#include "compiler/cost_estimator/cost_estimator.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..9becde61c3 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split, + PartialMachineMapping const &pre_mapping, + PartialMachineMapping const &post_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml new file mode 100644 index 0000000000..b9a7f9ac59 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "IncludeUnconstrained" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [] + +[[fields]] +name = "raw_bool" +type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index 270d57fe98..272d4c2097 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -3,7 +3,7 @@ name = "MachineMappingContext" features = [] includes = [ - "compiler/cost_estimator.h", + "compiler/cost_estimator/cost_estimator.h", "pcg/machine_view.dtg.h", "pcg/machine_specification.dtg.h", "compiler/machine_mapping/transitive_reduced_pcg.dtg.h", diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h index 45be24974f..4ed43b3470 100644 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h @@ -1,25 +1,31 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H +#include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "compiler/machine_mapping/partial_machine_mapping.dtg.h" +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" namespace FlexFlow { -PartialMachineMapping get_unconstrained_solution(); +PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &); -PartialMachineMapping get_sub_solution(MachineMappingContext const &ctx, - PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem); +std::unordered_set get_all_layers(PartialMachineMapping const &, + IncludeUnconstrained const &); -PartialMachineMapping with_additional_tensor_machine_views(MachineMappingContext const &ctx, - PartialMachineMapping const &partial_solution, - std::unordered_map const &additional); +std::optional get_machine_view_for_layer(PartialMachineMapping const &, + parallel_layer_guid_t const &); -PartialMachineMapping with_additional_layer_machine_views(MachineMappingContext const &ctx, - PartialMachineMapping const &partial_solution, +PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, + PCGBinarySPDecomposition const &sub_problem); + +PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, std::unordered_map const &additional); +MachineMapping require_complete_mapping(PartialMachineMapping const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml index f63c51a4c4..b1955185ad 100644 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml @@ -9,13 +9,15 @@ features = [ includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", + "", ] src_includes = [ "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", ] [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h index a75f1a8116..eca0cd7d0b 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h @@ -17,9 +17,9 @@ std::unordered_multiset SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); -PCGBinarySeriesSplit make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); -PCGBinarySeriesSplit make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); -PCGBinarySeriesSplit make_pcg_leaf_node(parallel_layer_guid_t const &); +PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &); PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &); PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &); diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 9eeb9fe563..7ac9759650 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #include "compiler/graph_optimize_result.dtg.h" -#include "cost_estimator.h" +#include "compiler/cost_estimator/cost_estimator.h" #include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc new file mode 100644 index 0000000000..051ffcd190 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -0,0 +1,16 @@ +#include "compiler/cost_estimator/cost_estimator.h" + +namespace FlexFlow { + +CostEstimator::CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} + +float CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->implementation_ptr->estimate_cost(k); +} + +float CostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->implementation_ptr->estimate_cost(m); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc b/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc deleted file mode 100644 index 8d7d6ccd03..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/estimate_cost_across_split.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "compiler/machine_mapping/estimate_cost_across_split.h" -#include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/containers/keys.h" -#include "utils/containers/sum.h" - -namespace FlexFlow { - -float estimate_cost_across_split(TransitiveReducedPCG const &tr_pcg, - CostEstimator const &cost_estimator, - std::unordered_map const &pre_machine_views, - std::unordered_map const &post_machine_views) { - std::unordered_set - edges_across_split = get_transitive_reduced_edges_across_split(tr_pcg, - keys(pre_machine_views), - keys(post_machine_views)); - - auto get_cost_of_edge = [&](ParallelComputationGraphEdge const &e) { - MachineView src_view = pre_machine_views.at(get_src_layer(e)); - MachineView dst_view = post_machine_views.at(get_dst_layer(e)); - ParallelTensorShape tensor_shape = get_parallel_tensor_shape(tr_pcg.full_pcg, - get_parallel_tensor(e)); - - return cost_estimator.estimate_cost(tensor_shape, src_view, dst_view); - }; - - // note this is only correct for certain split types, and for others (tensor reuse, etc.) this is - // an overapproximation. This should eventually get fixed. - return sum(transform(edges_across_split, get_cost_of_edge)); -} - - -} // namespace FlexFlow - - diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc index 02d31ec7f0..1caa31aefc 100644 --- a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -17,11 +17,15 @@ float estimate_layer_cost(ParallelComputationGraph const &pcg, std::vector weight_tensors = get_incoming_weights(pcg, layer); std::vector output_tensors = get_layer_outputs(pcg, layer); - return cost_estimator.estimate_cost(op_attrs, - transform(input_tensors, get_tensor_shape), - transform(weight_tensors, get_tensor_shape), - transform(output_tensors, get_tensor_shape), - machine_view); + OpCostEstimateKey key = OpCostEstimateKey{ + /*op_attrs=*/op_attrs, + /*input_shapes=*/transform(input_tensors, get_tensor_shape), + /*weight_shapes=*/transform(weight_tensors, get_tensor_shape), + /*output_shapes=*/transform(output_tensors, get_tensor_shape), + /*machine_view=*/machine_view, + }; + + return cost_estimator.estimate_cost(key); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc b/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc index 3c80d75289..717aa66a9b 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc @@ -13,35 +13,37 @@ std::vector> MachineMappingContext const &context, std::unordered_set const &layers, MachineSpecification const &resource) { - if (layers.empty()) { - return {{}}; - } - parallel_layer_guid_t curr_layer = get_first(layers); - std::unordered_set other_layers = - set_minus(layers, {curr_layer}); + NOT_IMPLEMENTED(); - std::vector> - other_machine_views_from_recursion = - get_allowed_machine_views_list(context, other_layers, resource); - - ParallelLayerAttrs curr_layer_attrs = - get_parallel_layer_attrs(context.pcg, curr_layer); - std::unordered_set allowed_machine_views_for_curr_layer = - context.allowed_machine_views(curr_layer_attrs, resource); - - std::vector> result; - - for (MachineView const &for_curr_node : - allowed_machine_views_for_curr_layer) { - for (std::unordered_map const - &for_other_layers : other_machine_views_from_recursion) { - result.push_back( - merge_maps(for_other_layers, - std::unordered_map{ - {curr_layer, for_curr_node}})); - } - } - return result; + // if (layers.empty()) { + // return {{}}; + // } + // parallel_layer_guid_t curr_layer = get_first(layers); + // std::unordered_set other_layers = + // set_minus(layers, {curr_layer}); + // + // std::vector> + // other_machine_views_from_recursion = + // get_allowed_machine_views_list(context, other_layers, resource); + // + // ParallelLayerAttrs curr_layer_attrs = + // get_parallel_layer_attrs(context.pcg, curr_layer); + // std::unordered_set allowed_machine_views_for_curr_layer = + // context.allowed_machine_views(curr_layer_attrs, resource); + // + // std::vector> result; + // + // for (MachineView const &for_curr_node : + // allowed_machine_views_for_curr_layer) { + // for (std::unordered_map const + // &for_other_layers : other_machine_views_from_recursion) { + // result.push_back( + // merge_maps(for_other_layers, + // std::unordered_map{ + // {curr_layer, for_curr_node}})); + // } + // } + // return result; } std::vector> @@ -49,29 +51,31 @@ std::vector> MachineMappingContext const &context, std::unordered_set const &tensors, MachineSpecification const &resource) { - std::unordered_set layers; - for (parallel_tensor_guid_t const &tensor : tensors) { - layers.insert(get_source_layer(tensor)); - } - - std::vector> - machine_views_for_layers_list = - get_allowed_machine_views_list(context, layers, resource); - - std::vector> result; - - for (std::unordered_map - machine_views_for_layers : machine_views_for_layers_list) { - std::unordered_map - machine_views_for_tensors; - for (parallel_tensor_guid_t const &tensor : tensors) { - machine_views_for_tensors.emplace( - tensor, machine_views_for_layers.at(get_source_layer(tensor))); - } - result.push_back(machine_views_for_tensors); - } + NOT_IMPLEMENTED(); - return result; + // std::unordered_set layers; + // for (parallel_tensor_guid_t const &tensor : tensors) { + // layers.insert(get_source_layer(tensor)); + // } + // + // std::vector> + // machine_views_for_layers_list = + // get_allowed_machine_views_list(context, layers, resource); + // + // std::vector> result; + // + // for (std::unordered_map + // machine_views_for_layers : machine_views_for_layers_list) { + // std::unordered_map + // machine_views_for_tensors; + // for (parallel_tensor_guid_t const &tensor : tensors) { + // machine_views_for_tensors.emplace( + // tensor, machine_views_for_layers.at(get_source_layer(tensor))); + // } + // result.push_back(machine_views_for_tensors); + // } + // + // return result; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index e20e3a1883..5accd84260 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "compiler/cost_estimator.h" +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/get_allowed_machine_views_list.h" #include "compiler/machine_mapping/get_machine_resource_splits.h" #include "compiler/machine_mapping/machine_mapping_result.h" @@ -16,7 +16,6 @@ #include "utils/containers/contains.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_all_assignments.h" -#include "utils/containers/keys.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/overload.h" @@ -52,7 +51,7 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, - MachineMappingContext &context, + MachineMappingContext const &context, MachineSpecification const &resources) { PCGBinarySPDecomposition sp_decomposition_tree = ({ @@ -63,16 +62,18 @@ MachineMappingResult get_optimal_machine_mapping_internal( returned.value(); }); + std::unordered_set all_layers = get_parallel_layers(context.transitive_reduced_pcg.full_pcg); + return get_optimal_machine_mapping_internal(result_cache, context, sp_decomposition_tree, resources, - get_unconstrained_solution()); + get_unconstrained_solution_for_layers(all_layers)); } MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, - MachineMappingContext &context, + MachineMappingContext const &context, PCGBinarySPDecomposition const &sp_decomposition_tree, MachineSpecification const &resources, PartialMachineMapping const &partial_solution) { @@ -92,7 +93,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingResult result = visit( sp_decomposition_tree, [&](auto const &decomp_tree_node) { - return get_optimal_machine_mapping_internal(context, decomp_tree_node, resources, partial_solution); + return get_optimal_machine_mapping_internal(result_cache, context, decomp_tree_node, resources, partial_solution); }); result_cache.save(state, result); @@ -102,7 +103,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingResult get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySeriesSplit const &series, + PCGBinarySeriesSplit const &series_split, MachineSpecification const &resource, PartialMachineMapping const &partial_solution) { @@ -113,15 +114,15 @@ MachineMappingResult get_optimal_machine_mapping_internal( return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); }; - PCGBinarySPDecomposition pre_sub_tree = get_left_child(series); - PCGBinarySPDecomposition post_sub_tree = get_right_child(series); + PCGBinarySPDecomposition pre_sub_tree = get_left_child(series_split); + PCGBinarySPDecomposition post_sub_tree = get_right_child(series_split); std::pair< std::unordered_set, std::unordered_set > boundary_layers = get_split_transitive_reduced_boundary_layers(context.transitive_reduced_pcg, - series); + series_split); auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) -> std::unordered_set> @@ -139,8 +140,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( PartialMachineMapping pre_candidate = with_additional_layer_machine_views( - context, - get_sub_solution(context, partial_solution, pre_sub_tree), + get_sub_solution(partial_solution, pre_sub_tree), assigned_pre_machine_views); MachineMappingResult pre_result = @@ -156,8 +156,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( PartialMachineMapping post_candidate = with_additional_layer_machine_views( - context, - get_sub_solution(context, partial_solution, post_sub_tree), + get_sub_solution(partial_solution, post_sub_tree), assigned_post_machine_views); MachineMappingResult post_result = @@ -167,9 +166,13 @@ MachineMappingResult get_optimal_machine_mapping_internal( resource, post_candidate); - float cost_across_split = estimate_cost_across_split(context, - assigned_pre_machine_views, - assigned_post_machine_views); + TensorSetMovement comm_across_split = get_tensor_set_movement_across_split( + /*transitive_reduced_pcg=*/context.transitive_reduced_pcg, + /*split=*/series_split, + /*pre_mapping=*/pre_candidate, + /*post_mapping=*/post_candidate); + + float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); minimize_runtime( optimal_result, @@ -190,19 +193,17 @@ MachineMappingResult get_optimal_machine_mapping_internal( PartialMachineMapping const &partial_solution) { PCGBinarySPDecomposition left_subtree = get_left_child(parallel); - PartialMachineMapping left_sub_solution = get_sub_solution(context, - partial_solution, + PartialMachineMapping left_sub_solution = get_sub_solution(partial_solution, left_subtree); PCGBinarySPDecomposition right_subtree = get_right_child(parallel); - PartialMachineMapping right_sub_solution = get_sub_solution(context, - partial_solution, + PartialMachineMapping right_sub_solution = get_sub_solution(partial_solution, right_subtree); MachineMappingResult optimal_result = [&] { - PCGBinarySeriesSplit series = make_pcg_series_split( + PCGBinarySeriesSplit series = require_series(make_pcg_series_split( get_left_child(parallel), - get_right_child(parallel)); + get_right_child(parallel))); return get_optimal_machine_mapping_internal(result_cache, context, series, @@ -233,21 +234,22 @@ MachineMappingResult get_optimal_machine_mapping_internal( } MachineMappingResult get_optimal_machine_mapping_internal( - MachineMappingContext &context, + MachineMappingCache &result_cache, + MachineMappingContext const &context, parallel_layer_guid_t const &layer, MachineSpecification const &resource, PartialMachineMapping const &partial_solution) { - assert (keys(partial_solution.machine_views) == std::unordered_set{layer}); + assert (get_all_layers(partial_solution, IncludeUnconstrained{true}) == std::unordered_set{layer}); float cost = estimate_layer_cost(context.transitive_reduced_pcg.full_pcg, context.cost_estimator, layer, - partial_solution.machine_views.at(layer)); + get_machine_view_for_layer(partial_solution, layer).value()); return MachineMappingResult{ /*runtime=*/cost, - /*machine_mapping=*/MachineMapping{partial_solution.machine_views}, + /*machine_mapping=*/require_complete_mapping(partial_solution), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..38f2bf2344 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,56 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split, + PartialMachineMapping const &pre_mapping, + PartialMachineMapping const &post_mapping) { + std::unordered_set + edges_across_split = get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { + std::unordered_set tensor_edges = filter(edges_across_split, + [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); + + std::unordered_set src_machine_views = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_machine_view_for_layer(pre_mapping, get_src_layer(e)).value(); + }); + + std::unordered_set dst_machine_views = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_machine_view_for_layer(post_mapping, get_dst_layer(e)).value(); + }); + + return SingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/src_machine_views, + /*dst_machine_views=*/dst_machine_views, + }; + }; + + std::unordered_map single_tensor_movements = + generate_map(get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return TensorSetMovement{ + values(single_tensor_movements), + }; +} + + +} // namespace FlexFlow + + diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc index 36e12ff5eb..c45e964a3a 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc @@ -11,6 +11,11 @@ std::unordered_set get_allowed_machine_views_for_tensor(MachineMapp NOT_IMPLEMENTED(); } +std::unordered_set get_allowed_machine_views_for_layer(MachineMappingContext const &, + parallel_layer_guid_t const &) { + NOT_IMPLEMENTED(); +} + MachineMappingContext make_machine_mapping_context(ParallelComputationGraph const &pcg, CostEstimator const &cost_estimator, std::function( diff --git a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc index 09b59c75cd..5ae3126184 100644 --- a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc @@ -1,30 +1,78 @@ #include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_context.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" #include "utils/containers/restrict_keys.h" +#include "utils/containers/map_values.h" namespace FlexFlow { -PartialMachineMapping get_sub_solution(MachineMappingContext const &ctx, - PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem) { - std::unordered_set sub_solution_layers = - flatmap(get_parallel_layers(sub_problem), - [&](parallel_layer_guid_t l) { - return set_union( - get_transitively_reduced_predecessors(ctx, l), - get_transitively_reduced_successors(ctx, l)); - }); +PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &layers) { + return PartialMachineMapping{ + generate_map(layers, + [](parallel_layer_guid_t const &) -> std::optional { + return std::nullopt; + }), + }; +} + +std::unordered_set get_all_layers(PartialMachineMapping const &partial_solution, + IncludeUnconstrained const &include_unconstrained) { + std::unordered_set with_unconstrained = keys(partial_solution.machine_views); + + if (include_unconstrained.raw_bool) { + return with_unconstrained; + } else { + return filter(with_unconstrained, + [&](parallel_layer_guid_t const &l) { return partial_solution.machine_views.at(l).has_value(); }); + } +} + +std::optional get_machine_view_for_layer(PartialMachineMapping const &partial_solution, + parallel_layer_guid_t const &layer) { + return partial_solution.machine_views.at(layer); +} + +PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, + PCGBinarySPDecomposition const &sub_problem) { + + std::unordered_set sub_problem_layers = unordered_set_of(get_parallel_layers(sub_problem)); return PartialMachineMapping{ - restrict_keys(partial_solution.machine_views, sub_solution_layers), + restrict_keys(partial_solution.machine_views, sub_problem_layers), }; } -MachineMapping require_complete(MachineMappingContext const &ctx, - PartialMachineMapping const &partial_solution) { - NOT_IMPLEMENTED(); +PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, + std::unordered_map const &additional) { + PartialMachineMapping result = partial_solution; + + for (auto const &[layer, machine_view] : additional) { + std::optional current_machine_view = result.machine_views.at(layer); + + if (!current_machine_view.has_value()) { + result.machine_views.at(layer) = machine_view; + } else { + if (current_machine_view.value() != machine_view) { + throw mk_runtime_error(fmt::format("with_additional_layer_machine_views received machine view assignment for layer {} " + "to machine view {}, but that layer is already assigned to machine view {}.", + layer, machine_view, current_machine_view.value())); + } + } + } + + return result; +} + + +MachineMapping require_complete_mapping(PartialMachineMapping const &partial_mapping) { + return MachineMapping{ + map_values(partial_mapping.machine_views, + [](std::optional const &mv) { return mv.value(); }), + }; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index 9c1d5bb05e..fe7b05e7b2 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -13,11 +13,11 @@ namespace FlexFlow { TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const &pcg) { DiGraphView raw_digraph = pcg.raw_graph; - DiGraphView transitively_reduced = transitive_reduction(raw_digraph); + DiGraphView transitive_reduced = transitive_reduction(raw_digraph); return TransitiveReducedPCG{ /*pcg=*/pcg, - /*transitive_reduction=*/transitively_reduced, + /*transitive_reduction=*/transitive_reduced, }; } @@ -34,7 +34,7 @@ std::unordered_set get_transitive_reduced_successors(Tran } std::unordered_set - get_transitively_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, + get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { std::unordered_set src_subgraph = unordered_set_of(get_parallel_layers(get_left_child(split))); std::unordered_set dst_subgraph = unordered_set_of(get_parallel_layers(get_right_child(split))); @@ -55,16 +55,16 @@ std::unordered_set } std::unordered_set - get_transitively_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { - return transform(get_transitively_reduced_edges_across_split(tr_pcg, split), + get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + return transform(get_transitive_reduced_edges_across_split(tr_pcg, split), [](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e); }); } std::pair< std::unordered_set, std::unordered_set -> get_split_transitively_reduced_boundary_layers(TransitiveReducedPCG const &tr_pcg, +> get_split_transitive_reduced_boundary_layers(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { std::unordered_set edges = get_transitive_reduced_edges_across_split(tr_pcg, split); diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index be9ccd3ab2..5c5233a494 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -12,4 +12,32 @@ std::unordered_multiset NOT_IMPLEMENTED(); } +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &) { + NOT_IMPLEMENTED(); +} + +PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + +parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc index 71dd7f0ec1..4b4f323ea4 100644 --- a/lib/compiler/src/graph_optimize_state.cc +++ b/lib/compiler/src/graph_optimize_state.cc @@ -22,9 +22,9 @@ bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { return false; } - auto inputs1 = get_layer_inputs(graph_optimize_result.pcg, layers1[i]); + auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); auto inputs2 = - get_layer_inputs(other.graph_optimize_result.pcg, layers2[i]); + get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); if (inputs1.size() != inputs2.size()) { return false; } @@ -68,7 +68,7 @@ size_t hash<::FlexFlow::GraphOptimizeState>::operator()( for (auto layer : layers) { ::FlexFlow::hash_combine( seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); - auto inputs = get_layer_inputs(state.graph_optimize_result.pcg, layer); + auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); ::FlexFlow::hash_combine(seed, inputs.size()); for (auto input : inputs) { for (size_t i = 0; i < layers.size(); ++i) { diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc index c3e7a8f3bf..5a7f56eb79 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -3,75 +3,36 @@ namespace FlexFlow { TestCostEstimator::TestCostEstimator( - std::function const &, - std::vector const &, - std::vector const &, - MachineView const &)> const &get_operator_cost, - std::function const &get_communication_cost) + std::function const &get_operator_cost, + std::function const & get_communication_cost) : get_operator_cost(get_operator_cost), get_communication_cost(get_communication_cost) { } -float TestCostEstimator::estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const { - return this->get_operator_cost(op, inputs, weights, outputs, mv); +float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->get_operator_cost(k); } -float TestCostEstimator::estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const { - return this->get_communication_cost(tensor_shape, src, dst); +float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->get_communication_cost(m); } -CostEstimator make_cost_estimator( - std::function const &, - std::vector const &, - std::vector const &, - MachineView const &)> const &get_operator_cost, - std::function const &get_communication_cost) { +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const &get_communication_cost) { + return CostEstimator::create(get_operator_cost, get_communication_cost); } -CostEstimator make_cost_estimator( +CostEstimator make_fake_cost_estimator( std::unordered_map const &op_cost_map, - std::unordered_map const &comm_cost_map) { - return make_cost_estimator( - [op_cost_map](PCGOperatorAttrs const &op_attrs, - std::vector const &input_shapes, - std::vector const &weight_shapes, - std::vector const &output_shapes, - MachineView const &machine_view) { - - OpCostEstimateKey key = OpCostEstimateKey{ - /*op_attrs=*/op_attrs, - /*input_shapes=*/input_shapes, - /*weight_shapes=*/weight_shapes, - /*output_shapes=*/output_shapes, - /*machine_view=*/machine_view, - }; - - return op_cost_map.at(key); + std::unordered_map const &comm_cost_map) { + return make_fake_cost_estimator( + [op_cost_map](OpCostEstimateKey const &k) { + return op_cost_map.at(k); }, - [comm_cost_map](ParallelTensorShape const ¶llel_tensor_shape, - MachineView const &src_machine_view, - MachineView const &dst_machine_view) { - - CommCostEstimateKey key = CommCostEstimateKey{ - /*parallel_tensor_shape=*/parallel_tensor_shape, - /*src_machine_view=*/src_machine_view, - /*dst_machine_view=*/dst_machine_view, - }; - - return comm_cost_map.at(key); + [comm_cost_map](TensorSetMovement const &m) { + return comm_cost_map.at(m); }); } diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index bfb3f6d8eb..2fa9e6028f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -1,50 +1,32 @@ #ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H #define _FLEXFLOW_TEST_COST_ESTIMATOR_H -#include "compiler/cost_estimator.h" -#include "compiler/op_cost_estimate_key.dtg.h" -#include "compiler/comm_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" namespace FlexFlow { struct TestCostEstimator : public ICostEstimator { - std::function const &, - std::vector const &, - std::vector const &, - MachineView const &)> get_operator_cost; - std::function get_communication_cost; + std::function get_operator_cost; + std::function get_communication_cost; TestCostEstimator() = delete; TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, decltype(get_communication_cost) const &get_communication_cost); - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override; + float estimate_cost(OpCostEstimateKey const &) const override; - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override; + float estimate_cost(TensorSetMovement const &) const override; }; -CostEstimator make_cost_estimator( - std::function const &, - std::vector const &, - std::vector const &, - MachineView const &)> const &get_operator_cost, - std::function const &get_communication_cost); - -CostEstimator make_cost_estimator( - std::unordered_map const &, - std::unordered_map const &); +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const &get_communication_cost); + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc deleted file mode 100644 index 49200fdd50..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/estimate_cost_across_split.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "compiler/machine_mapping/estimate_cost_across_split.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("estimate_cost_across_split") { - SUBCASE("single edge across split") { - SUBCASE("src and dst layers have same MachineView") { - FAIL("TODO"); - } - - SUBCASE("src and dst layers have different MachineViews") { - FAIL("TODO"); - } - } - - SUBCASE("single tensor, multiple consumers across split") { - SUBCASE("consumers have same view") { - FAIL("TODO"); - } - - SUBCASE("consumers have non-overlapping views") { - FAIL("TODO"); - } - - SUBCASE("consumers have different but overlapping views") { - FAIL("TODO"); - } - } - - SUBCASE("multiple tensors, multiple consumers across split") { - FAIL("TODO"); - } - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc index 1daa6aa272..00a99bfbf8 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -95,7 +95,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); - CostEstimator cost_estimator = make_cost_estimator( + CostEstimator cost_estimator = make_fake_cost_estimator( { { OpCostEstimateKey{ diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc index af2814bca0..1c4aee109a 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -1,6 +1,8 @@ #include "compiler/machine_mapping/get_machine_resource_splits.h" #include #include "utils/hash/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/pair.h" using namespace ::FlexFlow; diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 16e7b46b09..3c4ac1174c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,5 +1,5 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "cost_estimator_for_test.h" +#include "./cost_estimator_for_test.h" #include #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" @@ -14,7 +14,6 @@ TEST_SUITE(FF_TEST_SUITE) { return std::unordered_set{ make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; - CostEstimator estimator1 = CostEstimator::create(); MachineSpecification machine_spec = MachineSpecification{ /*num_nodes=*/2, /*num_cpus_per_node=*/1, @@ -23,6 +22,9 @@ TEST_SUITE(FF_TEST_SUITE) { /*intra_node_bandwidth=*/1, }; + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{}, + std::unordered_map{}); SUBCASE("single layer") { ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -34,20 +36,6 @@ TEST_SUITE(FF_TEST_SUITE) { return std::unordered_set{mv1}; }; - CostEstimator cost_estimator = make_cost_estimator( - [&](PCGOperatorAttrs const &, - std::vector const &, - std::vector const &, - std::vector const &, - MachineView const &) { - return 1.0; - }, - [&](ParallelTensorShape const &, - MachineView const &, - MachineView const &) { - return 0.5; - }); - ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{ InputAttrs{}, diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..b77325fe86 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,233 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include +#include "./cost_estimator_for_test.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("estimate_cost_across_split") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAttrs relu_attrs + = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + ParallelLayerAddedResult relu_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult relu_2 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(relu_1.outputs)}, + {relu_output_attrs}); + + MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); + MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); + MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); + MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); + + SUBCASE("single edge across split") { + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_leaf_node(relu_2.parallel_layer))); + + PartialMachineMapping pre_mapping = PartialMachineMapping{{ + {relu_1.parallel_layer, pre_mv1}, + }}; + + PartialMachineMapping post_mapping = PartialMachineMapping{{ + {relu_2.parallel_layer, post_mv1}, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") { + + } + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult relu_3 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(relu_1.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_parallel_split( + make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer)))); + + SUBCASE("consumers have same view") { + PartialMachineMapping pre_mapping = PartialMachineMapping{{ + {relu_1.parallel_layer, pre_mv1}, + }}; + + PartialMachineMapping post_mapping = PartialMachineMapping{{ + {relu_2.parallel_layer, post_mv1}, + {relu_3.parallel_layer, post_mv1}, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("consumers have different views") { + PartialMachineMapping pre_mapping = PartialMachineMapping{{ + {relu_1.parallel_layer, pre_mv1}, + }}; + + PartialMachineMapping post_mapping = PartialMachineMapping{{ + {relu_2.parallel_layer, post_mv1}, + {relu_3.parallel_layer, post_mv2}, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult relu_3 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + + ParallelLayerAddedResult relu_4 + = add_parallel_layer(pcg, + relu_attrs, + // relu's don't have two inputs, but for the purposes of this test it's fine. + {get_only(relu_1.outputs), get_only(relu_3.outputs)}, + {relu_output_attrs}); + + PartialMachineMapping pre_mapping = PartialMachineMapping{{ + {relu_1.parallel_layer, pre_mv1}, + {relu_3.parallel_layer, pre_mv2}, + }}; + + PartialMachineMapping post_mapping = PartialMachineMapping{{ + {relu_2.parallel_layer, post_mv1}, + {relu_4.parallel_layer, post_mv2}, + }}; + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_parallel_split( + make_pcg_leaf_node(relu_1.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer))), + make_pcg_parallel_split( + make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_4.parallel_layer)))); + + TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index ffd20c429a..6b16a54c1f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping.h" #include "cost_estimator_for_test.h" #include "doctest/doctest.h" +#include "pcg/machine_view.h" using namespace FlexFlow; diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc index 9c09efe3fa..fc521e110c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,9 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" -#include "compiler/machine_mapping/split_sp_decomposition.h" -#include "cost_estimator_for_test.h" +#include "./cost_estimator_for_test.h" #include "doctest/doctest.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" using namespace FlexFlow; @@ -49,44 +47,45 @@ TEST_SUITE(FF_TEST_SUITE) { return builder.pcg; }(); - SerialParallelDecomposition subgraph0 = - get_serial_parallel_decomposition(pcg.raw_graph).value(); - auto [subgraph1, subgraph2] = - split_sp_decomposition(subgraph0.get()); - - MachineSpecification machine_spec(1, 1, 1, 1, 1); - MachineMappingState state0(subgraph0, machine_spec, {}); - MachineMappingState state1(subgraph1, machine_spec, {}); - MachineMappingState state2(subgraph2, machine_spec, {}); - - MachineMappingResult result0( - 2, - MachineMapping( - std::unordered_map{})); - MachineMappingResult result1( - 1, - MachineMapping( - std::unordered_map{})); - MachineMappingResult result2( - 1, - MachineMapping( - std::unordered_map{})); - - MachineMappingCache cache; - - cache.save(state0, result0); - CHECK(cache.load(state0).value() == result0); - CHECK(!cache.load(state1)); - CHECK(!cache.load(state2)); - - cache.save(state1, result1); - CHECK(cache.load(state0).value() == result0); - CHECK(cache.load(state1).value() == result1); - CHECK(!cache.load(state2)); - - cache.save(state2, result2); - CHECK(cache.load(state0).value() == result0); - CHECK(cache.load(state1).value() == result1); - CHECK(cache.load(state2).value() == result2); + FAIL("TODO"); + // SerialParallelDecomposition subgraph0 = + // get_serial_parallel_decomposition(pcg.raw_graph).value(); + // auto [subgraph1, subgraph2] = + // split_sp_decomposition(subgraph0.get()); + // + // MachineSpecification machine_spec(1, 1, 1, 1, 1); + // MachineMappingState state0(subgraph0, machine_spec, {}); + // MachineMappingState state1(subgraph1, machine_spec, {}); + // MachineMappingState state2(subgraph2, machine_spec, {}); + // + // MachineMappingResult result0( + // 2, + // MachineMapping( + // std::unordered_map{})); + // MachineMappingResult result1( + // 1, + // MachineMapping( + // std::unordered_map{})); + // MachineMappingResult result2( + // 1, + // MachineMapping( + // std::unordered_map{})); + // + // MachineMappingCache cache; + // + // cache.save(state0, result0); + // CHECK(cache.load(state0).value() == result0); + // CHECK(!cache.load(state1)); + // CHECK(!cache.load(state2)); + // + // cache.save(state1, result1); + // CHECK(cache.load(state0).value() == result0); + // CHECK(cache.load(state1).value() == result1); + // CHECK(!cache.load(state2)); + // + // cache.save(state2, result2); + // CHECK(cache.load(state0).value() == result0); + // CHECK(cache.load(state1).value() == result1); + // CHECK(cache.load(state2).value() == result2); } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 0157f73ef3..ba06265cec 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_result.h" #include "cost_estimator_for_test.h" #include "doctest/doctest.h" +#include "pcg/machine_view.h" using namespace FlexFlow; @@ -20,15 +21,17 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingResult s1(1, machine_mapping_0); MachineMappingResult s2(2, machine_mapping_1); - MachineMappingResult result0 = sequential_combine(s0, s1); + float comm_cost = 2.0; + + MachineMappingResult result0 = sequential_combine(s0, comm_cost, s1); CHECK(result0.runtime == 1); CHECK(result0.machine_mapping == machine_mapping_0); - MachineMappingResult result1 = sequential_combine(s0, s2); + MachineMappingResult result1 = sequential_combine(s0, comm_cost, s2); CHECK(result1.runtime == 2); CHECK(result1.machine_mapping == machine_mapping_1); - MachineMappingResult result2 = sequential_combine(s1, s2); + MachineMappingResult result2 = sequential_combine(s1, comm_cost, s2); CHECK(result2.runtime == 3); CHECK(result2.machine_mapping == combined); } diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc index fa8385e560..46177ad420 100644 --- a/lib/compiler/test/src/graph_optimize_state.cc +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -22,7 +22,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT}; parallel_tensor_guid_t input0 = - builder.create_input_tensor(input_shape, true, "input0"); + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); parallel_tensor_guid_t dense0 = builder.dense(input0, 8, Activation::RELU, @@ -58,7 +58,7 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraphBuilder builder_; parallel_tensor_guid_t input0_ = - builder.create_input_tensor(input_shape, true, "input0"); + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); parallel_tensor_guid_t dense0_ = builder.dense(input0, 8, Activation::RELU, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 1239e75ce1..83a6504ecc 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -22,6 +22,10 @@ ParallelLayerAddedResult std::vector const &inputs, std::vector const &output_labels); +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape); + std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, parallel_layer_guid_t const &, parallel_layer_guid_t const &); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 7d9c217b25..d2987abc46 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -43,6 +43,26 @@ ParallelLayerAddedResult }; } +ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); +} + std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &src, parallel_layer_guid_t const &dst) { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b88fe38042..72d062e61d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -265,4 +265,47 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("pcg_add_input_layer") { + ParallelTensorShape tensor_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraph result = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + pcg_add_input_layer(pcg, tensor_shape); + return pcg; + }(); + + ParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + return pcg; + }(); + + FAIL("TODO"); + // CHECK(pcgs_are_isomorphic(result, correct)); + } } diff --git a/lib/utils/include/utils/containers/values.h b/lib/utils/include/utils/containers/values.h index 7c487d1d43..2a730ccc42 100644 --- a/lib/utils/include/utils/containers/values.h +++ b/lib/utils/include/utils/containers/values.h @@ -1,15 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H -#include +#include namespace FlexFlow { template -std::vector values(C const &c) { - std::vector result; +std::unordered_multiset values(C const &c) { + std::unordered_multiset result; for (auto const &kv : c) { - result.push_back(kv.second); + result.insert(kv.second); } return result; } diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 4f4c846433..941c8e8e3e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -6,6 +6,7 @@ #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 4ade34941c..b07423a21a 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -6,6 +6,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" +#include "utils/hash/vector.h" namespace FlexFlow { From 7e73162c857cd949feaf8c180c7ca6f5cc0b2246 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 27 Sep 2024 19:18:00 -0700 Subject: [PATCH 15/29] Get all the new testcases working --- .../pcg_split_boundary_layers.struct.toml | 24 ++++ .../machine_mapping/transitive_reduced_pcg.h | 18 +-- .../series_parallel/pcg_binary_series_split.h | 3 + .../get_machine_resource_splits.cc | 17 ++- .../get_optimal_machine_mapping.cc | 13 +- .../get_tensor_set_movement_across_split.cc | 4 +- .../machine_mapping/transitive_reduced_pcg.cc | 84 +++++------ ...mputation_graph_binary_sp_decomposition.cc | 2 +- .../pcg_binary_series_split.cc | 8 ++ .../pcg_binary_sp_decomposition.cc | 46 +++--- .../machine_mapping/estimate_layer_cost.cc | 2 +- .../get_tensor_set_movement_across_split.cc | 16 +-- .../machine_mapping/transitive_reduced_pcg.cc | 26 ---- lib/pcg/include/pcg/computation_graph.h | 5 + .../parallel_computation_graph.h | 5 + lib/pcg/src/pcg/computation_graph.cc | 26 ++++ .../parallel_computation_graph.cc | 25 ++++ .../parallel_computation_graph.cc | 8 +- .../sub_parallel_computation_graph.cc | 8 +- lib/utils/include/utils/containers.h | 3 - ...nsitive_reduced_boundary_nodes_for_split.h | 15 ++ ...et_transitive_reduced_edges_across_split.h | 15 ++ ..._transitive_reduced_outputs_across_split.h | 15 ++ .../split_boundary_nodes.struct.toml | 25 ++++ .../transitive_reduced_dataflow_graph.h | 12 ++ ...nsitive_reduced_dataflow_graph.struct.toml | 17 +++ ...zy_copy_of_labelled_dataflow_graph_view.h} | 14 +- .../algorithms/rewrite_node_labels.h | 22 +++ .../binary_parallel_split.h | 14 ++ .../binary_parallel_split.struct.toml | 22 +++ .../binary_series_split.h | 14 ++ .../binary_series_split.struct.toml | 22 +++ .../binary_sp_decomposition_tree.h | 6 + .../require.h | 2 +- .../transform.h | 28 ++-- ...sitive_reduced_boundary_nodes_for_split.cc | 24 ++++ ...t_transitive_reduced_edges_across_split.cc | 30 ++++ ...transitive_reduced_outputs_across_split.cc | 14 ++ .../transitive_reduced_dataflow_graph.cc | 16 +++ ...zy_copy_of_labelled_dataflow_graph_view.cc | 1 + .../algorithms/rewrite_node_labels.cc | 1 + .../binary_parallel_split.cc | 19 +++ .../binary_series_split.cc | 19 +++ .../binary_sp_decomposition_tree.cc | 17 +++ ...sitive_reduced_boundary_nodes_for_split.cc | 50 +++++++ ...t_transitive_reduced_edges_across_split.cc | 136 ++++++++++++++++++ ...transitive_reduced_outputs_across_split.cc | 47 ++++++ .../get_edges_from_subgraph_to_subgraph.cc | 4 +- 48 files changed, 815 insertions(+), 149 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml delete mode 100644 lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml rename lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/{create_lazy_copy_of_labelled_open_dataflow_graph_view.h => create_lazy_copy_of_labelled_dataflow_graph_view.h} (88%) create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml new file mode 100644 index 0000000000..155e526672 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "PCGSplitBoundaryLayers" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h index fcd3b47204..3545c4fa63 100644 --- a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -2,32 +2,28 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" namespace FlexFlow { -TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const &); +TransitiveReducedDataflowGraphView get_underlying_transitive_reduced_dataflow_graph(TransitiveReducedPCG const &); -std::unordered_set get_transitive_reduced_predecessors(TransitiveReducedPCG const &, - parallel_layer_guid_t const &); -std::unordered_set get_transitive_reduced_successors(TransitiveReducedPCG const &, - parallel_layer_guid_t const &); +TransitiveReducedPCG pcg_get_transitive_reduction(ParallelComputationGraph const &); std::unordered_set - get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); std::unordered_set - get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &, + pcg_get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); -std::pair< - std::unordered_set, - std::unordered_set -> get_split_transitive_reduced_boundary_layers(TransitiveReducedPCG const &, +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split(TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h index 196e0e502c..386bfee4f4 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h @@ -3,9 +3,12 @@ #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" namespace FlexFlow { +BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &); + PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &); PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &); diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc index c77d53a928..51ed1f7ff4 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -6,12 +6,25 @@ namespace FlexFlow { std::unordered_set> get_machine_resource_splits(MachineSpecification const &resource) { std::unordered_set> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; + + for (int i = 1; i < resource.num_nodes; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; sub_resource1.num_nodes = i; sub_resource2.num_nodes = resource.num_nodes - i; result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); } + + for (int i = 1; i < resource.num_gpus_per_node; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_gpus_per_node = i; + sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + return result; } diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 5accd84260..b731913627 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -117,12 +117,9 @@ MachineMappingResult get_optimal_machine_mapping_internal( PCGBinarySPDecomposition pre_sub_tree = get_left_child(series_split); PCGBinarySPDecomposition post_sub_tree = get_right_child(series_split); - std::pair< - std::unordered_set, - std::unordered_set - > boundary_layers = - get_split_transitive_reduced_boundary_layers(context.transitive_reduced_pcg, - series_split); + PCGSplitBoundaryLayers boundary_layers = + pcg_get_transitive_reduced_boundary_layers_for_split(context.transitive_reduced_pcg, + series_split); auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) -> std::unordered_set> @@ -136,7 +133,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( }; for (std::unordered_map const &assigned_pre_machine_views - : get_boundary_machine_view_assignments(boundary_layers.first)) { + : get_boundary_machine_view_assignments(boundary_layers.pre_split_boundary)) { PartialMachineMapping pre_candidate = with_additional_layer_machine_views( @@ -152,7 +149,7 @@ MachineMappingResult get_optimal_machine_mapping_internal( for (std::unordered_map const &assigned_post_machine_views - : get_boundary_machine_view_assignments(boundary_layers.second)) { + : get_boundary_machine_view_assignments(boundary_layers.post_split_boundary)) { PartialMachineMapping post_candidate = with_additional_layer_machine_views( diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 38f2bf2344..8c84e227a7 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -16,7 +16,7 @@ TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG cons PartialMachineMapping const &pre_mapping, PartialMachineMapping const &post_mapping) { std::unordered_set - edges_across_split = get_transitive_reduced_edges_across_split(tr_pcg, split); + edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { std::unordered_set tensor_edges = filter(edges_across_split, @@ -42,7 +42,7 @@ TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG cons }; std::unordered_map single_tensor_movements = - generate_map(get_transitive_reduced_tensors_across_split(tr_pcg, split), + generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), get_movement_for_tensor); return TensorSetMovement{ diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index fe7b05e7b2..ccb6ae2eed 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -4,14 +4,24 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" namespace FlexFlow { -TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const &pcg) { +TransitiveReducedDataflowGraphView get_underlying_transitive_reduced_dataflow_graph(TransitiveReducedPCG const &tr_pcg) { + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, + /*transitive_reduction=*/tr_pcg.transitive_reduction, + }; +} + +TransitiveReducedPCG pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { DiGraphView raw_digraph = pcg.raw_graph; DiGraphView transitive_reduced = transitive_reduction(raw_digraph); @@ -21,62 +31,44 @@ TransitiveReducedPCG get_pcg_transitive_reduction(ParallelComputationGraph const }; } -std::unordered_set get_transitive_reduced_predecessors(TransitiveReducedPCG const &tr_pcg, - parallel_layer_guid_t const &layer) { - std::unordered_set raw_predecessors = get_predecessors(tr_pcg.transitive_reduction, layer.raw_graph_node); - return transform(raw_predecessors, [](Node const &n) { return parallel_layer_guid_t{n}; }); -} +std::unordered_set + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { -std::unordered_set get_transitive_reduced_successors(TransitiveReducedPCG const &tr_pcg, - parallel_layer_guid_t const &layer) { - std::unordered_set raw_successors = get_successors(tr_pcg.transitive_reduction, layer.raw_graph_node); - return transform(raw_successors, [](Node const &n) { return parallel_layer_guid_t{n}; }); -} + TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); -std::unordered_set - get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { - std::unordered_set src_subgraph = unordered_set_of(get_parallel_layers(get_left_child(split))); - std::unordered_set dst_subgraph = unordered_set_of(get_parallel_layers(get_right_child(split))); - - std::unordered_set raw_src_subgraph = transform(src_subgraph, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }); - std::unordered_set raw_dst_subgraph = transform(dst_subgraph, [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }); - - std::unordered_set raw_edges = get_edges_from_subgraph_to_subgraph(tr_pcg.transitive_reduction, - raw_src_subgraph, - raw_dst_subgraph); - - return flatmap(raw_edges, - [&](DirectedEdge const &e) { - return get_pcg_edges_from_layer_to_layer(tr_pcg.full_pcg, - parallel_layer_guid_t{e.src}, - parallel_layer_guid_t{e.dst}); - }); + BinarySeriesSplit raw_split = get_raw_graph_series_split(split); + + std::unordered_set raw_edges = get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + + return transform(raw_edges, + [](DataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); } std::unordered_set - get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, + pcg_get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - return transform(get_transitive_reduced_edges_across_split(tr_pcg, split), - [](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e); }); + TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = get_raw_graph_series_split(split); + + std::unordered_set raw_outputs = get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + + return transform(raw_outputs, + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } -std::pair< - std::unordered_set, - std::unordered_set -> get_split_transitive_reduced_boundary_layers(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { - std::unordered_set edges = get_transitive_reduced_edges_across_split(tr_pcg, split); +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - std::unordered_set src_boundary_layers = transform(edges, - [](ParallelComputationGraphEdge const &e) { return get_src_layer(e); }); + BinarySeriesSplit raw_split = get_raw_graph_series_split(split); - std::unordered_set dst_boundary_layers = transform(edges, - [](ParallelComputationGraphEdge const &e) { return get_dst_layer(e); }); + SplitBoundaryNodes raw_boundary = get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); - return { - src_boundary_layers, - dst_boundary_layers, + return PCGSplitBoundaryLayers{ + /*pre_split_boundary=*/transform(raw_boundary.pre_split_boundary, [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*post_split_boundary=*/transform(raw_boundary.post_split_boundary, [](Node const &n) { return parallel_layer_guid_t{n}; }), }; } diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc index 63054385ac..63d1231ae7 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -33,7 +33,7 @@ ComputationGraphBinarySPDecomposition } layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_node(d.raw_tree); + return require_leaf(d.raw_tree); } std::optional diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc index 636671c8fa..efa919d5b9 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -1,9 +1,17 @@ #include "compiler/series_parallel/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" namespace FlexFlow { +BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &s) { + return BinarySeriesSplit{ + transform(s.raw_split, + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + }; +} + PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &s) { return PCGBinarySPDecomposition{ get_left_child(s.raw_split), diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index 5c5233a494..bdd68da600 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -1,4 +1,8 @@ #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" namespace FlexFlow { @@ -8,36 +12,46 @@ std::optional } std::unordered_multiset - get_parallel_layers(PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); + get_parallel_layers(PCGBinarySPDecomposition const &d) { + return get_leaves(d.raw_tree); } -SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { + return get_node_type(d.raw_tree); } -PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + }; } -PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + }; } -PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &) { - NOT_IMPLEMENTED(); +PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{ + make_generic_binary_sp_leaf(l), + }; } -PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) { + return PCGBinarySeriesSplit{ + require_series(d.raw_tree), + }; } -PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) { + return PCGBinaryParallelSplit{ + require_parallel(d.raw_tree), + }; } -parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &) { - NOT_IMPLEMENTED(); +parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { + return require_leaf(d.raw_tree); } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc index 00a99bfbf8..cd72f74c61 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -27,7 +27,7 @@ TEST_SUITE(FF_TEST_SUITE) { LinearAttrs linear_attrs = LinearAttrs{ /*out_channels=*/12, - /*use_bias=*/false, + /*use_bias=*/true, /*data_type=*/DataType::FLOAT, /*activation=*/std::nullopt, /*regularizer=*/std::nullopt, diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index b77325fe86..cce5dbb1a2 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -11,7 +11,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("estimate_cost_across_split") { + TEST_CASE("get_tensor_set_movement_across_split") { ParallelComputationGraph pcg = empty_parallel_computation_graph(); ParallelTensorShape input_shape = @@ -79,7 +79,7 @@ TEST_SUITE(FF_TEST_SUITE) { {relu_2.parallel_layer, post_mv1}, }}; - TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); @@ -125,7 +125,7 @@ TEST_SUITE(FF_TEST_SUITE) { {relu_3.parallel_layer, post_mv1}, }}; - TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); @@ -153,7 +153,7 @@ TEST_SUITE(FF_TEST_SUITE) { {relu_3.parallel_layer, post_mv2}, }}; - TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); @@ -206,7 +206,7 @@ TEST_SUITE(FF_TEST_SUITE) { make_pcg_leaf_node(relu_2.parallel_layer), make_pcg_leaf_node(relu_4.parallel_layer)))); - TensorSetMovement result = get_tensor_set_movement_across_split(get_pcg_transitive_reduction(pcg), + TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); @@ -217,12 +217,12 @@ TEST_SUITE(FF_TEST_SUITE) { SingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, }, SingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1, post_mv2}, + /*src_machine_views=*/{pre_mv2}, + /*dst_machine_views=*/{post_mv2}, }, }, }; diff --git a/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc deleted file mode 100644 index 4bb1afab53..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_transitive_reduced_predecessors") { - FAIL("TODO"); - } - - TEST_CASE("get_transitive_reduced_successors") { - FAIL("TODO"); - } - - TEST_CASE("get_transitive_reduced_edges_across_split") { - FAIL("TODO"); - } - - TEST_CASE("get_transitive_reduced_tensors_across_split") { - FAIL("TODO"); - } - - TEST_CASE("get_split_transitive_reduced_boundary_layers") { - FAIL("TODO"); - } -} diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index f70d9f7404..b29d683edb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -52,6 +52,11 @@ LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +ComputationGraph without_layer_names(ComputationGraph const &); + +bool computation_graphs_are_isomorphic(ComputationGraph const &, + ComputationGraph const &); + std::string as_dot(ComputationGraph const &); void debug_print_dot(ComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 83a6504ecc..b6f7790c49 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -58,6 +58,11 @@ parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); +ParallelComputationGraph without_layer_names(ParallelComputationGraph const &); + +bool pcgs_are_isomorphic(ParallelComputationGraph const &, + ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index a69e54fd93..32f2335605 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -11,10 +11,12 @@ #include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" namespace FlexFlow { @@ -175,6 +177,30 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +ComputationGraph without_layer_names(ComputationGraph const &cg) { + return ComputationGraph{ + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, + ComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + + std::string as_dot(ComputationGraph const &cg) { std::function get_node_label = [](LayerAttrs const &a) -> std::string { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index d2987abc46..b26478107d 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -7,7 +7,9 @@ #include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" namespace FlexFlow { @@ -169,4 +171,27 @@ parallel_layer_guid_t return get_only(found); } +ParallelComputationGraph without_layer_names(ParallelComputationGraph const &pcg) { + return ParallelComputationGraph{ + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + pcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + } // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 72d062e61d..f0e58191ef 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -302,10 +302,14 @@ TEST_SUITE(FF_TEST_SUITE) { /*create_gradients=*/CreateGrad::NO, }; + add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); + return pcg; }(); - FAIL("TODO"); - // CHECK(pcgs_are_isomorphic(result, correct)); + CHECK(pcgs_are_isomorphic(result, correct)); } } diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 0bbe0e97a7..0c673f0a8a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -4,7 +4,7 @@ #include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" @@ -54,12 +54,6 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( UnorderedSetLabelledOpenDataflowGraph>( sub_pcg.raw_graph)}; - // return ParallelComputationGraph{ - // make_lazy_copy_of< - // UnorderedSetLabelledOpenDataflowGraph - // >(sub_pcg.raw_graph) - // }; } parallel_layer_guid_t diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 4952a774b5..0e3b1fc0bd 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -126,9 +126,6 @@ std::optional optional_all_of(Container const &container, return true; } - } -} - template std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h new file mode 100644 index 0000000000..79cb6059b3 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split(TransitiveReducedDataflowGraphView const &, + BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h new file mode 100644 index 0000000000..c3a71b0f63 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_transitive_reduced_edges_across_split(TransitiveReducedDataflowGraphView const &, + BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h new file mode 100644 index 0000000000..5fab1fa0b3 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_transitive_reduced_outputs_across_split(TransitiveReducedDataflowGraphView const &, + BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml new file mode 100644 index 0000000000..32582a6b74 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "SplitBoundaryNodes" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h new file mode 100644 index 0000000000..6b711c8382 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView get_dataflow_graph_transitive_reduction(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml new file mode 100644 index 0000000000..54c710b26e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TransitiveReducedDataflowGraphView" +features = [] + +includes = [ + "utils/graph/dataflow_graph/dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_dataflow_graph" +type = "::FlexFlow::DataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h similarity index 88% rename from lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h rename to lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h index a8e08cb995..2685306bd5 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" @@ -8,6 +8,10 @@ namespace FlexFlow { + // NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. + // Figuring out what to do with it is tracked in + // https://github.com/flexflow/FlexFlow/issues/1513 + template struct LazyLabelledDataflowGraph final : public ILabelledDataflowGraph { @@ -42,11 +46,11 @@ struct LazyLabelledDataflowGraph final return this->get_view().query_outputs(q); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->get_view().at(n); } - ValueLabel const &at(DataflowOutput const &v) const override { + ValueLabel at(DataflowOutput const &v) const override { return this->get_view().at(v); } @@ -95,7 +99,7 @@ template static typename std::enable_if< std::is_base_of, T>::value, LabelledDataflowGraph>::type - make_lazy_copy_of( + create_lazy_copy_of_labelled_dataflow_graph_view( LabelledDataflowGraphView const &view) { std::function( LabelledDataflowGraphView const &)> diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..1aa5b6b37f --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" + +namespace FlexFlow { + +template > +LabelledDataflowGraphView rewrite_node_labels( + LabelledDataflowGraphView const &g, F f) { + return rewrite_node_labels( + view_as_labelled_open_dataflow_graph(g), f); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h new file mode 100644 index 0000000000..db2fbceaed --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree get_left_child(BinaryParallelSplit const &); +BinarySPDecompositionTree get_right_child(BinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml new file mode 100644 index 0000000000..985fb3089d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "BinaryParallelSplit" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h new file mode 100644 index 0000000000..f8ef91a5d8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree get_left_child(BinarySeriesSplit const &); +BinarySPDecompositionTree get_right_child(BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml new file mode 100644 index 0000000000..c7c89da6d2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "BinarySeriesSplit" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index b1607e7a76..023c767313 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -2,6 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" #include namespace FlexFlow { @@ -18,6 +20,10 @@ bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); +BinarySeriesSplit require_series(BinarySPDecompositionTree const &); +BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &); +Node require_leaf(BinarySPDecompositionTree const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index a8de1ee8f8..4137585c1a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -19,7 +19,7 @@ GenericBinaryParallelSplit const & } template -T const &require_node(GenericBinarySPDecompositionTree const &t) { +T const &require_leaf(GenericBinarySPDecompositionTree const &t) { return get(t); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index 4d7fa05960..08ab99a292 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -8,6 +8,24 @@ namespace FlexFlow { +template > +GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &s, F f) { + return GenericBinarySeriesSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }; +}; + +template > +GenericBinaryParallelSplit + transform(GenericBinaryParallelSplit const &s, F f) { + return GenericBinaryParallelSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }; +}; + template > GenericBinarySPDecompositionTree transform(GenericBinarySPDecompositionTree const &tt, F f) { @@ -16,18 +34,12 @@ GenericBinarySPDecompositionTree overload{ [&](GenericBinarySeriesSplit const &s) { return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, + transform(s, f), }; }, [&](GenericBinaryParallelSplit const &s) { return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, + transform(s, f), }; }, [&](T const &t) { diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..66152b9b13 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split(TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set edges = get_transitive_reduced_edges_across_split(tr_g, split); + + std::unordered_set src_boundary_nodes + = transform(edges, + [](DataflowEdge const &e) { return e.src.node; }); + + std::unordered_set dst_boundary_nodes + = transform(edges, + [](DataflowEdge const &e) { return e.dst.node; }); + + return SplitBoundaryNodes{ + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..49783ee0d5 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,30 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +std::unordered_set + get_transitive_reduced_edges_across_split(TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set src_subgraph = unordered_set_of(get_leaves(get_left_child(split))); + std::unordered_set dst_subgraph = unordered_set_of(get_leaves(get_right_child(split))); + + std::unordered_set raw_edges = get_edges_from_subgraph_to_subgraph(tr_g.transitive_reduction, + src_subgraph, + dst_subgraph); + + return flatmap(raw_edges, + [&](DirectedEdge const &e) { + return get_dataflow_edges_from_node_to_node(tr_g.full_dataflow_graph, + e.src, + e.dst); + }); + +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..d4e285e5c3 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,14 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +std::unordered_set + get_transitive_reduced_outputs_across_split(TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + return transform(get_transitive_reduced_edges_across_split(tr_g, split), + [](DataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc new file mode 100644 index 0000000000..a068679be4 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc @@ -0,0 +1,16 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { + DiGraphView as_digraph = g; + DiGraphView transitive_reduced = transitive_reduction(as_digraph); + + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc new file mode 100644 index 0000000000..28d63f9ee1 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..dc5ce4fbda --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc new file mode 100644 index 0000000000..88bb9d1acc --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +BinarySPDecompositionTree get_left_child(BinaryParallelSplit const &s) { + return BinarySPDecompositionTree{ + get_left_child(s.raw_split), + }; +} + +BinarySPDecompositionTree get_right_child(BinaryParallelSplit const &s) { + return BinarySPDecompositionTree{ + get_right_child(s.raw_split), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc new file mode 100644 index 0000000000..9b8f0685cd --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +BinarySPDecompositionTree get_left_child(BinarySeriesSplit const &split) { + return BinarySPDecompositionTree{ + get_left_child(split.raw_split), + }; +} + +BinarySPDecompositionTree get_right_child(BinarySeriesSplit const &split) { + return BinarySPDecompositionTree{ + get_right_child(split.raw_split), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 18d1f922c6..f683caef48 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" namespace FlexFlow { @@ -40,4 +41,20 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { return get_leaves(tt.raw_tree); } +BinarySeriesSplit require_series(BinarySPDecompositionTree const &tt) { + return BinarySeriesSplit{ + require_series(tt.raw_tree), + }; +} + +BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &tt) { + return BinaryParallelSplit{ + require_parallel(tt.raw_tree), + }; +} + +Node require_leaf(BinarySPDecompositionTree const &tt) { + return require_leaf(tt.raw_tree); +} + } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..1a47dfde25 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,50 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(\ + make_series_split( + make_series_split( + make_leaf_node(n1), + make_leaf_node(n2)), + make_series_split( + make_leaf_node(n3), + make_leaf_node(n4)))); + + SplitBoundaryNodes result = get_transitive_reduced_boundary_nodes_for_split(tr_g, split); + SplitBoundaryNodes correct = SplitBoundaryNodes{ + /*pre_split_boundary=*/{n2}, + /*post_split_boundary=*/{n3}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..915be7261e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,136 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_edges_across_split") { + DataflowGraph g = DataflowGraph::create(); + + SUBCASE("multiple nodes with edges across") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(\ + make_series_split( + make_parallel_split( + make_leaf_node(n1), + make_leaf_node(n2)), + make_parallel_split( + make_leaf_node(n3), + make_leaf_node(n4)))); + + std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o1, + DataflowInput{n3, 1}, + }, + DataflowEdge{ + o2, + DataflowInput{n3, 0}, + }, + DataflowEdge{ + o1, + DataflowInput{n4, 0}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("nodes each have multiple edges across") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o1 = n1_added.outputs.at(0); + DataflowOutput n1_o2 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o1, n1_o2, n1_o1}, 1); + Node n2 = n2_added.node; + + TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(\ + make_series_split( + make_leaf_node(n1), + make_leaf_node(n2))); + + std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + n1_o1, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o2, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not return edges eliminated by transitive reduction") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(\ + make_series_split( + make_series_split( + make_leaf_node(n1), + make_leaf_node(n2)), + make_series_split( + make_leaf_node(n3), + make_leaf_node(n4)))); + + std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o2, + DataflowInput{n3, 1}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..2df7c91041 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,47 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_outputs_across_split") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(\ + make_series_split( + make_series_split( + make_leaf_node(n1), + make_leaf_node(n2)), + make_series_split( + make_leaf_node(n3), + make_leaf_node(n4)))); + + std::unordered_set result = get_transitive_reduced_outputs_across_split(tr_g, split); + std::unordered_set correct = {o2}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc index bf51bd028e..c5e25386d5 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 5); SUBCASE("basic tests") { - std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(5)}; + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(4)}; std::unordered_set dst_subgraph = {n.at(2), n.at(3)}; SUBCASE("returns all edges between subgraphs") { @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(0), n.at(2)}, DirectedEdge{n.at(0), n.at(3)}, DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(5), n.at(2)}, + DirectedEdge{n.at(4), n.at(2)}, }; add_edges(g, e); From bdcc10e9ceca707b73805d1b6b35cc4730b348da Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 28 Sep 2024 19:16:04 -0700 Subject: [PATCH 16/29] Move over to ProblemTree/ResultTree framework for machine mapping --- .../single_tensor_movement.struct.toml | 2 +- ...tracted_single_tensor_movement.struct.toml | 30 +++ .../abstracted_tensor_set_movement.h | 21 ++ ...abstracted_tensor_set_movement.struct.toml | 21 ++ ...tracted_tensor_set_movement_across_split.h | 15 ++ .../get_optimal_machine_mapping.h | 34 ++-- .../machine_mapping/machine_mapping_cache.h | 8 +- .../machine_mapping_constraints.h | 31 +++ ...> machine_mapping_constraints.struct.toml} | 2 +- .../machine_mapping_context.struct.toml | 5 - .../get_machine_mapping_problem_tree.h | 17 ++ .../machine_mapping_problem_tree.h | 51 +++++ .../machine_mapping_problem_tree.struct.toml | 18 ++ ...mm_problem_tree_parallel_split.struct.toml | 18 ++ ...blem_tree_parallel_split_label.struct.toml | 11 ++ .../mm_problem_tree_series_split.h | 15 ++ .../mm_problem_tree_series_split.struct.toml | 18 ++ ...roblem_tree_series_split_label.struct.toml | 15 ++ .../unmapped_op_cost_estimate_key.struct.toml | 36 ++++ .../machine_mapping/machine_mapping_result.h | 5 - .../machine_mapping_result_tree.h | 19 ++ .../machine_mapping_result_tree.struct.toml | 18 ++ .../mm_result_tree_parallel_split.struct.toml | 18 ++ ...sult_tree_parallel_split_label.struct.toml | 13 ++ .../mm_result_tree_series_split.struct.toml | 18 ++ ...result_tree_series_split_label.struct.toml | 13 ++ .../machine_mapping_state.struct.toml | 12 +- .../machine_mapping/partial_machine_mapping.h | 31 --- ..._graph_binary_sp_decomposition.struct.toml | 10 +- .../pcg_binary_parallel_split.struct.toml | 10 +- .../pcg_binary_series_split.struct.toml | 9 +- .../pcg_binary_sp_decomposition.struct.toml | 10 +- .../abstracted_tensor_set_movement.cc | 49 +++++ .../machine_mapping/estimate_layer_cost.cc | 5 +- ...racted_tensor_set_movement_across_split.cc | 48 +++++ .../get_machine_mapping_problem_tree.cc | 45 +++++ .../get_optimal_machine_mapping.cc | 183 +++++++++--------- .../get_tensor_set_movement_across_split.cc | 37 +--- ...ping.cc => machine_mapping_constraints.cc} | 24 ++- .../machine_mapping_problem_tree.cc | 95 +++++++++ .../mm_problem_tree_series_split.cc | 18 ++ .../mm_problem_tree_split_label.cc | 17 ++ .../pcg_binary_parallel_split.cc | 2 +- .../pcg_binary_series_split.cc | 5 +- .../get_machine_mapping_problem_tree.cc | 176 +++++++++++++++++ .../get_optimal_machine_mapping.cc | 9 +- .../parallel_computation_graph.h | 2 + .../parallel_computation_graph.cc | 5 + .../include/utils/full_binary_tree/fmt.h | 37 ++++ .../utils/full_binary_tree/full_binary_tree.h | 87 +++++++++ .../full_binary_tree_node_type.enum.toml | 16 ++ .../utils/full_binary_tree/get_leaves.h | 30 +++ .../utils/full_binary_tree/get_left_child.h | 15 ++ .../utils/full_binary_tree/get_node_type.h | 27 +++ .../utils/full_binary_tree/get_right_child.h | 15 ++ .../include/utils/full_binary_tree/hash.h | 26 +++ .../include/utils/full_binary_tree/require.h | 20 ++ .../utils/full_binary_tree/transform.h | 48 +++++ .../include/utils/full_binary_tree/visit.h | 23 +++ .../binary_parallel_split.struct.toml | 10 +- .../binary_series_split.struct.toml | 10 +- .../binary_sp_decomposition_tree.struct.toml | 10 +- .../fmt.h | 63 ------ .../generic_binary_parallel_split.struct.toml | 29 +++ .../generic_binary_series_split.struct.toml | 30 +++ .../generic_binary_sp_decomposition_tree.h | 155 --------------- ...c_binary_sp_decomposition_tree.struct.toml | 21 ++ .../get.h | 15 -- .../get_leaves.h | 24 +-- .../get_left_child.h | 42 ++-- .../get_node_type.h | 33 ++-- .../get_right_child.h | 42 ++-- .../hash.h | 34 ---- .../generic_binary_sp_decomposition_tree/is.h | 20 +- .../is_binary_sp_tree_left_associative.h | 12 +- .../is_binary_sp_tree_right_associative.h | 12 +- .../make.h | 56 +++--- .../require.h | 38 ++-- .../transform.h | 59 ++++-- .../visit.h | 43 ++-- .../get_leaves.h | 16 ++ .../leaf_only_binary_parallel_split.h | 21 ++ ...eaf_only_binary_parallel_split.struct.toml | 23 +++ ...ly_binary_parallel_split_label.struct.toml | 12 ++ .../leaf_only_binary_series_split.h | 21 ++ .../leaf_only_binary_series_split.struct.toml | 23 +++ ...only_binary_series_split_label.struct.toml | 12 ++ ...y_binary_sp_decomposition_tree.struct.toml | 21 ++ .../make.h | 44 +++++ .../require.h | 52 +++++ .../transform.h | 63 ++++++ .../test/src/utils/containers/flatmap.cc | 35 ++++ 92 files changed, 1997 insertions(+), 722 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h rename lib/compiler/include/compiler/machine_mapping/{partial_machine_mapping.struct.toml => machine_mapping_constraints.struct.toml} (92%) create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h create mode 100644 lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc rename lib/compiler/src/compiler/machine_mapping/{partial_machine_mapping.cc => machine_mapping_constraints.cc} (69%) create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc create mode 100644 lib/utils/include/utils/full_binary_tree/fmt.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml create mode 100644 lib/utils/include/utils/full_binary_tree/get_leaves.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_left_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_node_type.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_right_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/hash.h create mode 100644 lib/utils/include/utils/full_binary_tree/require.h create mode 100644 lib/utils/include/utils/full_binary_tree/transform.h create mode 100644 lib/utils/include/utils/full_binary_tree/visit.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h create mode 100644 lib/utils/test/src/utils/containers/flatmap.cc diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml index 52f66f3420..70f73ebe51 100644 --- a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -26,5 +26,5 @@ name = "src_machine_views" type = "std::unordered_set<::FlexFlow::MachineView>" [[fields]] -name = "dst_machine_view" +name = "dst_machine_views" type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml new file mode 100644 index 0000000000..fcae1e2356 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h new file mode 100644 index 0000000000..80e91b0f85 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); + +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); + +TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &, + MachineMapping const &pre, + MachineMapping const &post); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml new file mode 100644 index 0000000000..4cf184706b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..33f44a3a11 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H + +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 7b4ba275a2..3c71d78093 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -3,10 +3,12 @@ #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" -#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -14,7 +16,7 @@ namespace FlexFlow { -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResultTree get_optimal_machine_mapping( ParallelComputationGraph const &pcg, std::function( ParallelLayerAttrs const &, MachineSpecification const &)> const @@ -23,38 +25,38 @@ MachineMappingResult get_optimal_machine_mapping( MachineSpecification const &resources, MachineMappingCache &cached_subgraph_results); -MachineMappingResult +MachineMappingResultTree get_optimal_machine_mapping_internal(MachineMappingCache &result_cache, MachineMappingContext const &context, MachineSpecification const &resources); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySPDecomposition const &sp_decomposition, + MachineMappingProblemTree const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySeriesSplit const &series, + MMProblemTreeSeriesSplit const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinaryParallelSplit const ¶llel, + MMProblemTreeParallelSplit const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &, parallel_layer_guid_t const &, MachineSpecification const &, - PartialMachineMapping const &); + MachineMappingConstraints const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index a721ea29ed..b4608a90e0 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H -#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_state.dtg.h" #include "utils/optional.h" @@ -11,11 +11,11 @@ class MachineMappingCache { public: MachineMappingCache() = default; - std::optional load(MachineMappingState const &) const; - void save(MachineMappingState const &, MachineMappingResult const &); + std::optional load(MachineMappingState const &) const; + void save(MachineMappingState const &, MachineMappingResultTree const &); private: - std::unordered_map cache; + std::unordered_map cache; }; } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h new file mode 100644 index 0000000000..320a840bf6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H + +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &); + +std::unordered_set get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); + +std::optional get_machine_view_for_layer(MachineMappingConstraints const &, + parallel_layer_guid_t const &); + +MachineMappingConstraints restrict_domain(MachineMappingConstraints const &, + std::unordered_set const &); + +MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, + MachineMapping const &); + +MachineMapping require_fully_constrained(MachineMappingConstraints const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml similarity index 92% rename from lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml index b1955185ad..7211c773bb 100644 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "PartialMachineMapping" +name = "MachineMappingConstraints" features = [ "eq", "hash", diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index 272d4c2097..505141d59f 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -6,13 +6,8 @@ includes = [ "compiler/cost_estimator/cost_estimator.h", "pcg/machine_view.dtg.h", "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/transitive_reduced_pcg.dtg.h", ] -[[fields]] -name = "transitive_reduced_pcg" -type = "::FlexFlow::TransitiveReducedPCG" - [[fields]] name = "cost_estimator" type = "::FlexFlow::CostEstimator" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h new file mode 100644 index 0000000000..b5ab1988ad --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h new file mode 100644 index 0000000000..29b5cf24d5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -0,0 +1,51 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree + mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &pre, + MachineMappingProblemTree const &post); +MachineMappingProblemTree + mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs); +MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &); + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); + +MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &); +MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &); +PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &); + +std::unordered_multiset get_leaves(MachineMappingProblemTree const &); + +template +Result visit(MachineMappingProblemTree const &t, F &&f) { + SPDecompositionTreeNodeType node_type = get_node_type(t); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + Result result = f(require_series_split(t)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + Result result = f(require_parallel_split(t)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + Result result = f(require_leaf(t)); + return result; + } + default: + throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml new file mode 100644 index 0000000000..e322133768 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..b277ca44bd --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml new file mode 100644 index 0000000000..367ffb399f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +fields = [] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h new file mode 100644 index 0000000000..8332da66f9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &); +MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &); +AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml new file mode 100644 index 0000000000..299114862c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml new file mode 100644 index 0000000000..0887d67b49 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..fe76683eb7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 621285ae16..0cdd283582 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -5,11 +5,6 @@ namespace FlexFlow { -MachineMappingResult sequential_combine(MachineMappingResult const &s1, - float comm_cost, - MachineMappingResult const &s2); -MachineMappingResult parallel_combine(MachineMappingResult const &s1, - MachineMappingResult const &s2); MachineMappingResult get_infinity_machine_mapping_result(); void minimize_runtime(MachineMappingResult &m1, MachineMappingResult const &m2); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h new file mode 100644 index 0000000000..0ddbc08297 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H + +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" + +namespace FlexFlow { + +MachineMappingResultTree make_series_split(float comm_cost, + MachineMappingResultTree const &pre, + MachineMappingResultTree const &post); +MachineMappingResultTree make_parallel_split(MachineMappingResultTree const &lhs, + MachineMappingResultTree const &rhs); +MachineMappingResultTree make_leaf_node(float cost, MachineView const &); + +std::optional minimize_cost(std::optional const &, MachineMappingResultTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml new file mode 100644 index 0000000000..69c7a613e0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingResultTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..ceb85e26eb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMResultTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/parallel_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml new file mode 100644 index 0000000000..6bc880e1fb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "MMResultTreeParallelSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "cost" +type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml new file mode 100644 index 0000000000..9210d1c80c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMResultTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml new file mode 100644 index 0000000000..0f0a326fb5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "MMResultTreeSeriesSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "cost" +type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 0fcb065b10..4d4a29eac7 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -8,18 +8,18 @@ features = [ includes = [ "pcg/machine_specification.dtg.h", - "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h", - "compiler/machine_mapping/partial_machine_mapping.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", ] [[fields]] -name = "subgraph" -type = "::FlexFlow::PCGBinarySPDecomposition" +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" [[fields]] name = "resource" type = "::FlexFlow::MachineSpecification" [[fields]] -name = "partial_solution" -type = "::FlexFlow::PartialMachineMapping" +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h deleted file mode 100644 index 4ed43b3470..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H - -#include "compiler/machine_mapping/machine_mapping.dtg.h" -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" -#include "compiler/machine_mapping/include_unconstrained.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &); - -std::unordered_set get_all_layers(PartialMachineMapping const &, - IncludeUnconstrained const &); - -std::optional get_machine_view_for_layer(PartialMachineMapping const &, - parallel_layer_guid_t const &); - -PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem); - -PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, - std::unordered_map const &additional); - -MachineMapping require_complete_mapping(PartialMachineMapping const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml index 147b1e3acf..98d0fc5faf 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "ComputationGraphBinarySPDecomposition" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition/leaf_only_binary_sp_decomposition.dtg.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySPDecomposition<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml index 75e1fec52f..f7d80138c5 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "PCGBinaryParallelSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", ] [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml index 63fc7562cd..48e19022c9 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml @@ -9,14 +9,9 @@ features = [ includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", ] [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml index c9950bf3f4..bead04b307 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "PCGBinarySPDecomposition" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc new file mode 100644 index 0000000000..96605fa238 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc @@ -0,0 +1,49 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { + return AbstractedTensorSetMovement{{}}; +} + +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.src_machine_views; + }); +} + +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.dst_machine_views; + }); +} + +TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &abstracted, + PartialMachineMapping const &pre_mapping, + PartialMachineMapping const &post_mapping) { + auto concretize_tensor_movement = [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/transform(a.src_machine_views, + [&](parallel_layer_guid_t const &layer) { + return get_machine_view_for_layer(pre_mapping, layer).value(); + }), + /*dst_machine_views=*/transform(a.dst_machine_views, + [&](parallel_layer_guid_t const &layer) { + return get_machine_view_for_layer(post_mapping, layer).value(); + }), + }; + }; + + return TensorSetMovement{ + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, concretize_tensor_movement), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc index 1caa31aefc..c01354f68b 100644 --- a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -3,9 +3,8 @@ namespace FlexFlow { -float estimate_layer_cost(ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - parallel_layer_guid_t const &layer, +float estimate_layer_cost(CostEstimator const &cost_estimator, + PCGOperatorAttrs const &layer, MachineView const &machine_view) { PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, layer).op_attrs; diff --git a/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..2c17fc089d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,48 @@ +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + std::unordered_set + edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { + std::unordered_set tensor_edges = filter(edges_across_split, + [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); + + std::unordered_set src_layers = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); + + std::unordered_set dst_layers = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); + + return AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/src_layers, + /*dst_machine_views=*/dst_layers, + }; + }; + + std::unordered_map single_tensor_movements = + generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return AbstractedTensorSetMovement{ + values(single_tensor_movements), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..8472228534 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc @@ -0,0 +1,45 @@ +#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { + TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); + + std::function to_problem_tree; + + to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + return visit( + sp, + overload { + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return mm_problem_tree_make_series_split( + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(get_left_child(series)), + /*rhs=*/to_problem_tree(get_right_child(series))); + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return mm_problem_tree_make_parallel_split( + to_problem_tree(get_left_child(parallel)), + to_problem_tree(get_right_child(parallel))); + }, + [&](parallel_layer_guid_t const &leaf) { + return mm_problem_tree_make_leaf(pcg_get_op_attrs(pcg, leaf)); + } + }); + }; + + return to_problem_tree(sp_decomposition_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index b731913627..d24ccaf63e 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,8 +1,16 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/get_allowed_machine_views_list.h" #include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" +#include "compiler/machine_mapping/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/partial_machine_mapping.dtg.h" #include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" @@ -49,41 +57,34 @@ MachineMappingResult get_optimal_machine_mapping( return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, MachineSpecification const &resources) { - PCGBinarySPDecomposition sp_decomposition_tree = ({ - std::optional returned = get_pcg_balanced_binary_sp_decomposition(context.transitive_reduced_pcg.full_pcg); - if (!returned.has_value()) { - throw mk_runtime_error("Failed to get serial parallel decomposition"); - } - returned.value(); - }); - std::unordered_set all_layers = get_parallel_layers(context.transitive_reduced_pcg.full_pcg); - return get_optimal_machine_mapping_internal(result_cache, - context, - sp_decomposition_tree, - resources, - get_unconstrained_solution_for_layers(all_layers)); + NOT_IMPLEMENTED(); + // return get_optimal_machine_mapping_internal(result_cache, + // context, + // sp_decomposition_tree, + // resources, + // get_unconstrained_solution_for_layers(all_layers)); } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySPDecomposition const &sp_decomposition_tree, + MachineMappingProblemTree const &problem_tree, MachineSpecification const &resources, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &constraints) { MachineMappingState state = MachineMappingState{ - sp_decomposition_tree, resources, partial_solution, + problem_tree, resources, constraints, }; { - std::optional cached_result = + std::optional cached_result = result_cache.load(state); if (cached_result) { return cached_result.value(); @@ -100,107 +101,106 @@ MachineMappingResult get_optimal_machine_mapping_internal( return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySeriesSplit const &series_split, + MMProblemTreeSeriesSplit const &series_split, MachineSpecification const &resource, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &partial_solution) { - MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + std::optional result = std::nullopt; auto is_subgraph_input = [&](std::unordered_set const &subgraph_nodes, parallel_tensor_guid_t const &input_tensor) { return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); }; - PCGBinarySPDecomposition pre_sub_tree = get_left_child(series_split); - PCGBinarySPDecomposition post_sub_tree = get_right_child(series_split); - - PCGSplitBoundaryLayers boundary_layers = - pcg_get_transitive_reduced_boundary_layers_for_split(context.transitive_reduced_pcg, - series_split); + AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) - -> std::unordered_set> + -> std::unordered_set { std::unordered_map> allowed = generate_map(layers, [&](parallel_layer_guid_t const &l) { return get_allowed_machine_views_for_layer(context, l); }); - return get_all_assignments(allowed); + return transform(get_all_assignments(allowed), + [](std::unordered_map const &m) { + return MachineMapping{m}; + }); }; - for (std::unordered_map const &assigned_pre_machine_views - : get_boundary_machine_view_assignments(boundary_layers.pre_split_boundary)) { + for (MachineMapping const &assigned_pre_machine_views + : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { - PartialMachineMapping pre_candidate = - with_additional_layer_machine_views( - get_sub_solution(partial_solution, pre_sub_tree), + MachineMappingConstraints pre_candidate = + with_additional_constraints( + restrict_domain(partial_solution, get_leaves(get_pre_child(series_split))), assigned_pre_machine_views); - MachineMappingResult pre_result = - get_optimal_machine_mapping_internal(result_cache, - context, - pre_sub_tree, - resource, - pre_candidate); - + MachineMappingResultTree pre_result = ({ + std::optional returned + = get_optimal_machine_mapping_internal(result_cache, + context, + get_pre_child(series_split), + resource, + pre_candidate); + if (!returned.has_value()) { + continue; + } + returned.value(); + }); - for (std::unordered_map const &assigned_post_machine_views - : get_boundary_machine_view_assignments(boundary_layers.post_split_boundary)) { + for (MachineMapping const &assigned_post_machine_views + : get_boundary_machine_view_assignments(get_dst_layers(tensor_movement))) { - PartialMachineMapping post_candidate = - with_additional_layer_machine_views( - get_sub_solution(partial_solution, post_sub_tree), + MachineMappingConstraints post_candidate = + with_additional_constraints( + restrict_domain(partial_solution, get_leaves(get_post_child(series_split))), assigned_post_machine_views); - MachineMappingResult post_result = - get_optimal_machine_mapping_internal(result_cache, - context, - post_sub_tree, - resource, - post_candidate); - - TensorSetMovement comm_across_split = get_tensor_set_movement_across_split( - /*transitive_reduced_pcg=*/context.transitive_reduced_pcg, - /*split=*/series_split, - /*pre_mapping=*/pre_candidate, - /*post_mapping=*/post_candidate); - + MachineMappingResultTree post_result = ({ + std::optional returned + = get_optimal_machine_mapping_internal(result_cache, + context, + get_post_child(series_split), + resource, + post_candidate); + if (!returned.has_value()) { + continue; + } + returned.value(); + }); + + TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement(tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); - minimize_runtime( - optimal_result, - sequential_combine(pre_result, cost_across_split, post_result)); + result = minimize_cost(result, make_series_split(cost_across_split, pre_result, post_result)); } } - return optimal_result; + return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinaryParallelSplit const ¶llel, + MMProblemTreeParallelSplit const ¶llel, MachineSpecification const &resources, - PartialMachineMapping const &partial_solution) { - - PCGBinarySPDecomposition left_subtree = get_left_child(parallel); - PartialMachineMapping left_sub_solution = get_sub_solution(partial_solution, - left_subtree); - - PCGBinarySPDecomposition right_subtree = get_right_child(parallel); - PartialMachineMapping right_sub_solution = get_sub_solution(partial_solution, - right_subtree); + MachineMappingConstraints const &partial_solution) { MachineMappingResult optimal_result = [&] { - PCGBinarySeriesSplit series = require_series(make_pcg_series_split( - get_left_child(parallel), - get_right_child(parallel))); + MMProblemTreeSeriesSplit series = MMProblemTreeSeriesSplit{ + MMProblemTreeSeriesSplitLabel{empty_abstracted_tensor_set_movement()}, + parallel.left, + parallel.right, + }; + return get_optimal_machine_mapping_internal(result_cache, context, series, @@ -208,17 +208,22 @@ MachineMappingResult get_optimal_machine_mapping_internal( partial_solution); }(); + MachineMappingConstraints left_sub_solution = restrict_domain(partial_solution, + get_leaves(parallel.left)); + MachineMappingConstraints right_sub_solution = restrict_domain(partial_solution, + get_leaves(parallel.right)); + for (auto const &resource_split : get_machine_resource_splits(resources)) { MachineMappingResult left_result = get_optimal_machine_mapping_internal(result_cache, context, - left_subtree, + parallel.left, resource_split.first, left_sub_solution); MachineMappingResult right_result = get_optimal_machine_mapping_internal(result_cache, context, - right_subtree, + parallel.right, resource_split.second, right_sub_solution); @@ -230,23 +235,25 @@ MachineMappingResult get_optimal_machine_mapping_internal( return optimal_result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - parallel_layer_guid_t const &layer, + PCGOperatorAttrs const &layer, MachineSpecification const &resource, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &constraints) { + + assert (get_all_layers(constraints, IncludeUnconstrained{true}) == std::unordered_set{layer}); - assert (get_all_layers(partial_solution, IncludeUnconstrained{true}) == std::unordered_set{layer}); + MachineMapping concrete_mapping = require_fully_constrained(constraints); float cost = estimate_layer_cost(context.transitive_reduced_pcg.full_pcg, context.cost_estimator, layer, - get_machine_view_for_layer(partial_solution, layer).value()); + concrete_mapping.machine_views.at(layer)); - return MachineMappingResult{ + return make_leaf_node( /*runtime=*/cost, - /*machine_mapping=*/require_complete_mapping(partial_solution), + /*machine_mapping=*/concrete_mapping, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 8c84e227a7..f237fba88f 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,4 +1,6 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -15,39 +17,8 @@ TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG cons PCGBinarySeriesSplit const &split, PartialMachineMapping const &pre_mapping, PartialMachineMapping const &post_mapping) { - std::unordered_set - edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - - auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { - std::unordered_set tensor_edges = filter(edges_across_split, - [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); - - std::unordered_set src_machine_views = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_machine_view_for_layer(pre_mapping, get_src_layer(e)).value(); - }); - - std::unordered_set dst_machine_views = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_machine_view_for_layer(post_mapping, get_dst_layer(e)).value(); - }); - - return SingleTensorMovement{ - /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/src_machine_views, - /*dst_machine_views=*/dst_machine_views, - }; - }; - - std::unordered_map single_tensor_movements = - generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), - get_movement_for_tensor); - - return TensorSetMovement{ - values(single_tensor_movements), - }; + AbstractedTensorSetMovement abstracted = get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement(abstracted, pre_mapping, post_mapping); } diff --git a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc similarity index 69% rename from lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc rename to lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 5ae3126184..721fa1e32b 100644 --- a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -10,8 +10,8 @@ namespace FlexFlow { -PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &layers) { - return PartialMachineMapping{ +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &layers) { + return MachineMappingConstraints{ generate_map(layers, [](parallel_layer_guid_t const &) -> std::optional { return std::nullopt; @@ -19,7 +19,7 @@ PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set

get_all_layers(PartialMachineMapping const &partial_solution, +std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution, IncludeUnconstrained const &include_unconstrained) { std::unordered_set with_unconstrained = keys(partial_solution.machine_views); @@ -31,24 +31,22 @@ std::unordered_set get_all_layers(PartialMachineMapping c } } -std::optional get_machine_view_for_layer(PartialMachineMapping const &partial_solution, +std::optional get_machine_view_for_layer(MachineMappingConstraints const &partial_solution, parallel_layer_guid_t const &layer) { return partial_solution.machine_views.at(layer); } -PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem) { +MachineMappingConstraints get_sub_solution(MachineMappingConstraints const &partial_solution, + std::unordered_set const &sub_problem) { - std::unordered_set sub_problem_layers = unordered_set_of(get_parallel_layers(sub_problem)); - - return PartialMachineMapping{ - restrict_keys(partial_solution.machine_views, sub_problem_layers), + return MachineMappingConstraints{ + restrict_keys(partial_solution.machine_views, sub_problem), }; } -PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, +MachineMappingConstraints with_additional_layer_machine_views(MachineMappingConstraints const &partial_solution, std::unordered_map const &additional) { - PartialMachineMapping result = partial_solution; + MachineMappingConstraints result = partial_solution; for (auto const &[layer, machine_view] : additional) { std::optional current_machine_view = result.machine_views.at(layer); @@ -68,7 +66,7 @@ PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping } -MachineMapping require_complete_mapping(PartialMachineMapping const &partial_mapping) { +MachineMapping require_complete_mapping(MachineMappingConstraints const &partial_mapping) { return MachineMapping{ map_values(partial_mapping.machine_views, [](std::optional const &mv) { return mv.value(); }), diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..3aace6b332 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc @@ -0,0 +1,95 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/full_binary_tree/get_left_child.h" +#include "compiler/machine_mapping/full_binary_tree/get_right_child.h" +#include "compiler/machine_mapping/full_binary_tree/require.h" +#include "compiler/machine_mapping/full_binary_tree/visit.h" +#include "compiler/machine_mapping/full_binary_tree/get_leaves.h" +#include "utils/overload.h" +#include "compiler/machine_mapping/mm_problem_tree_split_label.h" + +namespace FlexFlow { + +MachineMappingProblemTree mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + FullBinaryTree{ + FullBinaryTreeParentNode{ + /*label=*/MMProblemTreeSplitLabel{ + MMProblemTreeSeriesSplitLabel{ + /*tensor_set_movement=*/tensor_set_movement, + }, + }, + /*lhs=*/lhs.raw_tree, + /*rhs=*/rhs.raw_tree, + }, + }, + }; +} + +MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + FullBinaryTree{ + FullBinaryTreeParentNode{ + /*label=*/MMProblemTreeSplitLabel{ + MMProblemTreeParallelSplitLabel{}, + }, + /*lhs=*/lhs.raw_tree, + /*rhs=*/rhs.raw_tree, + }, + }, + }; +} + +MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &layer) { + return MachineMappingProblemTree{ + FullBinaryTree{ + layer, + }, + }; +} + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { + return visit( + tree.raw_tree, + overload { + [](FullBinaryTreeParentNode const &parent) { + return split_label_get_node_type(parent.label); + }, + [](PCGOperatorAttrs const &) { + return SPDecompositionTreeNodeType::NODE; + } + }); +} + + +MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { + FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); + + return MMProblemTreeSeriesSplit{ + /*label=*/raw_node.label.get(), + /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, + /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, + }; +} + +MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { + FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); + + return MMProblemTreeParallelSplit{ + /*label=*/raw_node.label.get(), + /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, + /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, + }; +} + +PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &t) { + return require_leaf(t.raw_tree); +} + +std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { + return get_leaves(t.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc new file mode 100644 index 0000000000..28c6137440 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/mm_problem_tree_series_split.h" +#include "compiler/machine_mapping/full_binary_tree/require.h" + +namespace FlexFlow { + +MachineMappingProblemTree const &get_left_child(MMProblemTreeSeriesSplit const &s) { + FullBinaryTree< require_parent(s.problem_tree.raw_tree); +} + +MachineMappingProblemTree const &get_right_child(MMProblemTreeSeriesSplit const &) { + +} + +AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &) { + +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc new file mode 100644 index 0000000000..54b7a4eaf8 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc @@ -0,0 +1,17 @@ +#include "compiler/machine_mapping/mm_problem_tree_split_label.h" +#include "utils/overload.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType split_label_get_node_type(MMProblemTreeSplitLabel const &l) { + return l.visit(overload { + [](MMProblemTreeSeriesSplitLabel const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplitLabel const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc index 0fe344aef8..dad21c6c8c 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc @@ -1,5 +1,5 @@ #include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc index efa919d5b9..31a90533ff 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -1,7 +1,6 @@ #include "compiler/series_parallel/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" namespace FlexFlow { diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..de4da010e5 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc @@ -0,0 +1,176 @@ +#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_mapping_problem_tree") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + auto make_output_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + + SUBCASE("single layer") { + ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_leaf_node(input_layer); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_attrs); + + CHECK(result == correct); + } + + SUBCASE("two layers in series") { + ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + parallel_tensor_guid_t input = get_only(input_added.outputs); + + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }; + ParallelTensorShape relu_output_shape = input_shape; + ParallelLayerAddedResult relu_added = add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); + parallel_layer_guid_t relu_layer = relu_added.parallel_layer; + parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_leaf_node(input_layer), + make_pcg_leaf_node(relu_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_series_split( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + input_shape, + {input_layer}, + {relu_layer}, + }, + }}, + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(relu_attrs)); + + CHECK(result == correct); + } + + SUBCASE("two layers in parallel") { + ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + + ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_leaf_node(input1_layer), + make_pcg_leaf_node(input2_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(input_attrs)); + + CHECK(result == correct); + } + + SUBCASE("multiple tensors across split") { + ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + + ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + + PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + ParallelTensorShape ew_op_output_shape = input_shape; + ParallelLayerAddedResult ew_op_added = add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); + parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_parallel_split( + make_pcg_leaf_node(input1_layer), + make_pcg_leaf_node(input2_layer)), + make_pcg_leaf_node(ew_op_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_series_split( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{input1_layer}, + /*dst_machine_views=*/{ew_op_layer}, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{input2_layer}, + /*dst_machine_views=*/{ew_op_layer}, + }, + }}, + /*pre=*/mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(input_attrs)), + /*post=*/mm_problem_tree_make_leaf(ew_op_attrs)); + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 3c4ac1174c..02b3fe4a03 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -8,12 +8,13 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_optimal_machine_mapping") { + TEST_CASE("get_optimal_machine_mapping_internal") { auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, MachineSpecification const &) { return std::unordered_set{ make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; + MachineSpecification machine_spec = MachineSpecification{ /*num_nodes=*/2, /*num_cpus_per_node=*/1, @@ -32,7 +33,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView mv1 = make_1d_machine_view(gpu_id_t{1}, gpu_id_t{2}); auto allowed_machine_views = [&](ParallelLayerAttrs const &, - MachineSpecification const &) { + MachineSpecification const &) { return std::unordered_set{mv1}; }; @@ -93,6 +94,10 @@ TEST_SUITE(FF_TEST_SUITE) { FAIL("TODO"); } + SUBCASE("multiple edges across split") { + FAIL("TODO"); + } + // SUBCASE("simple PCG") { // // ParallelComputationGraph pcg_simple = [&] { diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index b6f7790c49..a799e01dbc 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -46,6 +46,8 @@ std::vector ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b26478107d..1562425a80 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -143,6 +143,11 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, return pcg.raw_graph.at(l.raw_graph_node); } +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).op_attrs; +} + ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, parallel_tensor_guid_t const &t) { diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h new file mode 100644 index 0000000000..3d94996079 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::string format_as(FullBinaryTreeParentNode const &t) { + return fmt::format("<{} ({} {})>", + t.label, + get_left_child(t), + get_right_child(t)); +} + +template +std::string format_as(FullBinaryTree const &t) { + return visit( + t, + overload{ + [](FullBinaryTreeParentNode const &parent) { + return fmt::to_string(parent); + }, + [](LeafLabel const &leaf) { + return fmt::format("{}", leaf); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h new file mode 100644 index 0000000000..f90ffb88c4 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h @@ -0,0 +1,87 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct FullBinaryTree; + +template +struct FullBinaryTreeParentNode { + explicit FullBinaryTreeParentNode( + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) + : label(label), + left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) + { } + + FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; + + bool operator==(FullBinaryTreeParentNode const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTreeParentNode const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTreeParentNode const &other) const { + return this->tie() < other.tie(); + } +public: + ParentLabel label; + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; +private: + std::tuple const &, + FullBinaryTree const &> + tie() const { + return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct FullBinaryTree { +public: + FullBinaryTree() = delete; + explicit FullBinaryTree(FullBinaryTreeParentNode const &t) + : root{t} {} + + explicit FullBinaryTree(LeafLabel const &t) + : root{t} {} + + bool operator==(FullBinaryTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTree const &other) const { + return this->tie() < other.tie(); + } +public: + std::variant, LeafLabel> root; +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml new file mode 100644 index 0000000000..1f8af17cf3 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h new file mode 100644 index 0000000000..c58a850a6d --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include +#include "utils/containers/multiset_union.h" + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(FullBinaryTree const &t) { + return visit>( + t, + overload { + [](FullBinaryTreeParentNode const &parent) { + return multiset_union(get_leaves(get_left_child(parent)), + get_leaves(get_right_child(parent))); + }, + [](ChildLabel const &leaf) { + return std::unordered_multiset{leaf}; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h new file mode 100644 index 0000000000..163503abfd --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { + return *t.left_child_ptr; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h new file mode 100644 index 0000000000..e1cbe909d5 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H + +#include "utils/overload.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" + +namespace FlexFlow { + +template +FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { + return visit( + t, + overload { + [](FullBinaryTreeParentNode const &) { + return FullBinaryTreeNodeType::PARENT; + }, + [](LeafLabel const &) { + return FullBinaryTreeNodeType::LEAF; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h new file mode 100644 index 0000000000..e40f2024a1 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { + return *t.right_child_ptr; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h new file mode 100644 index 0000000000..a29836f972 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/hash.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::FullBinaryTreeParentNode> { + size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { + return get_std_hash(t.tie()); + } +}; + +template +struct hash<::FlexFlow::FullBinaryTree> { + size_t operator()(::FlexFlow::FullBinaryTree const &t) const { + return get_std_hash(t.tie()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h new file mode 100644 index 0000000000..0e5ad4914a --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTreeParentNode const &require_parent_node(FullBinaryTree const &t) { + return std::get>(t.root); +} + +template +LeafLabel const &require_leaf(FullBinaryTree const &t) { + return std::get(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/transform.h b/lib/utils/include/utils/full_binary_tree/transform.h new file mode 100644 index 0000000000..3fef8efd18 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/transform.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/overload.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template , + typename LeafLabel2 = std::invoke_result_t> +FullBinaryTreeParentNode transform(FullBinaryTreeParentNode const &t, F f) { + return FullBinaryTreeParentNode{ + transform(get_left_child(t), f), + transform(get_right_child(t), f), + }; +} + +template , + typename LeafLabel2 = std::invoke_result_t> +FullBinaryTree transform(FullBinaryTree const &t, F f) { + return visit> + ( t, + overload { + [&](FullBinaryTreeParentNode const &parent) { + return FullBinaryTree{ + transform(parent, f), + }; + }, + [&](LeafLabel const &leaf) { + return FullBinaryTree{ + f(leaf), + }; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h new file mode 100644 index 0000000000..93e5bfb504 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +Result visit(FullBinaryTree const &tt, F f) { + if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative(tt.root)) { + return f(std::get(tt.root)); + } else { + throw mk_runtime_error( + "Unexpected case in visit(FullBinaryTree)"); + } +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml index 985fb3089d..0dcae5177a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinaryParallelSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml index c7c89da6d2..45472cb243 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinarySeriesSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml index 1241311150..0000213398 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinarySPDecompositionTree" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h deleted file mode 100644 index 42d71ce54e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace FlexFlow { - -template -std::string format_as(GenericBinarySeriesSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySeriesSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinaryParallelSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinaryParallelSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return fmt::format("", s); - }, - [](GenericBinaryParallelSplit const &s) { - return fmt::format("", s); - }, - [](T const &t) { - return fmt::format("", t); - }, - }); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySPDecompositionTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..e3d92c7409 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "GenericBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "label" +type = "ParallelSplitLabel" + +[[fields]] +name = "lhs" +type = "GenericBinarySPDecompositionTree" + +[[fields]] +name = "rhs" +type = "GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml new file mode 100644 index 0000000000..db11340d6e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "GenericBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "label" +type = "SeriesSplitLabel" + +[[fields]] +name = "pre" +type = "GenericBinarySPDecompositionTree" + +[[fields]] +name = "post" +type = "GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h deleted file mode 100644 index 74f5ba5d8a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h +++ /dev/null @@ -1,155 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct GenericBinarySPDecompositionTree; - -template -struct GenericBinarySeriesSplit { -public: - GenericBinarySeriesSplit() = delete; - explicit GenericBinarySeriesSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; - - bool operator==(GenericBinarySeriesSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySeriesSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySeriesSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinaryParallelSplit { -public: - GenericBinaryParallelSplit() = delete; - explicit GenericBinaryParallelSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; - - bool operator==(GenericBinaryParallelSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinaryParallelSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinaryParallelSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinarySPDecompositionTree { -public: - GenericBinarySPDecompositionTree() = delete; - explicit GenericBinarySPDecompositionTree( - GenericBinarySeriesSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree( - GenericBinaryParallelSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} - - GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = - default; - - bool operator==(GenericBinarySPDecompositionTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySPDecompositionTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySPDecompositionTree const &other) const { - return this->tie() < other.tie(); - } - -public: - std::variant, GenericBinaryParallelSplit, T> - root; - -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -// namespace rc { -// -// template <> -// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { -// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { -// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { -// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); -// }; -// -// } // namespace rc - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..236274e617 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + +includes = [ + "utils/full_binary_tree/full_binary_tree.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::FullBinaryTree, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h deleted file mode 100644 index c6c1186d3d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -TT const &get(GenericBinarySPDecompositionTree const &t) { - return std::get(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 51e1e20bac..cad88d25b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #include "utils/containers/multiset_union.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" @@ -11,26 +11,26 @@ namespace FlexFlow { -template -std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { - return visit>( +template +std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &tt) { + return visit>( tt, overload{ - [](T const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, - [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + [](LeafLabel const &t) { return std::unordered_multiset{t}; }, + [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, + [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, }); } -template -std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { +template +std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { return multiset_union(get_leaves(get_left_child(s)), get_leaves(get_right_child(s))); } -template -std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { +template +std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { return multiset_union(get_leaves(get_left_child(p)), get_leaves(get_right_child(p))); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h index 46a460b64e..9e857341c6 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -1,42 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &s) { - return *s.left_child_ptr; +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &s) { + return s.pre; } -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &p) { - return *p.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_left_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_left_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_left_child incorrectly called on leaf node"); - }, - }); +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &p) { + return p.lhs; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 883acda480..888d3c6627 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -1,27 +1,32 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/full_binary_tree/visit.h" #include "utils/overload.h" namespace FlexFlow { -template +template SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tt) { + get_node_type(GenericBinarySPDecompositionTree const &tt) { return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplit const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](T const &) { return SPDecompositionTreeNodeType::NODE; }, - }); + tt.raw_tree, + overload { + [](LeafLabel const &) { + return SPDecompositionTreeNodeType::NODE; + }, + [](FullBinaryTreeParentNode, LeafLabel> const &parent) { + if (std::holds_alternative(parent.label)) { + return SPDecompositionTreeNodeType::SERIES; + } else { + assert (std::holds_alternative(parent.label)); + + return SPDecompositionTreeNodeType::PARALLEL; + } + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h index f0bfba43a2..766995b8a9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -1,42 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &s) { - return *s.right_child_ptr; +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &s) { + return s.post; } -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &p) { - return *p.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_right_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_right_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_right_child incorrectly called on leaf node"); - }, - }); +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &p) { + return p.rhs; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h deleted file mode 100644 index 983dc4a572..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::GenericBinarySeriesSplit> { - size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinaryParallelSplit> { - size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { - size_t operator()( - ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { - return get_std_hash(s.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h index 8086f38244..bdaf8bcc2b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -1,23 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" namespace FlexFlow { -template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); +template +bool is_series_split(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::SERIES; } -template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); +template +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::PARALLEL; } -template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative(t.root); +template +bool is_leaf(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::NODE; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 3ffa63753a..1ec84f194f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" @@ -9,19 +9,19 @@ namespace FlexFlow { -template +template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ - [](T const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](LeafLabel const &) { return true; }, + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_right_child(s)) && is_binary_sp_tree_left_associative(get_left_child(s)) && is_binary_sp_tree_left_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_right_child(p)) && is_binary_sp_tree_left_associative(get_left_child(p)) && is_binary_sp_tree_left_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index d88459b432..a3ff9d4012 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" @@ -9,19 +9,19 @@ namespace FlexFlow { -template +template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ - [](T const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](LeafLabel const &t) { return true; }, + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_left_child(s)) && is_binary_sp_tree_right_associative(get_left_child(s)) && is_binary_sp_tree_right_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_left_child(p)) && is_binary_sp_tree_right_associative(get_left_child(p)) && is_binary_sp_tree_right_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index f55b71146a..e925292b35 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -1,37 +1,49 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree make_generic_binary_series_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - lhs, - rhs, - }, +template +GenericBinarySPDecompositionTree make_generic_binary_series_split( + SeriesLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + label, + lhs.raw_tree, + rhs.raw_tree, + } + } }; } -template -GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - lhs, - rhs, - }, +template +GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + SeriesLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + label, + lhs.raw_tree, + rhs.raw_tree, + } + } }; } -template -GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { - return GenericBinarySPDecompositionTree{t}; +template +GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(LeafLabel const &leaf) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + leaf, + }, + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 4137585c1a..1c20de06dc 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -1,26 +1,38 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySeriesSplit const & - require_series(GenericBinarySPDecompositionTree const &t) { - return get>(t); +template +GenericBinarySeriesSplit + require_series(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + + return GenericBinarySeriesSplit{ + /*label=*/std::get(parent.label), + /*pre=*/get_left_child(parent), + /*post=*/get_right_child(parent), + }; } -template -GenericBinaryParallelSplit const & - require_parallel(GenericBinarySPDecompositionTree const &t) { - return get>(t); +template +GenericBinaryParallelSplit + require_parallel(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + + return GenericBinarySeriesSplit{ + /*label=*/std::get(parent.label), + /*pre=*/get_left_child(parent), + /*post=*/get_right_child(parent), + }; } -template -T const &require_leaf(GenericBinarySPDecompositionTree const &t) { - return get(t); +template +LeafLabel require_leaf(GenericBinarySPDecompositionTree const &t) { + return require_leaf(t.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index 08ab99a292..c557711a3b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -1,49 +1,70 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" namespace FlexFlow { -template > -GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &s, F f) { - return GenericBinarySeriesSplit{ +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &s, F f) { + return GenericBinarySeriesSplit{ + f(s.label), transform(get_left_child(s), f), transform(get_right_child(s), f), }; }; -template > -GenericBinaryParallelSplit - transform(GenericBinaryParallelSplit const &s, F f) { - return GenericBinaryParallelSplit{ +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinaryParallelSplit + transform(GenericBinaryParallelSplit const &s, F f) { + return GenericBinaryParallelSplit{ + f(s.label), transform(get_left_child(s), f), transform(get_right_child(s), f), }; }; -template > -GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, F f) { - return visit>( +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, F f) { + return visit>( tt, overload{ - [&](GenericBinarySeriesSplit const &s) { - return GenericBinarySPDecompositionTree{ + [&](GenericBinarySeriesSplit const &s) { + return GenericBinarySPDecompositionTree{ transform(s, f), }; }, - [&](GenericBinaryParallelSplit const &s) { - return GenericBinarySPDecompositionTree{ + [&](GenericBinaryParallelSplit const &s) { + return GenericBinarySPDecompositionTree{ transform(s, f), }; }, - [&](T const &t) { - return GenericBinarySPDecompositionTree{ + [&](LeafLabel const &t) { + return GenericBinarySPDecompositionTree{ f(t), }; }, diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 0d9503e59f..ce4e4ebf55 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -2,34 +2,29 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, F f) { - if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative(tt.root)) { - return f(std::get(tt.root)); - } else { - throw mk_runtime_error( - "Unexpected case in visit(GenericBinarySPDecompositionTree)"); +template +Result visit(GenericBinarySPDecompositionTree const &tt, F f) { + SPDecompositionTreeNodeType node_type = get_node_type(tt); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + Result result = f(require_series_split(t)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + Result result = f(require_parallel_split(t)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + Result result = f(require_leaf(t)); + return result; + } + default: + throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); } - - // return std::visit(tt.root, overload { - // [&](GenericBinarySeriesSplit const &s) -> Result { - // return f(s); - // }, - // [&](GenericBinaryParallelSplit const &p) -> Result { - // return f(p); - // }, - // [&](T const &t) -> Result { - // return f(t); - // }, - // }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..628cf89a44 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template +std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { + return get_leaves(t.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h new file mode 100644 index 0000000000..9d4ce10cb4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &s) { + return s.lhs; +} + +template +LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &s) { + return s.rhs; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..b92175b16f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", +] + +[[fields]] +name = "lhs" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" + +[[fields]] +name = "rhs" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml new file mode 100644 index 0000000000..0506d36227 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinaryParallelSplitLabel" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h new file mode 100644 index 0000000000..853def2c60 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &s) { + return s.pre; +} + +template +LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &s) { + return s.post; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml new file mode 100644 index 0000000000..a7ff2dcc70 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "pre" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" + +[[fields]] +name = "post" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml new file mode 100644 index 0000000000..b780bfeea6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySeriesSplitLabel" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..dacab0244a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt" +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::LeafOnlyBinarySeriesSplitLabel, ::FlexFlow::LeafOnlyBinaryParallelSplitLabel, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h new file mode 100644 index 0000000000..222799dbe9 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, + LeafOnlyBinarySPDecompositionTree const &post) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_series_split( + LeafOnlyBinaryParallelSplitLabel{}, + pre, + post), + }; +} + +template +LeafOnlyBinarySPDecompositionTree make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, + LeafOnlyBinarySPDecompositionTree const &rhs) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_series_split( + LeafOnlyBinaryParallelSplitLabel{}, + lhs, + rhs), + }; +} + +template +LeafOnlyBinarySPDecompositionTree make_leaf_node(LeafLabel const &label) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_sp_leaf< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel>(label), + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h new file mode 100644 index 0000000000..9011fadd78 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySeriesSplit + require_series(LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinarySeriesSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel> raw = + require_series(t.raw_tree); + + return LeafOnlyBinarySeriesSplit{ + LeafOnlyBinarySeriesSplitLabel{}, + LeafOnlyBinarySPDecompositionTree{raw.pre}, + LeafOnlyBinarySPDecompositionTree{raw.post}, + }; +} + +template +LeafOnlyBinaryParallelSplit + require_parallel(LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinarySeriesSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel> raw = + require_series(t.raw_tree); + + return LeafOnlyBinarySeriesSplit{ + LeafOnlyBinaryParallelSplitLabel{}, + LeafOnlyBinarySPDecompositionTree{raw.pre}, + LeafOnlyBinarySPDecompositionTree{raw.post}, + }; +} + +template +LeafLabel require_leaf(LeafOnlyBinarySPDecompositionTree const &t) { + return require_leaf(t.raw_tree); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h new file mode 100644 index 0000000000..364a3200b1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template > +LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinarySeriesSplit{ + transform(t.pre, f), + transform(t.post, f), + }; +} + +template > +LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinaryParallelSplit{ + transform(t.lhs, f), + transform(t.rhs, f), + }; +} + +template > +LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinarySPDecompositionTree{ + transform(t.raw_tree, ff), + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc new file mode 100644 index 0000000000..41b7b79101 --- /dev/null +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -0,0 +1,35 @@ +#include "utils/containers/flatmap.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::unordered_set, F)") { + auto get_chars = [](std::string const &s) { + std::unordered_set result; + for (char c : s) { + result.insert(c); + } + return result; + }; + + SUBCASE("type changing") { + std::unordered_set input = {"hello", " ", "", "world", "!"}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::unordered_set input = {}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} From b0475b4c97ee3e3f2bc31bfdb8d8cffa71d70522 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 30 Sep 2024 14:50:03 -0700 Subject: [PATCH 17/29] Settle on ProblemTree/BinaryTreePath-indexed-MachineMappingResult for machine mapping --- flake.lock | 6 +- ...tracted_single_tensor_movement.struct.toml | 6 +- .../abstracted_tensor_set_movement.h | 4 +- ...tracted_tensor_set_movement_across_split.h | 2 +- .../machine_mapping/estimate_layer_cost.h | 5 +- ...easible_machine_mapping_result.struct.toml | 20 ++ .../get_allowed_machine_views_list.h | 26 --- .../get_optimal_machine_mapping.h | 47 ++-- .../machine_mapping/machine_mapping_cache.h | 8 +- .../machine_mapping_constraints.h | 13 +- .../machine_mapping_constraints.struct.toml | 4 +- .../machine_mapping_context.struct.toml | 2 +- .../machine_mapping_problem_tree.h | 10 +- .../mm_problem_tree_parallel_split.h | 14 ++ .../unmapped_op_cost_estimate_key.h | 19 ++ .../machine_mapping/machine_mapping_result.h | 17 +- .../machine_mapping_result.struct.toml | 13 +- .../machine_mapping_result_tree.h | 19 -- .../machine_mapping_result_tree.struct.toml | 18 -- .../mm_result_tree_parallel_split.struct.toml | 18 -- ...sult_tree_parallel_split_label.struct.toml | 13 -- .../mm_result_tree_series_split.struct.toml | 18 -- ...result_tree_series_split_label.struct.toml | 13 -- .../parallel_split_transformation.enum.toml | 14 ++ .../pcg_binary_sp_decomposition.h | 7 + .../abstracted_tensor_set_movement.cc | 7 +- ...racted_tensor_set_movement_across_split.cc | 15 +- .../machine_mapping/estimate_layer_cost.cc | 2 +- .../get_allowed_machine_views_list.cc | 81 ------- .../get_optimal_machine_mapping.cc | 155 +++++-------- .../machine_mapping_constraints.cc | 37 +-- .../machine_mapping_problem_tree.cc | 95 -------- .../get_machine_mapping_problem_tree.cc | 9 +- .../machine_mapping_problem_tree.cc | 78 +++++++ .../mm_problem_tree_parallel_split.cc | 20 ++ .../mm_problem_tree_series_split.cc | 21 ++ .../unmapped_op_cost_estimate_key.cc | 33 +++ .../machine_mapping/machine_mapping_result.cc | 115 ++++++++-- .../machine_mapping_result_tree.cc | 59 +++++ .../mm_result_tree_parallel_split.cc | 13 ++ .../mm_result_tree_series_split.cc | 13 ++ .../pcg_binary_sp_decomposition.cc | 34 ++- .../get_machine_mapping_problem_tree.cc | 0 .../utils/full_binary_tree/binary_tree_path.h | 17 ++ .../binary_tree_path.struct.toml | 24 ++ .../binary_tree_path_entry.enum.toml | 16 ++ .../full_binary_tree/find_paths_to_leaf.h | 40 ++++ .../include/utils/full_binary_tree/fmt.h | 10 + .../utils/full_binary_tree/get_child.h | 26 +++ .../full_binary_tree/get_subtree_at_path.h | 38 +++ .../include/utils/full_binary_tree/json.h | 75 ++++++ .../include/utils/full_binary_tree/visit.h | 1 + .../binary_sp_decomposition_tree.h | 24 ++ .../find_paths_to_leaf.h | 17 ++ .../generic_binary_parallel_split.struct.toml | 5 +- .../generic_binary_series_split.struct.toml | 5 +- ...c_binary_sp_decomposition_tree.struct.toml | 10 +- .../generic_binary_sp_split_label.h | 20 ++ ...generic_binary_sp_split_label.variant.toml | 21 ++ .../get_node_type.h | 11 +- .../get_num_tree_nodes.h | 20 +- .../get_subtree_at_path.h | 28 +++ .../json.h | 103 --------- .../make.h | 16 +- .../require.h | 29 ++- .../visit.h | 8 +- .../wrap.h | 37 +++ .../find_paths_to_leaf.h | 18 ++ .../get_node_type.h | 16 ++ .../is_binary_sp_tree_left_associative.h | 16 ++ .../is_binary_sp_tree_right_associative.h | 16 ++ ...eaf_only_binary_parallel_split.struct.toml | 2 +- .../make.h | 12 +- .../require.h | 13 +- .../wrap.h | 46 ++++ .../full_binary_tree/binary_tree_path.cc | 32 +++ .../binary_parallel_split.cc | 3 +- .../binary_series_split.cc | 3 +- .../binary_sp_decomposition_tree.cc | 22 +- .../fmt.cc | 1 - .../generic_binary_sp_decomposition_tree.cc | 1 - .../get.cc | 1 - .../hash.cc | 1 - .../json.cc | 1 - ...ft_associative_binary_sp_tree_from_nary.cc | 43 ++-- ...ht_associative_binary_sp_tree_from_nary.cc | 41 ++-- .../intermediate_sp_decomposition_tree.cc | 17 +- .../test/src/utils/containers/flatmap.cc | 1 + .../fmt.cc | 51 ----- .../get_leaves.cc | 151 ++++++------ .../get_left_child.cc | 41 ---- .../get_num_tree_nodes.cc | 151 ++++++------ .../get_right_child.cc | 41 ---- .../hash.cc | 216 +++++++++--------- .../is_binary_sp_tree_left_associative.cc | 182 +++++++-------- .../is_binary_sp_tree_right_associative.cc | 188 +++++++-------- .../json.cc | 131 ----------- .../transform.cc | 36 +-- 98 files changed, 1744 insertions(+), 1474 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml rename lib/compiler/src/compiler/machine_mapping/{ => abstracted_tensor_set_movement}/abstracted_tensor_set_movement.cc (88%) rename lib/compiler/src/compiler/machine_mapping/{ => abstracted_tensor_set_movement}/get_abstracted_tensor_set_movement_across_split.cc (79%) delete mode 100644 lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc rename lib/compiler/src/compiler/machine_mapping/{ => machine_mapping_problem_tree}/get_machine_mapping_problem_tree.cc (77%) create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc rename lib/compiler/test/src/compiler/machine_mapping/{ => machine_mapping_problem_tree}/get_machine_mapping_problem_tree.cc (100%) create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path.h create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml create mode 100644 lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h create mode 100644 lib/utils/include/utils/full_binary_tree/json.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h create mode 100644 lib/utils/src/utils/full_binary_tree/binary_tree_path.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc diff --git a/flake.lock b/flake.lock index 1aad68ae29..c5f86c613c 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722923482, - "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", + "lastModified": 1727727609, + "narHash": "sha256-BSnh4wZV7LLXDQ4YIhCHz/uJ4N88vv5cBb1LKWJlltM=", "owner": "lockshaw", "repo": "proj", - "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", + "rev": "e17da953eaea9e728e9dfde9c12a2435122253b1", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml index fcae1e2356..449a448706 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -8,7 +8,7 @@ features = [ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", "", ] @@ -23,8 +23,8 @@ type = "::FlexFlow::ParallelTensorShape" [[fields]] name = "src_machine_views" -type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" [[fields]] name = "dst_machine_views" -type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h index 80e91b0f85..5917a8fb26 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -9,8 +9,8 @@ namespace FlexFlow { AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); -std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); -std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &, MachineMapping const &pre, diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h index 33f44a3a11..3a34e956ad 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -3,7 +3,7 @@ #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" -#include "compiler/machine_mapping/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h index dcb8856fe8..a862f0c476 100644 --- a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h +++ b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h @@ -6,9 +6,8 @@ #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" namespace FlexFlow { -float estimate_layer_cost(ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - parallel_layer_guid_t const &layer, +float estimate_layer_cost(CostEstimator const &cost_estimator, + UnmappedOpCostEstimateKey const &key, MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml new file mode 100644 index 0000000000..c75c968a90 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "FeasibleMachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "parallel_layer_guid_oblivious_machine_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h b/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h deleted file mode 100644 index 1da08daf1b..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/get_allowed_machine_views_list.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ -#define _FLEXFLOW_ALLOWED_MACHINE_MAPPINGS_H_ - -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include - -namespace FlexFlow { - -std::vector> - get_allowed_machine_views_list( - MachineMappingContext const &context, - std::unordered_set const &layers, - MachineSpecification const &resource); - -std::vector> - get_allowed_src_machine_views_list( - MachineMappingContext const &context, - std::unordered_set const &values, - MachineSpecification const &resource); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 3c71d78093..1c52ccc2bb 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -1,62 +1,43 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H -#include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" -#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "pcg/machine_specification.dtg.h" namespace FlexFlow { -MachineMappingResultTree get_optimal_machine_mapping( - ParallelComputationGraph const &pcg, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - MachineMappingCache &cached_subgraph_results); - -MachineMappingResultTree - get_optimal_machine_mapping_internal(MachineMappingCache &result_cache, - MachineMappingContext const &context, - MachineSpecification const &resources); - -std::optional get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, - MachineMappingProblemTree const &, + MachineMappingProblemTree const &problem_tree, MachineSpecification const &resources, - MachineMappingConstraints const &); + MachineMappingConstraints const &constraints); -std::optional get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, - MMProblemTreeSeriesSplit const &, + MMProblemTreeSeriesSplit const &series_split, MachineSpecification const &resources, - MachineMappingConstraints const &); + MachineMappingConstraints const &constraints); -std::optional get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, - MMProblemTreeParallelSplit const &, + MMProblemTreeParallelSplit const ¶llel_split, MachineSpecification const &resources, - MachineMappingConstraints const &); + MachineMappingConstraints const &constraints); -std::optional get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &, - parallel_layer_guid_t const &, - MachineSpecification const &, - MachineMappingConstraints const &); + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index b4608a90e0..4e72cc1d76 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H -#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_state.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" #include "utils/optional.h" namespace FlexFlow { @@ -11,11 +11,11 @@ class MachineMappingCache { public: MachineMappingCache() = default; - std::optional load(MachineMappingState const &) const; - void save(MachineMappingState const &, MachineMappingResultTree const &); + std::optional load(MachineMappingState const &) const; + void save(MachineMappingState const &, std::optional const &); private: - std::unordered_map cache; + std::unordered_map cache; }; } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index 320a840bf6..f0c81f3ecd 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H #include "compiler/machine_mapping/machine_mapping.dtg.h" -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/include_unconstrained.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" @@ -10,22 +9,24 @@ namespace FlexFlow { -MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &); +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &); -std::unordered_set get_all_layers(MachineMappingConstraints const &, - IncludeUnconstrained const &); +std::unordered_set get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); std::optional get_machine_view_for_layer(MachineMappingConstraints const &, - parallel_layer_guid_t const &); + BinaryTreePath const &); MachineMappingConstraints restrict_domain(MachineMappingConstraints const &, - std::unordered_set const &); + BinaryTreePathEntry const &); MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, MachineMapping const &); MachineMapping require_fully_constrained(MachineMappingConstraints const &); +std::optional require_only_root(MachineMappingConstraints const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml index 7211c773bb..8e13abedb9 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -7,8 +7,8 @@ features = [ ] includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", "", ] @@ -20,4 +20,4 @@ src_includes = [ [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, std::optional<::FlexFlow::MachineView>>" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index 505141d59f..c4bf1d1ac8 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -14,4 +14,4 @@ type = "::FlexFlow::CostEstimator" [[fields]] name = "allowed_machine_views" -type = "std::function(::FlexFlow::ParallelLayerAttrs const &, ::FlexFlow::MachineSpecification const &)>" +type = "std::function(::FlexFlow::UnmappedOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 29b5cf24d5..20e3a11399 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -4,6 +4,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" namespace FlexFlow { @@ -15,15 +16,18 @@ MachineMappingProblemTree MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, MachineMappingProblemTree const &rhs); -MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &); +MachineMappingProblemTree mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &); SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &); MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &); -PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &); +UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &); -std::unordered_multiset get_leaves(MachineMappingProblemTree const &); +std::unordered_multiset get_leaves(MachineMappingProblemTree const &); + +std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, + BinaryTreePath const &); template Result visit(MachineMappingProblemTree const &t, F &&f) { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h new file mode 100644 index 0000000000..63e724fa94 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_MM_PROBLEM_TREE_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_MM_PROBLEM_TREE_PARALLEL_SPLIT_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_lhs_child(MMProblemTreeParallelSplit const &); +MachineMappingProblemTree get_rhs_child(MMProblemTreeParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h new file mode 100644 index 0000000000..e90dcdd94c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +OpCostEstimateKey map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 0cdd283582..6fb4c70d2e 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -2,12 +2,25 @@ #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H #include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" namespace FlexFlow { -MachineMappingResult get_infinity_machine_mapping_result(); +MachineMappingResult infeasible_machine_mapping_result(); +bool is_infeasible(MachineMappingResult const &); +FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); -void minimize_runtime(MachineMappingResult &m1, MachineMappingResult const &m2); +MachineMappingResult series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const ¶llel_split_transformation); +MachineMappingResult parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); + +[[nodiscard]] MachineMappingResult minimize_runtime(MachineMappingResult const &m1, MachineMappingResult const &m2); + +MachineMappingResult make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml index f2f2e15e9a..28b124cea3 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -7,13 +7,14 @@ features = [ ] includes = [ - "compiler/machine_mapping/machine_mapping.dtg.h", + "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", + "", ] -[[fields]] -name = "runtime" -type = "float" +src_includes = [ + "utils/fmt/optional.h", +] [[fields]] -name = "machine_mapping" -type = "::FlexFlow::MachineMapping" +name = "raw_result" +type = "std::optional" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h deleted file mode 100644 index 0ddbc08297..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H - -#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" - -namespace FlexFlow { - -MachineMappingResultTree make_series_split(float comm_cost, - MachineMappingResultTree const &pre, - MachineMappingResultTree const &post); -MachineMappingResultTree make_parallel_split(MachineMappingResultTree const &lhs, - MachineMappingResultTree const &rhs); -MachineMappingResultTree make_leaf_node(float cost, MachineView const &); - -std::optional minimize_cost(std::optional const &, MachineMappingResultTree const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml deleted file mode 100644 index 69c7a613e0..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingResultTree" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", - "pcg/machine_view.dtg.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml deleted file mode 100644 index ceb85e26eb..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "MMResultTreeParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/parallel_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", - "pcg/machine_view.dtg.h", -] - -[[fields]] -name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml deleted file mode 100644 index 6bc880e1fb..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "MMResultTreeParallelSplitLabel" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [] - -[[fields]] -name = "cost" -type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml deleted file mode 100644 index 9210d1c80c..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "MMResultTreeSeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", - "pcg/machine_view.dtg.h", -] - -[[fields]] -name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml deleted file mode 100644 index 0f0a326fb5..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml +++ /dev/null @@ -1,13 +0,0 @@ -namespace = "FlexFlow" -name = "MMResultTreeSeriesSplitLabel" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [] - -[[fields]] -name = "cost" -type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml new file mode 100644 index 0000000000..8247c0cbdc --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelSplitTransformation" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LthenR" + +[[values]] +name = "RthenL" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h index eca0cd7d0b..2744393fc2 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h @@ -5,6 +5,7 @@ #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include @@ -21,10 +22,16 @@ PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &); +PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &); +PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &); + PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &); PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &); parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &); +std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); + template ReturnType visit(PCGBinarySPDecomposition const &d, F &&f) { SPDecompositionTreeNodeType node_type = get_node_type(d); diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc similarity index 88% rename from lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc rename to lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 96605fa238..69242e4076 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -1,5 +1,4 @@ -#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "utils/containers/flatmap.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/transform.h" @@ -10,14 +9,14 @@ AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { return AbstractedTensorSetMovement{{}}; } -std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { return flatmap(unordered_set_of(m.single_tensor_movements), [](AbstractedSingleTensorMovement const &s) { return s.src_machine_views; }); } -std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { return flatmap(unordered_set_of(m.single_tensor_movements), [](AbstractedSingleTensorMovement const &s) { return s.dst_machine_views; diff --git a/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc similarity index 79% rename from lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc rename to lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 2c17fc089d..cea2df8073 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,18 +1,25 @@ -#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" #include "utils/containers/values.h" namespace FlexFlow { AbstractedTensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + auto get_path_to_layer = [&](parallel_layer_guid_t const &l) { + return get_only(find_paths_to_leaf(wrap_series_split(split), l)); + }; + std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - + auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { std::unordered_set tensor_edges = filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); @@ -31,8 +38,8 @@ AbstractedTensorSetMovement get_tensor_set_movement_across_split(TransitiveReduc return AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/src_layers, - /*dst_machine_views=*/dst_layers, + /*src_machine_views=*/transform(src_layers, get_path_to_layer), + /*dst_machine_views=*/transform(dst_layers, get_path_to_layer), }; }; diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc index c01354f68b..2df6ddb859 100644 --- a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -4,7 +4,7 @@ namespace FlexFlow { float estimate_layer_cost(CostEstimator const &cost_estimator, - PCGOperatorAttrs const &layer, + UnmappedOpCostEstimateKey const &key, MachineView const &machine_view) { PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, layer).op_attrs; diff --git a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc b/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc deleted file mode 100644 index 717aa66a9b..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/get_allowed_machine_views_list.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include "compiler/machine_mapping/get_allowed_machine_views_list.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "utils/containers.h" -#include "utils/containers/get_first.h" -#include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/set_minus.h" - -namespace FlexFlow { - -std::vector> - get_allowed_machine_views_list( - MachineMappingContext const &context, - std::unordered_set const &layers, - MachineSpecification const &resource) { - NOT_IMPLEMENTED(); - - // if (layers.empty()) { - // return {{}}; - // } - // parallel_layer_guid_t curr_layer = get_first(layers); - // std::unordered_set other_layers = - // set_minus(layers, {curr_layer}); - // - // std::vector> - // other_machine_views_from_recursion = - // get_allowed_machine_views_list(context, other_layers, resource); - // - // ParallelLayerAttrs curr_layer_attrs = - // get_parallel_layer_attrs(context.pcg, curr_layer); - // std::unordered_set allowed_machine_views_for_curr_layer = - // context.allowed_machine_views(curr_layer_attrs, resource); - // - // std::vector> result; - // - // for (MachineView const &for_curr_node : - // allowed_machine_views_for_curr_layer) { - // for (std::unordered_map const - // &for_other_layers : other_machine_views_from_recursion) { - // result.push_back( - // merge_maps(for_other_layers, - // std::unordered_map{ - // {curr_layer, for_curr_node}})); - // } - // } - // return result; -} - -std::vector> - get_allowed_src_machine_views_list( - MachineMappingContext const &context, - std::unordered_set const &tensors, - MachineSpecification const &resource) { - NOT_IMPLEMENTED(); - - // std::unordered_set layers; - // for (parallel_tensor_guid_t const &tensor : tensors) { - // layers.insert(get_source_layer(tensor)); - // } - // - // std::vector> - // machine_views_for_layers_list = - // get_allowed_machine_views_list(context, layers, resource); - // - // std::vector> result; - // - // for (std::unordered_map - // machine_views_for_layers : machine_views_for_layers_list) { - // std::unordered_map - // machine_views_for_tensors; - // for (parallel_tensor_guid_t const &tensor : tensors) { - // machine_views_for_tensors.emplace( - // tensor, machine_views_for_layers.at(get_source_layer(tensor))); - // } - // result.push_back(machine_views_for_tensors); - // } - // - // return result; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index d24ccaf63e..4d4fe8a7de 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -7,7 +7,9 @@ #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" #include "compiler/machine_mapping/mm_problem_tree_series_split.h" @@ -38,41 +40,6 @@ namespace FlexFlow { MachineMappingResult get_optimal_machine_mapping( - ParallelComputationGraph const &pcg, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - MachineMappingCache &result_cache) { - - MachineMappingContext context = make_machine_mapping_context( - pcg, - cost_estimator, - allowed_machine_views); - - MachineMappingResult result = - get_optimal_machine_mapping_internal(result_cache, context, resources); - - return result; -} - -MachineMappingResultTree get_optimal_machine_mapping_internal( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - MachineSpecification const &resources) { - - std::unordered_set all_layers = get_parallel_layers(context.transitive_reduced_pcg.full_pcg); - - NOT_IMPLEMENTED(); - // return get_optimal_machine_mapping_internal(result_cache, - // context, - // sp_decomposition_tree, - // resources, - // get_unconstrained_solution_for_layers(all_layers)); -} - -MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, MachineMappingProblemTree const &problem_tree, @@ -84,7 +51,7 @@ MachineMappingResultTree get_optimal_machine_mapping_internal( }; { - std::optional cached_result = + std::optional cached_result = result_cache.load(state); if (cached_result) { return cached_result.value(); @@ -92,37 +59,38 @@ MachineMappingResultTree get_optimal_machine_mapping_internal( } MachineMappingResult result = visit( - sp_decomposition_tree, + problem_tree, [&](auto const &decomp_tree_node) { - return get_optimal_machine_mapping_internal(result_cache, context, decomp_tree_node, resources, partial_solution); + return get_optimal_machine_mapping + (result_cache, + context, + decomp_tree_node, + resources, + constraints); }); result_cache.save(state, result); return result; } -std::optional get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeSeriesSplit const &series_split, MachineSpecification const &resource, - MachineMappingConstraints const &partial_solution) { - - std::optional result = std::nullopt; + MachineMappingConstraints const &partial_solution, + std::optional const ¶llel_split_transformation) { - auto is_subgraph_input = [&](std::unordered_set const &subgraph_nodes, - parallel_tensor_guid_t const &input_tensor) { - return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); - }; + MachineMappingResult result = infeasible_machine_mapping_result(); AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); - auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) - -> std::unordered_set + auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) + -> std::unordered_set { - std::unordered_map> + std::unordered_map> allowed = generate_map(layers, - [&](parallel_layer_guid_t const &l) { + [&](BinaryTreePath const &l) { return get_allowed_machine_views_for_layer(context, l); }); return transform(get_all_assignments(allowed), @@ -136,12 +104,12 @@ std::optional get_optimal_machine_mapping_internal( MachineMappingConstraints pre_candidate = with_additional_constraints( - restrict_domain(partial_solution, get_leaves(get_pre_child(series_split))), + restrict_domain(partial_solution, BinaryTreePathEntry::LEFT_CHILD), assigned_pre_machine_views); - MachineMappingResultTree pre_result = ({ - std::optional returned - = get_optimal_machine_mapping_internal(result_cache, + MachineMappingResult pre_result = ({ + std::optional returned + = get_optimal_machine_mapping(result_cache, context, get_pre_child(series_split), resource, @@ -157,12 +125,12 @@ std::optional get_optimal_machine_mapping_internal( MachineMappingConstraints post_candidate = with_additional_constraints( - restrict_domain(partial_solution, get_leaves(get_post_child(series_split))), + restrict_domain(partial_solution, BinaryTreePathEntry::RIGHT_CHILD), assigned_post_machine_views); - MachineMappingResultTree post_result = ({ - std::optional returned - = get_optimal_machine_mapping_internal(result_cache, + MachineMappingResult post_result = ({ + std::optional returned + = get_optimal_machine_mapping(result_cache, context, get_post_child(series_split), resource, @@ -178,7 +146,8 @@ std::optional get_optimal_machine_mapping_internal( /*post_mapping=*/assigned_post_machine_views); float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); - result = minimize_cost(result, make_series_split(cost_across_split, pre_result, post_result)); + result = minimize_runtime(result, + series_combine(cost_across_split, pre_result, post_result, parallel_split_transformation)); } } @@ -187,47 +156,49 @@ std::optional get_optimal_machine_mapping_internal( -MachineMappingResultTree get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, - MMProblemTreeParallelSplit const ¶llel, + MMProblemTreeParallelSplit const ¶llel_split, MachineSpecification const &resources, - MachineMappingConstraints const &partial_solution) { + MachineMappingConstraints const &constraints) { + + MachineMappingProblemTree lhs = get_lhs_child(parallel_split); + MachineMappingProblemTree rhs = get_rhs_child(parallel_split); MachineMappingResult optimal_result = [&] { - MMProblemTreeSeriesSplit series = MMProblemTreeSeriesSplit{ - MMProblemTreeSeriesSplitLabel{empty_abstracted_tensor_set_movement()}, - parallel.left, - parallel.right, - }; + MMProblemTreeSeriesSplit series_split = require_series_split(\ + mm_problem_tree_make_series_split( + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*pre=*/lhs, + /*post=*/rhs)); - return get_optimal_machine_mapping_internal(result_cache, + return get_optimal_machine_mapping(result_cache, context, - series, + series_split, resources, - partial_solution); + constraints, + ParallelSplitTransformation::LthenR); }(); - MachineMappingConstraints left_sub_solution = restrict_domain(partial_solution, - get_leaves(parallel.left)); - MachineMappingConstraints right_sub_solution = restrict_domain(partial_solution, - get_leaves(parallel.right)); + MachineMappingConstraints left_constraints = restrict_domain(constraints, get_leaves(lhs)); + MachineMappingConstraints right_constraints = restrict_domain(constraints, get_leaves(rhs)); for (auto const &resource_split : get_machine_resource_splits(resources)) { MachineMappingResult left_result = - get_optimal_machine_mapping_internal(result_cache, + get_optimal_machine_mapping(result_cache, context, - parallel.left, + lhs, resource_split.first, - left_sub_solution); + left_constraints); MachineMappingResult right_result = - get_optimal_machine_mapping_internal(result_cache, + get_optimal_machine_mapping(result_cache, context, - parallel.right, + rhs, resource_split.second, - right_sub_solution); + right_constraints); - minimize_runtime( + optimal_result = minimize_runtime( optimal_result, parallel_combine(left_result, right_result)); } @@ -235,26 +206,20 @@ MachineMappingResultTree get_optimal_machine_mapping_internal( return optimal_result; } -MachineMappingResultTree get_optimal_machine_mapping_internal( +MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGOperatorAttrs const &layer, + UnmappedOpCostEstimateKey const &leaf, MachineSpecification const &resource, MachineMappingConstraints const &constraints) { - assert (get_all_layers(constraints, IncludeUnconstrained{true}) == std::unordered_set{layer}); - - MachineMapping concrete_mapping = require_fully_constrained(constraints); + MachineView machine_view = require_only_root(constraints).value(); + OpCostEstimateKey mapped = map_unmapped_op_cost_estimate_key(leaf, machine_view); + float cost = context.cost_estimator.estimate_cost(mapped); - float cost = estimate_layer_cost(context.transitive_reduced_pcg.full_pcg, - context.cost_estimator, - layer, - concrete_mapping.machine_views.at(layer)); - - return make_leaf_node( - /*runtime=*/cost, - /*machine_mapping=*/concrete_mapping, - }; + return make_singleton_machine_mapping_result + (/*runtime=*/cost, + /*machine_view=*/machine_view); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 721fa1e32b..720522ac0c 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -1,46 +1,55 @@ -#include "compiler/machine_mapping/partial_machine_mapping.h" -#include "compiler/machine_mapping/machine_mapping_context.h" -#include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" +#include "utils/containers/filter.h" #include "utils/containers/generate_map.h" #include "utils/containers/keys.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/map_values.h" +#include "utils/full_binary_tree/binary_tree_path.h" namespace FlexFlow { -MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &layers) { +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &layers) { return MachineMappingConstraints{ generate_map(layers, - [](parallel_layer_guid_t const &) -> std::optional { + [](BinaryTreePath const &) -> std::optional { return std::nullopt; }), }; } -std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution, +std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution, IncludeUnconstrained const &include_unconstrained) { - std::unordered_set with_unconstrained = keys(partial_solution.machine_views); + std::unordered_set with_unconstrained = keys(partial_solution.machine_views); if (include_unconstrained.raw_bool) { return with_unconstrained; } else { return filter(with_unconstrained, - [&](parallel_layer_guid_t const &l) { return partial_solution.machine_views.at(l).has_value(); }); + [&](BinaryTreePath const &l) { return partial_solution.machine_views.at(l).has_value(); }); } } std::optional get_machine_view_for_layer(MachineMappingConstraints const &partial_solution, - parallel_layer_guid_t const &layer) { + BinaryTreePath const &layer) { return partial_solution.machine_views.at(layer); } -MachineMappingConstraints get_sub_solution(MachineMappingConstraints const &partial_solution, - std::unordered_set const &sub_problem) { - +MachineMappingConstraints restrict_domain(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { return MachineMappingConstraints{ - restrict_keys(partial_solution.machine_views, sub_problem), + filtermap_keys(constraints.machine_views, + [&](BinaryTreePath const &path) -> std::optional { + BinaryTreePathEntry head = binary_tree_path_get_top_level(path); + + if (head == prefix) { + BinaryTreePath rest = binary_tree_path_get_non_top_level(path); + return rest; + } else { + return std::nullopt; + } + }) }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc deleted file mode 100644 index 3aace6b332..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc +++ /dev/null @@ -1,95 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_problem_tree.h" -#include "compiler/machine_mapping/full_binary_tree/get_left_child.h" -#include "compiler/machine_mapping/full_binary_tree/get_right_child.h" -#include "compiler/machine_mapping/full_binary_tree/require.h" -#include "compiler/machine_mapping/full_binary_tree/visit.h" -#include "compiler/machine_mapping/full_binary_tree/get_leaves.h" -#include "utils/overload.h" -#include "compiler/machine_mapping/mm_problem_tree_split_label.h" - -namespace FlexFlow { - -MachineMappingProblemTree mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - FullBinaryTree{ - FullBinaryTreeParentNode{ - /*label=*/MMProblemTreeSplitLabel{ - MMProblemTreeSeriesSplitLabel{ - /*tensor_set_movement=*/tensor_set_movement, - }, - }, - /*lhs=*/lhs.raw_tree, - /*rhs=*/rhs.raw_tree, - }, - }, - }; -} - -MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - FullBinaryTree{ - FullBinaryTreeParentNode{ - /*label=*/MMProblemTreeSplitLabel{ - MMProblemTreeParallelSplitLabel{}, - }, - /*lhs=*/lhs.raw_tree, - /*rhs=*/rhs.raw_tree, - }, - }, - }; -} - -MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &layer) { - return MachineMappingProblemTree{ - FullBinaryTree{ - layer, - }, - }; -} - -SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { - return visit( - tree.raw_tree, - overload { - [](FullBinaryTreeParentNode const &parent) { - return split_label_get_node_type(parent.label); - }, - [](PCGOperatorAttrs const &) { - return SPDecompositionTreeNodeType::NODE; - } - }); -} - - -MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { - FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); - - return MMProblemTreeSeriesSplit{ - /*label=*/raw_node.label.get(), - /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, - /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, - }; -} - -MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { - FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); - - return MMProblemTreeParallelSplit{ - /*label=*/raw_node.label.get(), - /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, - /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, - }; -} - -PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &t) { - return require_leaf(t.raw_tree); -} - -std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { - return get_leaves(t.raw_tree); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc similarity index 77% rename from lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc rename to lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 8472228534..c9864c2e25 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,6 +1,6 @@ -#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" -#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg_binary_series_split.h" @@ -8,6 +8,7 @@ #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/overload.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" namespace FlexFlow { @@ -34,7 +35,7 @@ MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGr to_problem_tree(get_right_child(parallel))); }, [&](parallel_layer_guid_t const &leaf) { - return mm_problem_tree_make_leaf(pcg_get_op_attrs(pcg, leaf)); + return mm_problem_tree_make_leaf(get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf)); } }); }; diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..9d29f573c3 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -0,0 +1,78 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +MachineMappingProblemTree mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + make_generic_binary_series_split( + MMProblemTreeSeriesSplitLabel{tensor_set_movement}, + lhs.raw_tree, + rhs.raw_tree), + }; +} + +MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + make_generic_binary_parallel_split( + MMProblemTreeParallelSplitLabel{}, + lhs.raw_tree, + rhs.raw_tree), + }; +} + +MachineMappingProblemTree mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &leaf_label) { + return MachineMappingProblemTree{ + make_generic_binary_sp_leaf< + MMProblemTreeSeriesSplitLabel, + MMProblemTreeParallelSplitLabel, + UnmappedOpCostEstimateKey>(leaf_label), + }; +} + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { + return get_node_type(tree.raw_tree); +} + + +MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { + return MMProblemTreeSeriesSplit{ + require_series(t.raw_tree), + }; +} + +MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { + return MMProblemTreeParallelSplit{ + require_parallel(t.raw_tree), + }; +} + +UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { + return require_leaf(t.raw_tree); +} + +std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { + return get_leaves(t.raw_tree); +} + +std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, + BinaryTreePath const &path) { + std::optional> raw_subtree = get_subtree_at_path(tree.raw_tree, path); + + if (!raw_subtree.has_value()) { + return std::nullopt; + } else { + return MachineMappingProblemTree{raw_subtree.value()}; + } +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc new file mode 100644 index 0000000000..1b9cd59572 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc @@ -0,0 +1,20 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_lhs_child(MMProblemTreeParallelSplit const &p) { + return MachineMappingProblemTree{ + get_left_child(p.raw_split), + }; +} + +MachineMappingProblemTree get_rhs_child(MMProblemTreeParallelSplit const &p) { + return MachineMappingProblemTree{ + get_right_child(p.raw_split), + }; +} + + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc new file mode 100644 index 0000000000..545d06957a --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc @@ -0,0 +1,21 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &s) { + return MachineMappingProblemTree{ + s.raw_split.pre, + }; +} + +MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &s) { + return MachineMappingProblemTree{ + s.raw_split.post, + }; +} + +AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &s) { + return s.raw_split.label.tensor_set_movement; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc new file mode 100644 index 0000000000..2574fb81aa --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -0,0 +1,33 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { + return get_parallel_tensor_shape(pcg, t); + }; + + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/pcg_get_op_attrs(pcg, layer), + /*input_shapes=*/transform(get_incoming_inputs(pcg, layer), get_tensor_shape), + /*weight_shapes=*/transform(get_incoming_weights(pcg, layer), get_tensor_shape), + /*output_shapes=*/transform(get_layer_outputs(pcg, layer), get_tensor_shape), + }; +} + + +OpCostEstimateKey map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view) { + return OpCostEstimateKey{ + /*op_attrs=*/unmapped.op_attrs, + /*input_shapes=*/unmapped.input_shapes, + /*weight_shapes=*/unmapped.weight_shapes, + /*output_shapes=*/unmapped.output_shapes, + /*machine_view=*/machine_view, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 5e630cdef7..73f4bca8db 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,34 +1,117 @@ #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/full_binary_tree/binary_tree_path.h" namespace FlexFlow { -MachineMappingResult sequential_combine(MachineMappingResult const &s1, - float comm_cost, - MachineMappingResult const &s2) { +MachineMappingResult sequential_combine(float comm_cost, + MachineMappingResult const &maybe_pre_result, + MachineMappingResult const &maybe_post_result, + std::optional const ¶llel_split_transformation) { + FeasibleMachineMappingResult pre_result = ({ + if (is_infeasible(maybe_pre_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_pre_result); + }); + + FeasibleMachineMappingResult post_result = ({ + if (is_infeasible(maybe_post_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_post_result); + }); + + std::function transform_problem_tree_paths_pre + = nest_inside_left_child; + std::function transform_problem_tree_paths_post + = nest_inside_right_child; + + if (parallel_split_transformation.has_value() + && parallel_split_transformation.value() == ParallelSplitTransformation::RthenL) { + transform_problem_tree_paths_pre = nest_inside_right_child; + transform_problem_tree_paths_post = nest_inside_left_child; + } + return MachineMappingResult{ - s1.runtime + comm_cost + s2.runtime, - combine_disjoint_mappings(s1.machine_mapping, s2.machine_mapping)}; + FeasibleMachineMappingResult{ + /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, + /*parallel_layer_guid_oblivious_machine_mapping=*/merge_maps( + map_keys(pre_result.parallel_layer_guid_oblivious_machine_mapping, + transform_problem_tree_paths_pre), + map_keys(post_result.parallel_layer_guid_oblivious_machine_mapping, + transform_problem_tree_paths_post)), + }, + }; } -MachineMappingResult parallel_combine(MachineMappingResult const &s1, - MachineMappingResult const &s2) { +MachineMappingResult parallel_combine(MachineMappingResult const &maybe_lhs_result, + MachineMappingResult const &maybe_rhs_result) { + FeasibleMachineMappingResult lhs_result = ({ + if (is_infeasible(maybe_lhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_lhs_result); + }); + + FeasibleMachineMappingResult rhs_result = ({ + if (is_infeasible(maybe_rhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_rhs_result); + }); + return MachineMappingResult{ - std::max(s1.runtime, s2.runtime), - combine_disjoint_mappings(s1.machine_mapping, s2.machine_mapping)}; + FeasibleMachineMappingResult{ + /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), + /*parallel_layer_guid_oblivious_machine_mapping=*/merge_maps( + map_keys(lhs_result.parallel_layer_guid_oblivious_machine_mapping, + nest_inside_left_child), + map_keys(rhs_result.parallel_layer_guid_oblivious_machine_mapping, + nest_inside_right_child)), + }, + }; } -MachineMappingResult get_infinity_machine_mapping_result() { - return MachineMappingResult( - std::numeric_limits::infinity(), - MachineMapping(std::unordered_map{})); +MachineMappingResult infeasible_machine_mapping_result() { + return MachineMappingResult{std::nullopt}; } -void minimize_runtime(MachineMappingResult &m1, - MachineMappingResult const &m2) { +MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, + MachineMappingResult const &maybe_m2) { + FeasibleMachineMappingResult m1 = ({ + if (is_infeasible(maybe_m1)) { + return maybe_m2; + } + require_feasible(maybe_m1); + }); + + FeasibleMachineMappingResult m2 = ({ + if (is_infeasible(maybe_m2)) { + return maybe_m1; + } + require_feasible(maybe_m2); + }); + if (m2.runtime < m1.runtime) { - m1 = m2; + return maybe_m2; + } else { + return maybe_m1; } } +MachineMappingResult make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view) { + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/runtime, + /*parallel_layer_guid_oblivious_machine_mapping=*/{ + {binary_tree_root_path(), machine_view}, + }, + }, + }; +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc new file mode 100644 index 0000000000..2e61ca2ca2 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc @@ -0,0 +1,59 @@ +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType get_node_type(MachineMappingResultTree const &t) { + return get_node_type(t.raw_tree); +} + +float get_mm_result_tree_cost(MachineMappingResultTree const &t) { + return visit( + t, + overload { + [](MMResultTreeSeriesSplit const &series) { + return get_cost(series); + }, + [](MMResultTreeParallelSplit const ¶llel) { + return get_cost(parallel); + }, + [](MMResultTreeLeafLabel const &leaf) { + return leaf.cost; + }, + }); +} + +MachineMappingResultTree make_series_split(float comm_cost, + BinaryTreePathEntry problem_tree_path_entry, + MachineMappingResultTree const &pre, + MachineMappingResultTree const &post) { + MMResultTreeSeriesSplitLabel label = MMResultTreeSeriesSplitLabel{ + /*cost=*/get_mm_result_tree_cost(pre) + comm_cost + get_mm_result_tree_cost(post), + /*problem_tree_path_entry=*/problem_tree_path_entry, + }; + + return MachineMappingResultTree{ + make_generic_binary_series_split(label, pre.raw_tree, post.raw_tree), + }; +} + +MachineMappingResultTree make_parallel_split(MachineMappingResultTree const &lhs, + MachineMappingResultTree const &rhs) { + MMResultTreeParallelSplitLabel label = MMResultTreeParallelSplitLabel{ + /*cost=*/std::max(get_mm_result_tree_cost(lhs), get_mm_result_tree_cost(rhs)), + /*problem_tree_path_entry=*/problem_tree_path_entry, + }; + + return MachineMappingResultTree{ + make_generic_binary_series_split(label, pre.raw_tree, post.raw_tree), + }; +} + +MachineMappingResultTree make_leaf_node(float cost, MachineView const &) { + +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc new file mode 100644 index 0000000000..bf237f1aaa --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc @@ -0,0 +1,13 @@ +#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.h" + +namespace FlexFlow { + +float get_cost(MMResultTreeParallelSplit const &p) { + return p.raw_split.label.cost; +} + +BinaryTreePathEntry get_problem_tree_path_entry(MMResultTreeParallelSplit const &p) { + return p.raw_split.label.problem_tree_path_entry; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc new file mode 100644 index 0000000000..4e78787a3f --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc @@ -0,0 +1,13 @@ +#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.h" + +namespace FlexFlow { + +float get_cost(MMResultTreeSeriesSplit const &s) { + return s.raw_split.label.cost; +} + +BinaryTreePathEntry get_problem_tree_path_entry(MMResultTreeSeriesSplit const &s) { + return s.raw_split.label.problem_tree_path_entry; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index bdd68da600..78345398b9 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -1,8 +1,10 @@ #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" namespace FlexFlow { @@ -22,19 +24,31 @@ SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + make_series_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { return PCGBinarySPDecomposition{ - make_generic_binary_sp_leaf(l), + make_leaf_node(l), + }; +} + +PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &s) { + return PCGBinarySPDecomposition{ + wrap_series_split(s.raw_split), + }; +} + +PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) { + return PCGBinarySPDecomposition{ + wrap_parallel_split(p.raw_split), }; } @@ -54,4 +68,10 @@ parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { return require_leaf(d.raw_tree); } +std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &spd, + parallel_layer_guid_t const &l) { + return find_paths_to_leaf(spd.raw_tree, l); +} + + } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc similarity index 100% rename from lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc rename to lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.h b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h new file mode 100644 index 0000000000..e3ed967a23 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path(); +BinaryTreePath nest_inside_left_child(BinaryTreePath const &); +BinaryTreePath nest_inside_right_child(BinaryTreePath const &); + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &); +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml new file mode 100644 index 0000000000..08955c2d75 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "BinaryTreePath" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "utils/full_binary_tree/binary_tree_path_entry.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "entries" +type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml new file mode 100644 index 0000000000..6c81123dcf --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "BinaryTreePathEntry" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT_CHILD" +key = "left" + +[[values]] +name = "RIGHT_CHILD" +key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..4410f06e67 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include +#include "utils/overload.h" +#include "utils/containers/transform.h" +#include "utils/containers/set_union.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf(FullBinaryTree const &tree, + LeafLabel const &leaf) { + return visit>( + tree, + overload { + [&](LeafLabel const &l) -> std::unordered_set { + if (l == leaf) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + [&](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(find_paths_to_leaf(get_left_child(parent), leaf), + nest_inside_left_child), + transform(find_paths_to_leaf(get_right_child(parent), leaf), + nest_inside_right_child)); + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h index 3d94996079..ff4f54e95d 100644 --- a/lib/utils/include/utils/full_binary_tree/fmt.h +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -32,6 +32,16 @@ std::string format_as(FullBinaryTree const &t) { }); } +template +std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode const &t) { + return (s << fmt::to_string(t)); +} + +template +std::ostream &operator<<(std::ostream &s, FullBinaryTree const &t) { + return (s << fmt::to_string(t)); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h new file mode 100644 index 0000000000..8f9f76f49d --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" + +namespace FlexFlow { + +template +FullBinaryTree get_child(FullBinaryTreeParentNode const &t, + BinaryTreePathEntry const &e) { + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return get_left_child(t); + case BinaryTreePathEntry::RIGHT_CHILD: + return get_right_child(t); + default: + throw mk_runtime_error(fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..6909d9e1ef --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::optional> get_subtree_at_path(FullBinaryTree const &t, + BinaryTreePath const &p) { + if (p == binary_tree_root_path()) { + return t; + } + + return visit>>( + t, + overload { + [&](FullBinaryTreeParentNode const &parent) { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, curr), rest); + }, + [&](LeafLabel const &leaf) { + return std::nullopt; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h new file mode 100644 index 0000000000..0d830890dc --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/json.h @@ -0,0 +1,75 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::FullBinaryTreeParentNode> { + static ::FlexFlow::FullBinaryTreeParentNode from_json(json const &j) { + return ::FlexFlow::FullBinaryTreeParentNode{ + j.at("left_child") + .template get<::FlexFlow::FullBinaryTreeParentNode>(), + j.at("right_child") + .template get<::FlexFlow::FullBinaryTreeParentNode>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::FullBinaryTreeParentNode const &v) { + j["__type"] = "FullBinaryTreeParentNode"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::FullBinaryTree> { + static ::FlexFlow::FullBinaryTree from_json(json const &j) { + std::string key = j.at("type").get(); + + if (key == "parent") { + return ::FlexFlow::FullBinaryTree{ + j.at("value").get<::FlexFlow::FullBinaryTreeParentNode>(), + }; + } else if (key == "leaf") { + return ::FlexFlow::FullBinaryTree{ + j.at("value").get(), + }; + } else { + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown json type key: {}", key)); + } + } + + static void + to_json(json &j, + ::FlexFlow::FullBinaryTree const &v) { + j["__type"] = "FullBinaryTree"; + ::FlexFlow::visit( + v, + ::FlexFlow::overload{ + [&](::FlexFlow::FullBinaryTreeParentNode const &s) { + j["type"] = "parent"; + j["value"] = s; + return std::monostate{}; + }, + [&](LeafLabel const &t) { + j["type"] = "leaf"; + j["value"] = t; + return std::monostate{}; + }, + }); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 93e5bfb504..4a1e615830 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H #include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index 023c767313..281d64c6f6 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -4,6 +4,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include namespace FlexFlow { @@ -24,6 +25,29 @@ BinarySeriesSplit require_series(BinarySPDecompositionTree const &); BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &); Node require_leaf(BinarySPDecompositionTree const &); +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); + +template +Return visit(BinarySPDecompositionTree const &tree, F &&f) { + SPDecompositionTreeNodeType node_type = get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + Return result = f(require_series(tree)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + Return result = f(require_parallel(tree)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + Return result = f(require_leaf(tree)); + return result; + } + default: + throw mk_runtime_error(fmt::format("Unhandled SPDecompositionTreeNodeType value: {}", node_type)); + } +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..e89a35bab8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf(GenericBinarySPDecompositionTree const &tree, + LeafLabel const &leaf) { + return find_paths_to_leaf(tree.raw_tree, leaf); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml index e3d92c7409..f613d2f04e 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml @@ -4,6 +4,7 @@ features = [ "eq", "hash", "fmt", + "json", ] template_params = [ @@ -22,8 +23,8 @@ type = "ParallelSplitLabel" [[fields]] name = "lhs" -type = "GenericBinarySPDecompositionTree" +type = "::FlexFlow::GenericBinarySPDecompositionTree" [[fields]] name = "rhs" -type = "GenericBinarySPDecompositionTree" +type = "::FlexFlow::GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml index db11340d6e..025dca1826 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml @@ -4,6 +4,7 @@ features = [ "eq", "hash", "fmt", + "json", ] template_params = [ @@ -23,8 +24,8 @@ type = "SeriesSplitLabel" [[fields]] name = "pre" -type = "GenericBinarySPDecompositionTree" +type = "::FlexFlow::GenericBinarySPDecompositionTree" [[fields]] name = "post" -type = "GenericBinarySPDecompositionTree" +type = "::FlexFlow::GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml index 236274e617..82f93a9197 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -4,6 +4,7 @@ features = [ "eq", "hash", "fmt", + "json", ] template_params = [ @@ -14,8 +15,15 @@ template_params = [ includes = [ "utils/full_binary_tree/full_binary_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h", +] + +src_includes = [ + "utils/full_binary_tree/hash.h", + "utils/full_binary_tree/fmt.h", + "utils/full_binary_tree/json.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::FullBinaryTree, LeafLabel>" +type = "::FlexFlow::FullBinaryTree<::FlexFlow::GenericBinarySPSplitLabel, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h new file mode 100644 index 0000000000..c856f35d68 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_SPLIT_LABEL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_SPLIT_LABEL_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &label) { + return label.template visit(overload { + [](SeriesLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](ParallelLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml new file mode 100644 index 0000000000..17920c180e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "GenericBinarySPSplitLabel" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", +] + +[[values]] +type = "SeriesSplitLabel" +key = "series" + +[[values]] +type = "ParallelSplitLabel" +key = "parallel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 888d3c6627..9c3ad0daeb 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include "utils/full_binary_tree/visit.h" #include "utils/overload.h" @@ -17,14 +18,8 @@ SPDecompositionTreeNodeType [](LeafLabel const &) { return SPDecompositionTreeNodeType::NODE; }, - [](FullBinaryTreeParentNode, LeafLabel> const &parent) { - if (std::holds_alternative(parent.label)) { - return SPDecompositionTreeNodeType::SERIES; - } else { - assert (std::holds_alternative(parent.label)); - - return SPDecompositionTreeNodeType::PARALLEL; - } + [](FullBinaryTreeParentNode, LeafLabel> const &parent) { + return get_node_type(parent.label); }, }); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index 7c6d28d7b4..cfb2ea1cb2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" @@ -9,28 +9,28 @@ namespace FlexFlow { -template -int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { +template +int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { return visit(tt, overload{ - [](T const &t) { return 1; }, - [](GenericBinarySeriesSplit const &s) { + [](LeafLabel const &t) { return 1; }, + [](GenericBinarySeriesSplit const &s) { return get_num_tree_nodes(s); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return get_num_tree_nodes(p); }, }); } -template -int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { +template +int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { return 1 + get_num_tree_nodes(get_left_child(s)) + get_num_tree_nodes(get_right_child(s)); } -template -int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { +template +int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { return 1 + get_num_tree_nodes(get_left_child(p)) + get_num_tree_nodes(get_right_child(p)); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..e5b3b65ccd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include + +namespace FlexFlow { + +template +std::optional> + get_subtree_at_path(GenericBinarySPDecompositionTree const &tree, + BinaryTreePath const &path) { + std::optional, LeafLabel>> raw_subtree = get_subtree_at_path(tree.raw_tree, path); + + if (!raw_subtree.has_value()) { + return std::nullopt; + } else { + return GenericBinarySPDecompositionTree{ + raw_subtree.value(), + }; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h deleted file mode 100644 index 4f1f8266e1..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace nlohmann { - -template -struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { - static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { - return ::FlexFlow::GenericBinarySeriesSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinarySeriesSplit const &v) { - j["__type"] = "GenericBinarySeriesSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { - static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { - return ::FlexFlow::GenericBinaryParallelSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinaryParallelSplit const &v) { - j["__type"] = "GenericBinaryParallelSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { - static ::FlexFlow::GenericBinarySPDecompositionTree - from_json(json const &j) { - std::string key = j.at("type").get(); - - if (key == "series") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), - }; - } else if (key == "parallel") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), - }; - } else if (key == "leaf") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get(), - }; - } else { - throw ::FlexFlow::mk_runtime_error( - fmt::format("Unknown json type key: {}", key)); - } - } - - static void - to_json(json &j, - ::FlexFlow::GenericBinarySPDecompositionTree const &v) { - j["__type"] = "GenericBinarySPDecompositionTree"; - ::FlexFlow::visit( - v, - ::FlexFlow::overload{ - [&](::FlexFlow::GenericBinarySeriesSplit const &s) { - j["type"] = "series"; - j["value"] = s; - return std::monostate{}; - }, - [&](::FlexFlow::GenericBinaryParallelSplit const &p) { - j["type"] = "parallel"; - j["value"] = p; - return std::monostate{}; - }, - [&](T const &t) { - j["type"] = "leaf"; - j["value"] = t; - return std::monostate{}; - }, - }); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index e925292b35..2ae89462bb 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -11,9 +11,9 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &lhs, GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ - FullBinaryTreeParentNode, LeafLabel>{ - label, + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + GenericBinarySPSplitLabel{label}, lhs.raw_tree, rhs.raw_tree, } @@ -23,13 +23,13 @@ GenericBinarySPDecompositionTree make_gen template GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - SeriesLabel const &label, + ParallelLabel const &label, GenericBinarySPDecompositionTree const &lhs, GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ - FullBinaryTreeParentNode, LeafLabel>{ - label, + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + GenericBinarySPSplitLabel{label}, lhs.raw_tree, rhs.raw_tree, } @@ -40,7 +40,7 @@ GenericBinarySPDecompositionTree make_gen template GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(LeafLabel const &leaf) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ + FullBinaryTree, LeafLabel>{ leaf, }, }; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 1c20de06dc..9a93ae8d6a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -3,30 +3,39 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/full_binary_tree/require.h" namespace FlexFlow { template GenericBinarySeriesSplit require_series(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); return GenericBinarySeriesSplit{ - /*label=*/std::get(parent.label), - /*pre=*/get_left_child(parent), - /*post=*/get_right_child(parent), + /*label=*/parent.label.template get(), + /*pre=*/GenericBinarySPDecompositionTree{ + get_left_child(parent), + }, + /*post=*/GenericBinarySPDecompositionTree{ + get_right_child(parent), + }, }; } template GenericBinaryParallelSplit require_parallel(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); - - return GenericBinarySeriesSplit{ - /*label=*/std::get(parent.label), - /*pre=*/get_left_child(parent), - /*post=*/get_right_child(parent), + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + + return GenericBinaryParallelSplit{ + /*label=*/parent.label.template get(), + /*lhs=*/GenericBinarySPDecompositionTree{ + get_left_child(parent), + }, + /*rhs=*/GenericBinarySPDecompositionTree{ + get_right_child(parent), + }, }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index ce4e4ebf55..a56ed952e9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -3,6 +3,8 @@ #include "utils/exception.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" namespace FlexFlow { @@ -11,15 +13,15 @@ Result visit(GenericBinarySPDecompositionTree +GenericBinarySPDecompositionTree + wrap_series_split(GenericBinarySeriesSplit const &series_split) { + return FullBinaryTree, LeafLabel> { + FullBinaryTreeParentNode, LeafLabel> { + /*label=*/series_split.label, + /*lhs=*/series_split.pre.raw_tree, + /*rhs=*/series_split.post.raw_tree, + }, + }; +} + +template +GenericBinarySPDecompositionTree + wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { + return FullBinaryTree, LeafLabel> { + FullBinaryTreeParentNode, LeafLabel> { + /*label=*/parallel_split.label, + /*lhs=*/parallel_split.lhs.raw_tree, + /*rhs=*/parallel_split.rhs.raw_tree, + }, + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..77b44adc01 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf(LeafOnlyBinarySPDecompositionTree const &tree, + LeafLabel const &leaf) { + return find_paths_to_leaf(tree.raw_tree, leaf); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h new file mode 100644 index 0000000000..74dc6cd839 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +SPDecompositionTreeNodeType get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { + return get_node_type(tree.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h new file mode 100644 index 0000000000..7ba5d2998c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_left_associative(LeafOnlyBinarySPDecompositionTree const &t) { + return is_binary_sp_tree_left_associative(t.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h new file mode 100644 index 0000000000..84f6b21602 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_right_associative(LeafOnlyBinarySPDecompositionTree const &t) { + return is_binary_sp_tree_right_associative(t.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml index b92175b16f..d5579fd58c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml @@ -11,7 +11,7 @@ template_params = [ ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index 222799dbe9..5f5e7d9f64 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -11,9 +11,9 @@ LeafOnlyBinarySPDecompositionTree make_series_split(LeafOnlyBinarySPD LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_series_split( - LeafOnlyBinaryParallelSplitLabel{}, - pre, - post), + LeafOnlyBinarySeriesSplitLabel{}, + pre.raw_tree, + post.raw_tree), }; } @@ -21,10 +21,10 @@ template LeafOnlyBinarySPDecompositionTree make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_series_split( + make_generic_binary_parallel_split( LeafOnlyBinaryParallelSplitLabel{}, - lhs, - rhs), + lhs.raw_tree, + rhs.raw_tree), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h index 9011fadd78..65d42eee7c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -6,6 +6,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" namespace FlexFlow { @@ -19,7 +20,6 @@ LeafOnlyBinarySeriesSplit require_series(t.raw_tree); return LeafOnlyBinarySeriesSplit{ - LeafOnlyBinarySeriesSplitLabel{}, LeafOnlyBinarySPDecompositionTree{raw.pre}, LeafOnlyBinarySPDecompositionTree{raw.post}, }; @@ -28,16 +28,15 @@ LeafOnlyBinarySeriesSplit template LeafOnlyBinaryParallelSplit require_parallel(LeafOnlyBinarySPDecompositionTree const &t) { - GenericBinarySeriesSplit< + GenericBinaryParallelSplit< LeafOnlyBinarySeriesSplitLabel, LeafOnlyBinaryParallelSplitLabel, LeafLabel> raw = - require_series(t.raw_tree); + require_parallel(t.raw_tree); - return LeafOnlyBinarySeriesSplit{ - LeafOnlyBinaryParallelSplitLabel{}, - LeafOnlyBinarySPDecompositionTree{raw.pre}, - LeafOnlyBinarySPDecompositionTree{raw.post}, + return LeafOnlyBinaryParallelSplit{ + LeafOnlyBinarySPDecompositionTree{raw.lhs}, + LeafOnlyBinarySPDecompositionTree{raw.rhs}, }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h new file mode 100644 index 0000000000..4a86bc8d49 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_WRAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_WRAP_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySeriesSplit const &split) { + return LeafOnlyBinarySPDecompositionTree{ + wrap_series_split( + GenericBinarySeriesSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel>{ + LeafOnlyBinarySeriesSplitLabel{}, + split.pre, + split.post, + } + ), + }; +} + +template +LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryParallelSplit const &split) { + return LeafOnlyBinarySPDecompositionTree{ + wrap_parallel_split( + GenericBinaryParallelSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel>{ + LeafOnlyBinaryParallelSplitLabel{}, + split.lhs, + split.rhs, + } + ), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc new file mode 100644 index 0000000000..63d083fc5b --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc @@ -0,0 +1,32 @@ +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/subvec.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path() { + return BinaryTreePath{{}}; +} + +BinaryTreePath nest_inside_left_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), BinaryTreePathEntry::LEFT_CHILD); + return result; +} + +BinaryTreePath nest_inside_right_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), BinaryTreePathEntry::RIGHT_CHILD); + return result; +} + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &p) { + return p.entries.at(0); +} + +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &p) { + return BinaryTreePath{ + subvec(p.entries, 1, std::nullopt), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc index 88bb9d1acc..6763d9442b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc @@ -1,6 +1,5 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc index 9b8f0685cd..79bde7899a 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc @@ -1,6 +1,5 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index f683caef48..4aaa657821 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,9 +1,10 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" namespace FlexFlow { @@ -11,7 +12,7 @@ BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + make_series_split(lhs.raw_tree, rhs.raw_tree), }; } @@ -19,13 +20,13 @@ BinarySPDecompositionTree make_parallel_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } BinarySPDecompositionTree make_leaf_node(Node const &n) { return BinarySPDecompositionTree{ - make_generic_binary_sp_leaf(n), + make_leaf_node(n), }; } @@ -57,4 +58,9 @@ Node require_leaf(BinarySPDecompositionTree const &tt) { return require_leaf(tt.raw_tree); } +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tt) { + return get_node_type(tt.raw_tree); +} + + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 4cd7206408..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc deleted file mode 100644 index 3a4dbad8ec..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc deleted file mode 100644 index 4ee18af5be..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index 75c472c435..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index b569ff9265..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 02e541b7e4..eb66ce2f68 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -2,50 +2,49 @@ #include "utils/containers/foldl1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; - auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { - return GenericBinarySPDecompositionTree{n}; + auto from_node = [](Node const &n) -> BinarySPDecompositionTree { + return make_leaf_node(n); }; auto from_series = - [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = + [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + std::vector children = transform(s.children, from_series_child); return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{accum, x}, - }; + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) { + return make_series_split(accum, x); }); }; auto from_parallel = - [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = + [&](ParallelSplit const &s) -> BinarySPDecompositionTree { + std::vector children = transform(vector_of(s.children), from_parallel_child); return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{accum, x}}; + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) { + return make_parallel_split(accum, x); }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -54,7 +53,7 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -63,13 +62,11 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ + return nary.visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 673a4118a6..bebb97defc 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -2,47 +2,46 @@ #include "utils/containers/foldr1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; auto from_node = [](Node const &n) { - return GenericBinarySPDecompositionTree{n}; + return make_leaf_node(n); }; auto from_series = [&](SeriesSplit const &s) { - std::vector> children = + std::vector children = transform(s.children, from_series_child); return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{x, accum}}; + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) { + return make_series_split(x, accum); }); }; auto from_parallel = [&](ParallelSplit const &s) { - std::vector> children = + std::vector children = transform(vector_of(s.children), from_parallel_child); return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{x, accum}}; + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) { + return make_parallel_split(x, accum); }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -51,7 +50,7 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -60,13 +59,11 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 48c936ec39..996803d1ac 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,8 +1,8 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/containers/extend.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" #include "utils/overload.h" namespace FlexFlow { @@ -50,12 +50,12 @@ std::variant flatten_ast( } std::variant - from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { + from_binary_sp_tree(BinarySPDecompositionTree const &binary) { return visit>( binary, overload{ [](Node const &n) { return n; }, - [](GenericBinarySeriesSplit const &s) { + [](BinarySeriesSplit const &s) { return IntermediateSpDecompositionTree{ SplitType::SERIES, { @@ -64,7 +64,7 @@ std::variant }, }; }, - [](GenericBinaryParallelSplit const &p) { + [](BinaryParallelSplit const &p) { return IntermediateSpDecompositionTree{ SplitType::PARALLEL, { @@ -76,9 +76,4 @@ std::variant }); } -std::variant - from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return from_binary_sp_tree(binary.raw_tree); -} - } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index 41b7b79101..9d7e6439e2 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -1,6 +1,7 @@ #include "utils/containers/flatmap.h" #include #include +#include "test/utils/doctest/fmt/unordered_set.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 66b657eaaa..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("fmt GenericBinarySPDecompositionTree") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - std::string result = fmt::to_string(input); - std::string correct = ""; - - CHECK(result == correct); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index abae9286b6..af58bfb777 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -7,80 +7,81 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5}; - - CHECK(result == correct); - } - - SUBCASE("series split") { - SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; - - CHECK(result == correct); - } - - SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; - - CHECK(result == correct); - } - } - - SUBCASE("parallel split") { - SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; - - CHECK(result == correct); - } - - SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; - - CHECK(result == correct); - } - } - - SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {2, 2, 4, 4, 5}; - - CHECK(result == correct); - } + CHECK("TODO"); + // SUBCASE("leaf") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_sp_leaf(5); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {5}; + // + // CHECK(result == correct); + // } + // + // SUBCASE("series split") { + // SUBCASE("children are not the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {5, 6}; + // + // CHECK(result == correct); + // } + // + // SUBCASE("children are the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(5)); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {5, 5}; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("parallel split") { + // SUBCASE("children are not the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {5, 6}; + // + // CHECK(result == correct); + // } + // + // SUBCASE("children are the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(5)); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {5, 5}; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("nested") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(4), + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(2), + // make_generic_binary_sp_leaf(5))), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(4), + // make_generic_binary_sp_leaf(2))); + // + // std::unordered_multiset result = get_leaves(input); + // std::unordered_multiset correct = {2, 2, 4, 4, 5}; + // + // CHECK(result == correct); + // } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 92c556ad28..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_left_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(5); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(4); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 3de61d3313..3a7d13c2a8 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -6,80 +6,81 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - int result = get_num_tree_nodes(input); - int correct = 1; - - CHECK(result == correct); - } - - SUBCASE("series split") { - SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - - int result = get_num_tree_nodes(input); - int correct = 3; - - CHECK(result == correct); - } - - SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); - - int result = get_num_tree_nodes(input); - int correct = 3; - - CHECK(result == correct); - } - } - - SUBCASE("parallel split") { - SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - - int result = get_num_tree_nodes(input); - int correct = 3; - - CHECK(result == correct); - } - - SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); - - int result = get_num_tree_nodes(input); - int correct = 3; - - CHECK(result == correct); - } - } - - SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - int result = get_num_tree_nodes(input); - int correct = 9; - - CHECK(result == correct); - } + FAIL("TODO"); + // SUBCASE("leaf") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_sp_leaf(5); + // + // int result = get_num_tree_nodes(input); + // int correct = 1; + // + // CHECK(result == correct); + // } + // + // SUBCASE("series split") { + // SUBCASE("children are not the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // + // int result = get_num_tree_nodes(input); + // int correct = 3; + // + // CHECK(result == correct); + // } + // + // SUBCASE("children are the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(5)); + // + // int result = get_num_tree_nodes(input); + // int correct = 3; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("parallel split") { + // SUBCASE("children are not the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // + // int result = get_num_tree_nodes(input); + // int correct = 3; + // + // CHECK(result == correct); + // } + // + // SUBCASE("children are the same") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(5)); + // + // int result = get_num_tree_nodes(input); + // int correct = 3; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("nested") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(4), + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(2), + // make_generic_binary_sp_leaf(5))), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(4), + // make_generic_binary_sp_leaf(2))); + // + // int result = get_num_tree_nodes(input); + // int correct = 9; + // + // CHECK(result == correct); + // } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index 33b5d37955..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_right_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(3); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(7); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc index e7025dbfad..87d41a0bb6 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -1,4 +1,3 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include @@ -6,112 +5,113 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("std::hash>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree leaf_5 = - make_generic_binary_sp_leaf(5); - size_t leaf_5_hash = get_std_hash(leaf_5); - - SUBCASE("leaves with same labels hash to the same value") { - GenericBinarySPDecompositionTree also_leaf_5 = - make_generic_binary_sp_leaf(5); - size_t also_leaf_5_hash = get_std_hash(also_leaf_5); - - CHECK(leaf_5_hash == also_leaf_5_hash); - } - - SUBCASE("leaves with different labels hash to different values") { - GenericBinarySPDecompositionTree leaf_6 = - make_generic_binary_sp_leaf(6); - size_t leaf_6_hash = get_std_hash(leaf_6); - - CHECK(leaf_5_hash != leaf_6_hash); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t series_5_6_hash = get_std_hash(series_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_series_5_6_hash = get_std_hash(also_series_5_6); - - CHECK(series_5_6_hash == also_series_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree series_6_5 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t series_6_5_hash = get_std_hash(series_6_5); - - CHECK(series_5_6_hash != series_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree series_4_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t series_4_6_hash = get_std_hash(series_4_6); - - CHECK(series_5_6_hash != series_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree series_5_7 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t series_5_7_hash = get_std_hash(series_5_7); - - CHECK(series_5_6_hash != series_5_7_hash); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t parallel_5_6_hash = get_std_hash(parallel_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); - - CHECK(parallel_5_6_hash == also_parallel_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree parallel_6_5 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t parallel_6_5_hash = get_std_hash(parallel_6_5); - - CHECK(parallel_5_6_hash != parallel_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree parallel_4_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t parallel_4_6_hash = get_std_hash(parallel_4_6); - - CHECK(parallel_5_6_hash != parallel_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree parallel_5_7 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t parallel_5_7_hash = get_std_hash(parallel_5_7); - - CHECK(parallel_5_6_hash != parallel_5_7_hash); - } - } + FAIL("TODO, probably move over to FullBinaryTree"); + // SUBCASE("leaf") { + // GenericBinarySPDecompositionTree leaf_5 = + // make_generic_binary_sp_leaf(5); + // size_t leaf_5_hash = get_std_hash(leaf_5); + // + // SUBCASE("leaves with same labels hash to the same value") { + // GenericBinarySPDecompositionTree also_leaf_5 = + // make_generic_binary_sp_leaf(5); + // size_t also_leaf_5_hash = get_std_hash(also_leaf_5); + // + // CHECK(leaf_5_hash == also_leaf_5_hash); + // } + // + // SUBCASE("leaves with different labels hash to different values") { + // GenericBinarySPDecompositionTree leaf_6 = + // make_generic_binary_sp_leaf(6); + // size_t leaf_6_hash = get_std_hash(leaf_6); + // + // CHECK(leaf_5_hash != leaf_6_hash); + // } + // } + // + // SUBCASE("series split") { + // GenericBinarySPDecompositionTree series_5_6 = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // size_t series_5_6_hash = get_std_hash(series_5_6); + // + // SUBCASE("same children lead to the same hash") { + // GenericBinarySPDecompositionTree also_series_5_6 = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // size_t also_series_5_6_hash = get_std_hash(also_series_5_6); + // + // CHECK(series_5_6_hash == also_series_5_6_hash); + // } + // + // SUBCASE("hash is order dependent") { + // GenericBinarySPDecompositionTree series_6_5 = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(6), + // make_generic_binary_sp_leaf(5)); + // size_t series_6_5_hash = get_std_hash(series_6_5); + // + // CHECK(series_5_6_hash != series_6_5_hash); + // } + // + // SUBCASE("different left child leads to different hash") { + // GenericBinarySPDecompositionTree series_4_6 = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(4), + // make_generic_binary_sp_leaf(6)); + // size_t series_4_6_hash = get_std_hash(series_4_6); + // + // CHECK(series_5_6_hash != series_4_6_hash); + // } + // + // SUBCASE("different right child leads to different hash") { + // GenericBinarySPDecompositionTree series_5_7 = + // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(7)); + // size_t series_5_7_hash = get_std_hash(series_5_7); + // + // CHECK(series_5_6_hash != series_5_7_hash); + // } + // } + // + // SUBCASE("parallel split") { + // GenericBinarySPDecompositionTree parallel_5_6 = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // size_t parallel_5_6_hash = get_std_hash(parallel_5_6); + // + // SUBCASE("same children lead to the same hash") { + // GenericBinarySPDecompositionTree also_parallel_5_6 = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(6)); + // size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); + // + // CHECK(parallel_5_6_hash == also_parallel_5_6_hash); + // } + // + // SUBCASE("hash is order dependent") { + // GenericBinarySPDecompositionTree parallel_6_5 = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), + // make_generic_binary_sp_leaf(5)); + // size_t parallel_6_5_hash = get_std_hash(parallel_6_5); + // + // CHECK(parallel_5_6_hash != parallel_6_5_hash); + // } + // + // SUBCASE("different left child leads to different hash") { + // GenericBinarySPDecompositionTree parallel_4_6 = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + // make_generic_binary_sp_leaf(6)); + // size_t parallel_4_6_hash = get_std_hash(parallel_4_6); + // + // CHECK(parallel_5_6_hash != parallel_4_6_hash); + // } + // + // SUBCASE("different right child leads to different hash") { + // GenericBinarySPDecompositionTree parallel_5_7 = + // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + // make_generic_binary_sp_leaf(7)); + // size_t parallel_5_7_hash = get_std_hash(parallel_5_7); + // + // CHECK(parallel_5_6_hash != parallel_5_7_hash); + // } + // } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 7a8756c6cc..481dcd85d3 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1,5 +1,4 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include @@ -8,95 +7,96 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_binary_sp_tree_left_associative(" "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; - - SUBCASE("input is actually left associative") { - SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = true; - - CHECK(result == correct); - } - } - - SUBCASE("input is not left associative") { - SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = false; - - CHECK(result == correct); - } - - SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); - - bool result = is_binary_sp_tree_left_associative(input); - bool correct = false; - - CHECK(result == correct); - } - } + FAIL("TODO"); + // int n1 = 1; + // int n2 = 2; + // int n3 = 3; + // int n4 = 4; + // + // SUBCASE("input is actually left associative") { + // SUBCASE("just node") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_sp_leaf(n1); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just series") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_sp_leaf(n3)); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just parallel") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_sp_leaf(n3)); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("nested") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n3), + // make_generic_binary_sp_leaf(n4))); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("input is not left associative") { + // SUBCASE("just series") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n2), + // make_generic_binary_sp_leaf(n3))); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = false; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just parallel") { + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n2), + // make_generic_binary_sp_leaf(n3))); + // + // bool result = is_binary_sp_tree_left_associative(input); + // bool correct = false; + // + // CHECK(result == correct); + // } + // } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 3cf87368ab..3651eca03a 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,102 +1,102 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_binary_sp_tree_right_associative(" - "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; - - SUBCASE("input is actually right associative") { - SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = true; - - CHECK(result == correct); - } - - SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = true; - - CHECK(result == correct); - } - } - - SUBCASE("input is not right associative") { - SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = false; - - CHECK(result == correct); - } - - SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); - - bool result = is_binary_sp_tree_right_associative(input); - bool correct = false; - - CHECK(result == correct); - } - } + "LeafOnlyBinarySPDecompositionTree)") { + FAIL("TODO"); + // int n1 = 1; + // int n2 = 2; + // int n3 = 3; + // int n4 = 4; + // + // SUBCASE("input is actually right associative") { + // SUBCASE("just node") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_sp_leaf(n1); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just series") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n2), + // make_generic_binary_sp_leaf(n3))); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just parallel") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n2), + // make_generic_binary_sp_leaf(n3))); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // + // SUBCASE("nested") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n3), + // make_generic_binary_sp_leaf(n4))); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = true; + // + // CHECK(result == correct); + // } + // } + // + // SUBCASE("input is not right associative") { + // SUBCASE("just series") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_series_split( + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_sp_leaf(n3)); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = false; + // + // CHECK(result == correct); + // } + // + // SUBCASE("just parallel") { + // LeafOnlyBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_parallel_split( + // make_generic_binary_sp_leaf(n1), + // make_generic_binary_sp_leaf(n2)), + // make_generic_binary_sp_leaf(n3)); + // + // bool result = is_binary_sp_tree_right_associative(input); + // bool correct = false; + // + // CHECK(result == correct); + // } + // } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index cc234bacf8..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1,131 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("adl_serializer>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_series_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "series"}, - { - "value", - { - {"__type", "GenericBinarySeriesSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "parallel"}, - { - "value", - { - {"__type", "GenericBinaryParallelSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc index 4ede4e84b5..b9021a19ef 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -1,5 +1,4 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include @@ -7,22 +6,23 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split(make_generic_binary_sp_leaf(1), - make_generic_binary_sp_leaf(4)), - make_generic_binary_sp_leaf(2)); - - GenericBinarySPDecompositionTree result = - transform(input, [](int x) { return std::to_string(x); }); - - GenericBinarySPDecompositionTree correct = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(std::string{"1"}), - make_generic_binary_sp_leaf(std::string{"4"})), - make_generic_binary_sp_leaf(std::string{"2"})); - - CHECK(result == correct); + FAIL("TODO"); + // GenericBinarySPDecompositionTree input = + // make_generic_binary_parallel_split( + // make_generic_binary_series_split(make_generic_binary_sp_leaf(1), + // make_generic_binary_sp_leaf(4)), + // make_generic_binary_sp_leaf(2)); + // + // GenericBinarySPDecompositionTree result = + // transform(input, [](int x) { return std::to_string(x); }); + // + // GenericBinarySPDecompositionTree correct = + // make_generic_binary_parallel_split( + // make_generic_binary_series_split( + // make_generic_binary_sp_leaf(std::string{"1"}), + // make_generic_binary_sp_leaf(std::string{"4"})), + // make_generic_binary_sp_leaf(std::string{"2"})); + // + // CHECK(result == correct); } } From 2bbec5c6ca5cd221bb3590ddbc54d57d472ecb3c Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 30 Sep 2024 21:55:38 -0700 Subject: [PATCH 18/29] More code cleanup and PR prep --- .../abstracted_tensor_set_movement.h | 5 +- .../machine_mapping/estimate_layer_cost.h | 15 --- ...easible_machine_mapping_result.struct.toml | 7 +- .../get_tensor_set_movement_across_split.h | 6 +- .../machine_mapping/machine_mapping_cache.h | 2 +- .../machine_mapping_constraints.h | 9 +- .../machine_mapping/machine_mapping_context.h | 23 ---- .../machine_mapping_context.struct.toml | 1 + .../machine_mapping_problem_tree.h | 4 + .../machine_mapping_result.struct.toml | 2 +- ...lel_layer_guid_oblivious_machine_mapping.h | 20 +++ ...guid_oblivious_machine_mapping.struct.toml | 21 +++ ...omputation_graph_binary_sp_decomposition.h | 4 - ..._graph_binary_sp_decomposition.struct.toml | 4 +- .../pcg_binary_series_split.struct.toml | 1 - .../abstracted_tensor_set_movement.cc | 17 ++- .../machine_mapping/estimate_layer_cost.cc | 30 ----- .../get_optimal_machine_mapping.cc | 50 +++---- .../get_tensor_set_movement_across_split.cc | 9 +- .../machine_mapping/machine_mapping_cache.cc | 8 +- .../machine_mapping_constraints.cc | 27 ++-- .../machine_mapping_context.cc | 36 ----- .../machine_mapping_problem_tree.cc | 5 + .../machine_mapping/machine_mapping_result.cc | 39 +++--- .../machine_mapping_result_tree.cc | 59 --------- .../mm_result_tree_parallel_split.cc | 13 -- .../mm_result_tree_series_split.cc | 13 -- .../mm_problem_tree_series_split.cc | 18 --- .../mm_problem_tree_split_label.cc | 17 --- ...el_layer_guid_oblivious_machine_mapping.cc | 24 ++++ ...mputation_graph_binary_sp_decomposition.cc | 38 ++---- .../pcg_binary_series_split.cc | 7 +- lib/compiler/src/unity_algorithm.cc | 102 +++++++------- .../machine_mapping/cost_estimator_for_test.h | 3 +- .../machine_mapping/estimate_layer_cost.cc | 124 ------------------ .../get_optimal_machine_mapping.cc | 94 ++++++------- lib/utils/include/utils/containers/try_at.h | 31 +++++ .../full_binary_tree/get_all_leaf_paths.h | 39 ++++++ .../include/utils/full_binary_tree/visit.h | 6 +- ..._sp_decomposition_tree_visitor.struct.toml | 28 ++++ .../get_all_leaf_paths.h | 16 +++ .../get_subtree_at_path.h | 2 +- .../transform.h | 67 +++++----- .../wrap.h | 24 ++-- ..._sp_decomposition_tree_visitor.struct.toml | 16 +++ .../transform.h | 69 +++++----- .../wrap.h | 8 +- lib/utils/src/utils/containers/try_at.cc | 1 + lib/utils/test/src/utils/containers/try_at.cc | 26 ++++ 49 files changed, 527 insertions(+), 663 deletions(-) delete mode 100644 lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml delete mode 100644 lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc delete mode 100644 lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc create mode 100644 lib/utils/include/utils/containers/try_at.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml create mode 100644 lib/utils/src/utils/containers/try_at.cc create mode 100644 lib/utils/test/src/utils/containers/try_at.cc diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h index 5917a8fb26..7a32b7a694 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -4,6 +4,7 @@ #include "compiler/cost_estimator/tensor_set_movement.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" namespace FlexFlow { @@ -13,8 +14,8 @@ std::unordered_set get_src_layers(AbstractedTensorSetMovement co std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &, - MachineMapping const &pre, - MachineMapping const &post); + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h b/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h deleted file mode 100644 index a862f0c476..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/estimate_layer_cost.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_LAYER_COST_H - -#include "compiler/cost_estimator/cost_estimator.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -namespace FlexFlow { - -float estimate_layer_cost(CostEstimator const &cost_estimator, - UnmappedOpCostEstimateKey const &key, - MachineView const &machine_view); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml index c75c968a90..e71cfc540f 100644 --- a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml @@ -7,8 +7,7 @@ features = [ ] includes = [ - "pcg/machine_view.dtg.h", - "utils/full_binary_tree/binary_tree_path.dtg.h", + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", ] [[fields]] @@ -16,5 +15,5 @@ name = "runtime" type = "float" [[fields]] -name = "parallel_layer_guid_oblivious_machine_mapping" -type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h index 9becde61c3..770bfe982d 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H #include "compiler/cost_estimator/tensor_set_movement.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" @@ -10,8 +10,8 @@ namespace FlexFlow { TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, PCGBinarySeriesSplit const &split, - PartialMachineMapping const &pre_mapping, - PartialMachineMapping const &post_mapping); + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index 4e72cc1d76..fc00e7f26a 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -12,7 +12,7 @@ class MachineMappingCache { MachineMappingCache() = default; std::optional load(MachineMappingState const &) const; - void save(MachineMappingState const &, std::optional const &); + void save(MachineMappingState const &, MachineMappingResult const &); private: std::unordered_map cache; diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index f0c81f3ecd..3ab879aed3 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -4,6 +4,7 @@ #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -17,11 +18,13 @@ std::unordered_set get_all_layers(MachineMappingConstraints cons std::optional get_machine_view_for_layer(MachineMappingConstraints const &, BinaryTreePath const &); -MachineMappingConstraints restrict_domain(MachineMappingConstraints const &, - BinaryTreePathEntry const &); +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, BinaryTreePathEntry const &); +MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &); +MachineMappingConstraints restrict_to_right_child(MachineMappingConstraints const &); + MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, - MachineMapping const &); + ParallelLayerGuidObliviousMachineMapping const &); MachineMapping require_fully_constrained(MachineMappingConstraints const &); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h deleted file mode 100644 index 894f935015..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONTEXT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONTEXT_H - -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -std::unordered_set get_allowed_machine_views_for_tensor(MachineMappingContext const &, - parallel_tensor_guid_t const &); -std::unordered_set get_allowed_machine_views_for_layer(MachineMappingContext const &, - parallel_layer_guid_t const &); - -MachineMappingContext make_machine_mapping_context(ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index c4bf1d1ac8..81e26f491d 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -6,6 +6,7 @@ includes = [ "compiler/cost_estimator/cost_estimator.h", "pcg/machine_view.dtg.h", "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", ] [[fields]] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 20e3a11399..13a7358a6e 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -24,7 +24,11 @@ MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &) MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &); UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &); +MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &); +MachineMappingProblemTree wrap_parallel_split(MMProblemTreeParallelSplit const &); + std::unordered_multiset get_leaves(MachineMappingProblemTree const &); +std::unordered_set get_all_leaf_paths(MachineMappingProblemTree const &); std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, BinaryTreePath const &); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml index 28b124cea3..92a2873af5 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -17,4 +17,4 @@ src_includes = [ [[fields]] name = "raw_result" -type = "std::optional" +type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h new file mode 100644 index 0000000000..23c589a261 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping + binary_combine_mappings(ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +ParallelLayerGuidObliviousMachineMapping restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); + +std::optional get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml new file mode 100644 index 0000000000..f00fcc8490 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ParallelLayerGuidObliviousMachineMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h index 3032e3efe9..b855fbff07 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h @@ -9,10 +9,6 @@ namespace FlexFlow { SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &); layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); std::optional get_computation_graph_left_assoc_binary_sp_decomposition( diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml index 98d0fc5faf..2e6bb0b611 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -8,9 +8,9 @@ features = [ includes = [ "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition/leaf_only_binary_sp_decomposition.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::LeafOnlyBinarySPDecomposition<::FlexFlow::layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml index 48e19022c9..184b272c55 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml @@ -2,7 +2,6 @@ namespace = "FlexFlow" name = "PCGBinarySeriesSplit" features = [ "eq", - "ord", "hash", "fmt", ] diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 69242e4076..8fc7239a45 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -1,4 +1,5 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "utils/containers/flatmap.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/transform.h" @@ -24,18 +25,22 @@ std::unordered_set get_dst_layers(AbstractedTensorSetMovement co } TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &abstracted, - PartialMachineMapping const &pre_mapping, - PartialMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(/*lhs=*/pre_mapping, + /*rhs=*/post_mapping); + auto concretize_tensor_movement = [&](AbstractedSingleTensorMovement const &a) { return SingleTensorMovement{ /*parallel_tensor_shape=*/a.parallel_tensor_shape, /*src_machine_views=*/transform(a.src_machine_views, - [&](parallel_layer_guid_t const &layer) { - return get_machine_view_for_layer(pre_mapping, layer).value(); + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(pre_mapping, path).value(); }), /*dst_machine_views=*/transform(a.dst_machine_views, - [&](parallel_layer_guid_t const &layer) { - return get_machine_view_for_layer(post_mapping, layer).value(); + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(post_mapping, path).value(); }), }; }; diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc deleted file mode 100644 index 2df6ddb859..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "compiler/machine_mapping/estimate_layer_cost.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" - -namespace FlexFlow { - -float estimate_layer_cost(CostEstimator const &cost_estimator, - UnmappedOpCostEstimateKey const &key, - MachineView const &machine_view) { - PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, layer).op_attrs; - - auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { - return get_parallel_tensor_shape(pcg, t); - }; - - std::vector input_tensors = get_incoming_inputs(pcg, layer); - std::vector weight_tensors = get_incoming_weights(pcg, layer); - std::vector output_tensors = get_layer_outputs(pcg, layer); - - OpCostEstimateKey key = OpCostEstimateKey{ - /*op_attrs=*/op_attrs, - /*input_shapes=*/transform(input_tensors, get_tensor_shape), - /*weight_shapes=*/transform(weight_tensors, get_tensor_shape), - /*output_shapes=*/transform(output_tensors, get_tensor_shape), - /*machine_view=*/machine_view, - }; - - return cost_estimator.estimate_cost(key); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 4d4fe8a7de..bf3b2d9f9b 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,20 +1,12 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" -#include "compiler/machine_mapping/get_allowed_machine_views_list.h" #include "compiler/machine_mapping/get_machine_resource_splits.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_result.h" -#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" -#include "compiler/machine_mapping/mm_problem_tree_series_split.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" @@ -29,12 +21,7 @@ #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/overload.h" -#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "compiler/series_parallel/pcg_binary_series_split.h" -#include "compiler/machine_mapping/machine_mapping_context.h" #include "utils/containers/flatmap.h" -#include "compiler/machine_mapping/estimate_layer_cost.h" namespace FlexFlow { @@ -77,34 +64,37 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resource, - MachineMappingConstraints const &partial_solution, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation) { MachineMappingResult result = infeasible_machine_mapping_result(); AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); - auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) - -> std::unordered_set + auto get_boundary_machine_view_assignments = [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { std::unordered_map> - allowed = generate_map(layers, - [&](BinaryTreePath const &l) { - return get_allowed_machine_views_for_layer(context, l); + allowed = generate_map(boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = require_leaf + (mm_problem_tree_get_subtree_at_path + (wrap_series_split(series_split), l).value()); + return context.allowed_machine_views(leaf, resources); }); return transform(get_all_assignments(allowed), - [](std::unordered_map const &m) { - return MachineMapping{m}; + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; }); }; - for (MachineMapping const &assigned_pre_machine_views + for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { MachineMappingConstraints pre_candidate = with_additional_constraints( - restrict_domain(partial_solution, BinaryTreePathEntry::LEFT_CHILD), + restrict_to_left_child(constraints), assigned_pre_machine_views); MachineMappingResult pre_result = ({ @@ -112,7 +102,7 @@ MachineMappingResult get_optimal_machine_mapping( = get_optimal_machine_mapping(result_cache, context, get_pre_child(series_split), - resource, + resources, pre_candidate); if (!returned.has_value()) { continue; @@ -120,12 +110,12 @@ MachineMappingResult get_optimal_machine_mapping( returned.value(); }); - for (MachineMapping const &assigned_post_machine_views + for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments(get_dst_layers(tensor_movement))) { MachineMappingConstraints post_candidate = with_additional_constraints( - restrict_domain(partial_solution, BinaryTreePathEntry::RIGHT_CHILD), + restrict_to_right_child(constraints), assigned_post_machine_views); MachineMappingResult post_result = ({ @@ -133,7 +123,7 @@ MachineMappingResult get_optimal_machine_mapping( = get_optimal_machine_mapping(result_cache, context, get_post_child(series_split), - resource, + resources, post_candidate); if (!returned.has_value()) { continue; @@ -181,8 +171,8 @@ MachineMappingResult get_optimal_machine_mapping( ParallelSplitTransformation::LthenR); }(); - MachineMappingConstraints left_constraints = restrict_domain(constraints, get_leaves(lhs)); - MachineMappingConstraints right_constraints = restrict_domain(constraints, get_leaves(rhs)); + MachineMappingConstraints left_constraints = restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = restrict_to_right_child(constraints); for (auto const &resource_split : get_machine_resource_splits(resources)) { MachineMappingResult left_result = diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index f237fba88f..2979947c7c 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,7 +1,6 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" -#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" -#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" @@ -15,8 +14,8 @@ namespace FlexFlow { TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split, - PartialMachineMapping const &pre_mapping, - PartialMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { AbstractedTensorSetMovement abstracted = get_abstracted_tensor_set_movement_across_split(tr_pcg, split); return concretize_abstracted_tensor_set_movement(abstracted, pre_mapping, post_mapping); } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index b2b3fbc8f5..76d7fea7ff 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,15 +1,11 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" -#include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" namespace FlexFlow { std::optional MachineMappingCache::load(MachineMappingState const &state) const { - if (contains_key(cache, state)) { - MachineMappingResult result = cache.at(state); - return result; - } - return std::nullopt; + return try_at(this->cache, state); } void MachineMappingCache::save(MachineMappingState const &state, diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 720522ac0c..206bde2e78 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -36,8 +36,8 @@ std::optional get_machine_view_for_layer(MachineMappingConstraints return partial_solution.machine_views.at(layer); } -MachineMappingConstraints restrict_domain(MachineMappingConstraints const &constraints, - BinaryTreePathEntry const &prefix) { +MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { return MachineMappingConstraints{ filtermap_keys(constraints.machine_views, [&](BinaryTreePath const &path) -> std::optional { @@ -53,11 +53,19 @@ MachineMappingConstraints restrict_domain(MachineMappingConstraints const &const }; } -MachineMappingConstraints with_additional_layer_machine_views(MachineMappingConstraints const &partial_solution, - std::unordered_map const &additional) { - MachineMappingConstraints result = partial_solution; +MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::LEFT_CHILD); +} + +MachineMappingConstraints restrict_to_right_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::RIGHT_CHILD); +} + +MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &constraints, + ParallelLayerGuidObliviousMachineMapping const &additional) { + MachineMappingConstraints result = constraints; - for (auto const &[layer, machine_view] : additional) { + for (auto const &[layer, machine_view] : additional.raw_mapping) { std::optional current_machine_view = result.machine_views.at(layer); if (!current_machine_view.has_value()) { @@ -75,11 +83,4 @@ MachineMappingConstraints with_additional_layer_machine_views(MachineMappingCons } -MachineMapping require_complete_mapping(MachineMappingConstraints const &partial_mapping) { - return MachineMapping{ - map_values(partial_mapping.machine_views, - [](std::optional const &mv) { return mv.value(); }), - }; -} - } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc deleted file mode 100644 index c45e964a3a..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_context.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_context.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/containers/keys.h" -#include "utils/containers/sum.h" - -namespace FlexFlow { - -std::unordered_set get_allowed_machine_views_for_tensor(MachineMappingContext const &, - parallel_tensor_guid_t const &) { - NOT_IMPLEMENTED(); -} - -std::unordered_set get_allowed_machine_views_for_layer(MachineMappingContext const &, - parallel_layer_guid_t const &) { - NOT_IMPLEMENTED(); -} - -MachineMappingContext make_machine_mapping_context(ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views) { - NOT_IMPLEMENTED(); -} - -std::unordered_set get_transitively_reduced_predecessors(MachineMappingContext const &ctx, - parallel_layer_guid_t const &l) { - NOT_IMPLEMENTED(); -} - -std::unordered_set get_transitively_reduced_successors(MachineMappingContext const &ctx, - parallel_layer_guid_t const &l) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 9d29f573c3..e3f5e582ea 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -4,6 +4,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" namespace FlexFlow { @@ -62,6 +63,10 @@ std::unordered_multiset get_leaves(MachineMappingProb return get_leaves(t.raw_tree); } +std::unordered_set get_all_leaf_paths(MachineMappingProblemTree const &t) { + return get_all_leaf_paths(t.raw_tree); +} + std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, BinaryTreePath const &path) { std::optional transform_problem_tree_paths_pre - = nest_inside_left_child; - std::function transform_problem_tree_paths_post - = nest_inside_right_child; - - if (parallel_split_transformation.has_value() - && parallel_split_transformation.value() == ParallelSplitTransformation::RthenL) { - transform_problem_tree_paths_pre = nest_inside_right_child; - transform_problem_tree_paths_post = nest_inside_left_child; - } + ParallelLayerGuidObliviousMachineMapping mapping = [&] { + if (parallel_split_transformation.has_value() + && parallel_split_transformation.value() == ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, + /*rhs=*/pre_result.machine_mapping); + } else { + return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, + /*rhs=*/post_result.machine_mapping); + } + }(); return MachineMappingResult{ FeasibleMachineMappingResult{ /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, - /*parallel_layer_guid_oblivious_machine_mapping=*/merge_maps( - map_keys(pre_result.parallel_layer_guid_oblivious_machine_mapping, - transform_problem_tree_paths_pre), - map_keys(post_result.parallel_layer_guid_oblivious_machine_mapping, - transform_problem_tree_paths_post)), + /*machine_mapping=*/mapping, }, }; } @@ -66,11 +63,9 @@ MachineMappingResult parallel_combine(MachineMappingResult const &maybe_lhs_resu return MachineMappingResult{ FeasibleMachineMappingResult{ /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), - /*parallel_layer_guid_oblivious_machine_mapping=*/merge_maps( - map_keys(lhs_result.parallel_layer_guid_oblivious_machine_mapping, - nest_inside_left_child), - map_keys(rhs_result.parallel_layer_guid_oblivious_machine_mapping, - nest_inside_right_child)), + /*machine_mapping=*/binary_combine_mappings + (/*lhs=*/lhs_result.machine_mapping, + /*rhs=*/rhs_result.machine_mapping), }, }; } @@ -107,9 +102,9 @@ MachineMappingResult make_singleton_machine_mapping_result(float runtime, return MachineMappingResult{ FeasibleMachineMappingResult{ /*runtime=*/runtime, - /*parallel_layer_guid_oblivious_machine_mapping=*/{ + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ {binary_tree_root_path(), machine_view}, - }, + }}, }, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc deleted file mode 100644 index 2e61ca2ca2..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.cc +++ /dev/null @@ -1,59 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" -#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.h" -#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" - -namespace FlexFlow { - -SPDecompositionTreeNodeType get_node_type(MachineMappingResultTree const &t) { - return get_node_type(t.raw_tree); -} - -float get_mm_result_tree_cost(MachineMappingResultTree const &t) { - return visit( - t, - overload { - [](MMResultTreeSeriesSplit const &series) { - return get_cost(series); - }, - [](MMResultTreeParallelSplit const ¶llel) { - return get_cost(parallel); - }, - [](MMResultTreeLeafLabel const &leaf) { - return leaf.cost; - }, - }); -} - -MachineMappingResultTree make_series_split(float comm_cost, - BinaryTreePathEntry problem_tree_path_entry, - MachineMappingResultTree const &pre, - MachineMappingResultTree const &post) { - MMResultTreeSeriesSplitLabel label = MMResultTreeSeriesSplitLabel{ - /*cost=*/get_mm_result_tree_cost(pre) + comm_cost + get_mm_result_tree_cost(post), - /*problem_tree_path_entry=*/problem_tree_path_entry, - }; - - return MachineMappingResultTree{ - make_generic_binary_series_split(label, pre.raw_tree, post.raw_tree), - }; -} - -MachineMappingResultTree make_parallel_split(MachineMappingResultTree const &lhs, - MachineMappingResultTree const &rhs) { - MMResultTreeParallelSplitLabel label = MMResultTreeParallelSplitLabel{ - /*cost=*/std::max(get_mm_result_tree_cost(lhs), get_mm_result_tree_cost(rhs)), - /*problem_tree_path_entry=*/problem_tree_path_entry, - }; - - return MachineMappingResultTree{ - make_generic_binary_series_split(label, pre.raw_tree, post.raw_tree), - }; -} - -MachineMappingResultTree make_leaf_node(float cost, MachineView const &) { - -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc deleted file mode 100644 index bf237f1aaa..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.h" - -namespace FlexFlow { - -float get_cost(MMResultTreeParallelSplit const &p) { - return p.raw_split.label.cost; -} - -BinaryTreePathEntry get_problem_tree_path_entry(MMResultTreeParallelSplit const &p) { - return p.raw_split.label.problem_tree_path_entry; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc deleted file mode 100644 index 4e78787a3f..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.h" - -namespace FlexFlow { - -float get_cost(MMResultTreeSeriesSplit const &s) { - return s.raw_split.label.cost; -} - -BinaryTreePathEntry get_problem_tree_path_entry(MMResultTreeSeriesSplit const &s) { - return s.raw_split.label.problem_tree_path_entry; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc deleted file mode 100644 index 28c6137440..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "compiler/machine_mapping/mm_problem_tree_series_split.h" -#include "compiler/machine_mapping/full_binary_tree/require.h" - -namespace FlexFlow { - -MachineMappingProblemTree const &get_left_child(MMProblemTreeSeriesSplit const &s) { - FullBinaryTree< require_parent(s.problem_tree.raw_tree); -} - -MachineMappingProblemTree const &get_right_child(MMProblemTreeSeriesSplit const &) { - -} - -AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &) { - -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc deleted file mode 100644 index 54b7a4eaf8..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "compiler/machine_mapping/mm_problem_tree_split_label.h" -#include "utils/overload.h" - -namespace FlexFlow { - -SPDecompositionTreeNodeType split_label_get_node_type(MMProblemTreeSplitLabel const &l) { - return l.visit(overload { - [](MMProblemTreeSeriesSplitLabel const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](MMProblemTreeParallelSplitLabel const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - }); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc new file mode 100644 index 0000000000..27922d62dc --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -0,0 +1,24 @@ +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/try_at.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping + binary_combine_mappings(ParallelLayerGuidObliviousMachineMapping const &lhs, + ParallelLayerGuidObliviousMachineMapping const &rhs) { + return ParallelLayerGuidObliviousMachineMapping{ + merge_maps( + map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), + }; +} + +std::optional get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &mapping, + BinaryTreePath const &path) { + return try_at(mapping.raw_mapping, path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc index 63d1231ae7..00d0d74959 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -1,13 +1,11 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" @@ -18,20 +16,6 @@ SPDecompositionTreeNodeType return get_node_type(d.raw_tree); } -ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_left_child(d.raw_tree), - }; -} - -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_right_child(d.raw_tree), - }; -} - layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { return require_leaf(d.raw_tree); } @@ -51,8 +35,11 @@ std::optional BinarySPDecompositionTree raw_binary_tree = left_associative_binary_sp_tree_from_nary(sp_decomposition); + auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ + [](Node const &n) { return layer_guid_t{n}; }, + }; return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; + raw_binary_tree.raw_tree, visitor)}; } std::optional @@ -70,8 +57,11 @@ std::optional BinarySPDecompositionTree raw_binary_tree = right_associative_binary_sp_tree_from_nary(sp_decomposition); + auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ + [](Node const &n) { return layer_guid_t{n}; }, + }; return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; + raw_binary_tree.raw_tree, visitor)}; } bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc index 31a90533ff..0b972706d1 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -5,9 +5,12 @@ namespace FlexFlow { BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &s) { + auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; } + }; + return BinarySeriesSplit{ - transform(s.raw_split, - [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + transform(s.raw_split, visitor), }; } diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index caf072fdbc..263450117f 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -32,58 +32,60 @@ GraphOptimizeResult graph_optimize( ParallelLayerAttrs const &, MachineSpecification const &)> const &allowed_machine_views, OptimizerConfig const &opt_config) { - std::vector substitutions = - get_all_applicable_substitutions(pcg); - - MachineMappingCache cached_subgraph_costs; - DeduplicatedPriorityQueue candidates; - - MachineMappingResult original_pcg_cost = - get_optimal_machine_mapping(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); - - GraphOptimizeState initial_state = { - GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), - original_pcg_cost.runtime}; - - GraphOptimizeState best_state = initial_state; - candidates.push(initial_state); - - for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; - ++iteration) { - GraphOptimizeState current_state = candidates.top(); - candidates.pop(); - - if (current_state.runtime < best_state.runtime) { - best_state = current_state; - } else if (current_state.runtime > best_state.runtime * opt_config.alpha) { - continue; - } + NOT_IMPLEMENTED(); - for (Substitution const &substitution : substitutions) { - for (ParallelComputationGraph const &new_pcg : apply_substitution( - current_state.graph_optimize_result.pcg, substitution)) { - MachineMappingResult new_pcg_cost = - get_optimal_machine_mapping(new_pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs); - GraphOptimizeState new_state{ - GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), - new_pcg_cost.runtime}; - if (new_pcg_cost.runtime <= opt_config.threshold && - get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - candidates.push(new_state); - } - } - } - } + // std::vector substitutions = + // get_all_applicable_substitutions(pcg); + // + // MachineMappingCache cached_subgraph_costs; + // DeduplicatedPriorityQueue candidates; + // + // MachineMappingResult original_pcg_cost = + // get_optimal_machine_mapping(pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // + // GraphOptimizeState initial_state = { + // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + // original_pcg_cost.runtime}; + // + // GraphOptimizeState best_state = initial_state; + // candidates.push(initial_state); + // + // for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; + // ++iteration) { + // GraphOptimizeState current_state = candidates.top(); + // candidates.pop(); + // + // if (current_state.runtime < best_state.runtime) { + // best_state = current_state; + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) { + // continue; + // } + // + // for (Substitution const &substitution : substitutions) { + // for (ParallelComputationGraph const &new_pcg : apply_substitution( + // current_state.graph_optimize_result.pcg, substitution)) { + // MachineMappingResult new_pcg_cost = + // get_optimal_machine_mapping(new_pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // GraphOptimizeState new_state{ + // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + // new_pcg_cost.runtime}; + // if (new_pcg_cost.runtime <= opt_config.threshold && + // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { + // candidates.push(new_state); + // } + // } + // } + // } - return best_state.graph_optimize_result; + // return best_state.graph_optimize_result; } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index 2fa9e6028f..c8e2624c54 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -4,6 +4,7 @@ #include "compiler/cost_estimator/cost_estimator.h" #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" namespace FlexFlow { @@ -25,7 +26,7 @@ CostEstimator make_fake_cost_estimator( std::function const &get_communication_cost); CostEstimator make_fake_cost_estimator( - std::unordered_map const &op_cost_map, + std::unordered_map> const &op_cost_map, std::unordered_map const &comm_cost_map); } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc deleted file mode 100644 index cd72f74c61..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/estimate_layer_cost.cc +++ /dev/null @@ -1,124 +0,0 @@ -#include "compiler/machine_mapping/estimate_layer_cost.h" -#include "./cost_estimator_for_test.h" -#include "op-attrs/ops/linear.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/get_only.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("estimate_layer_cost") { - ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{8, 2}, - ShardParallelDim{10, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, - }; - - LinearAttrs linear_attrs = LinearAttrs{ - /*out_channels=*/12, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/std::nullopt, - /*regularizer=*/std::nullopt, - }; - - ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(linear_attrs, input_shape)); - ParallelTensorShape bias_shape = throw_if_unexpected(get_bias_shape(linear_attrs, input_shape)); - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(linear_attrs, input_shape)); - - auto make_tensor_attrs = [](ParallelTensorShape const &shape) { - return ParallelTensorAttrs{ - /*shape=*/shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_grad=*/CreateGrad::YES, - }; - }; - - auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { - return ParallelLayerAttrs{ - /*op_attrs=*/op_attrs, - /*name=*/std::nullopt, - }; - }; - - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelLayerAddedResult input = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs(PCGOperatorAttrs{InputAttrs{}}), - /*inputs=*/{}, - /*output_labels=*/{make_tensor_attrs(input_shape)}); - parallel_tensor_guid_t input_tensor = get_only(input.outputs); - - ParallelLayerAddedResult projection = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs( - PCGOperatorAttrs{ - WeightAttrs{ - /*tensor_shape=*/get_reduced_shape(projection_shape), - }, - }), - /*inputs=*/{}, - /*output_labels=*/{make_tensor_attrs(projection_shape)}); - parallel_tensor_guid_t projection_tensor = get_only(projection.outputs); - - ParallelLayerAddedResult bias = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs( - PCGOperatorAttrs{ - WeightAttrs{ - /*tensor_shape=*/get_reduced_shape(bias_shape), - }, - }), - /*inputs=*/{}, - /*output_labels=*/{make_tensor_attrs(bias_shape)}); - parallel_tensor_guid_t bias_tensor = get_only(bias.outputs); - - ParallelLayerAddedResult linear = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs(PCGOperatorAttrs{linear_attrs}), - /*inputs=*/{ - get_only(input.outputs), - get_only(projection.outputs), - get_only(bias.outputs), - }, - /*output_labels=*/{make_tensor_attrs(output_shape)}); - parallel_tensor_guid_t linear_output = get_only(linear.outputs); - - MachineView machine_view = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); - - - CostEstimator cost_estimator = make_fake_cost_estimator( - { - { - OpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{linear_attrs}, - /*input_shapes=*/{input_shape}, - /*weight_shapes=*/{projection_shape, bias_shape}, - /*output_shapes=*/{output_shape}, - /*machine_view=*/machine_view, - }, - 2.0, - }, - }, - {} - ); - - SUBCASE("returns just the layer cost if the layer exists") { - float result = estimate_layer_cost(pcg, - cost_estimator, - linear.parallel_layer, - machine_view); - float correct = 2.0; - - CHECK(result == correct); - } - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 02b3fe4a03..91ea973245 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,18 +1,23 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "./cost_estimator_for_test.h" #include +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping_internal") { - auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, + MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); + + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, MachineSpecification const &) { - return std::unordered_set{ - make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + return std::unordered_set{mv1}; }; MachineSpecification machine_spec = MachineSpecification{ @@ -23,64 +28,49 @@ TEST_SUITE(FF_TEST_SUITE) { /*intra_node_bandwidth=*/1, }; - CostEstimator cost_estimator = make_fake_cost_estimator( - std::unordered_map{}, - std::unordered_map{}); - - SUBCASE("single layer") { - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - MachineView mv1 = make_1d_machine_view(gpu_id_t{1}, gpu_id_t{2}); - - auto allowed_machine_views = [&](ParallelLayerAttrs const &, - MachineSpecification const &) { - return std::unordered_set{mv1}; - }; + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; - ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ - PCGOperatorAttrs{ - InputAttrs{}, + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map>{ + { + k1, + { + {mv1, 1.0}, + } }, - std::nullopt, - }; + }, + std::unordered_map{}); - ParallelTensorAttrs output_tensor_attrs = ParallelTensorAttrs{ - ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, - }, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::YES, - }; + MachineMappingContext context = MachineMappingContext{ + cost_estimator, + allowed_machine_views1, + }; - ParallelLayerAddedResult added = add_parallel_layer(pcg, - layer_attrs, - {}, - {output_tensor_attrs}); - parallel_layer_guid_t layer = added.parallel_layer; - parallel_tensor_guid_t output_tensor = get_only(added.outputs); + SUBCASE("single layer") { + MachineMappingProblemTree problem_tree = + mm_problem_tree_make_leaf(k1); MachineMappingCache cache; - MachineMappingResult result = get_optimal_machine_mapping(pcg, - allowed_machine_views, - cost_estimator, + MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping(cache, + context, + problem_tree, machine_spec, - cache); + constraints); MachineMappingResult correct = MachineMappingResult{ - /*runtime=*/2.0, - /*machine_mapping=*/MachineMapping{{ - {layer, mv1}, - }}, + FeasibleMachineMappingResult{ + /*runtime=*/1.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, }; CHECK(result == correct); diff --git a/lib/utils/include/utils/containers/try_at.h b/lib/utils/include/utils/containers/try_at.h new file mode 100644 index 0000000000..a274c134f7 --- /dev/null +++ b/lib/utils/include/utils/containers/try_at.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H + +#include +#include +#include "utils/containers/contains_key.h" +#include + +namespace FlexFlow { + +template +std::optional try_at(std::unordered_map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +template +std::optional try_at(std::map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..926cc0ea9c --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include +#include "utils/overload.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths(FullBinaryTree const &tree) { + return visit> + (tree, + overload { + [](LeafLabel const &) { + return std::unordered_set{binary_tree_root_path()}; + }, + [](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(get_all_leaf_paths(get_left_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(get_right_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 4a1e615830..978eba4d74 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -9,9 +9,11 @@ namespace FlexFlow { template Result visit(FullBinaryTree const &tt, F f) { if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); + Result result = f(std::get>(tt.root)); + return result; } else if (std::holds_alternative(tt.root)) { - return f(std::get(tt.root)); + Result result = f(std::get(tt.root)); + return result; } else { throw mk_runtime_error( "Unexpected case in visit(FullBinaryTree)"); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml new file mode 100644 index 0000000000..7c491ad49d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeVisitor" +features = [] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", + "SeriesSplitLabel2", + "ParallelSplitLabel2", + "LeafLabel2", +] + +includes = [ + "", +] + +[[fields]] +name = "series_split_func" +type = "std::function" + +[[fields]] +name = "parallel_split_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..0bb0e08eae --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths(GenericBinarySPDecompositionTree const &tree) { + return get_all_leaf_paths(tree.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h index e5b3b65ccd..fe308ec762 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -12,7 +12,7 @@ template > get_subtree_at_path(GenericBinarySPDecompositionTree const &tree, BinaryTreePath const &path) { - std::optional, LeafLabel>> raw_subtree = get_subtree_at_path(tree.raw_tree, path); + std::optional, LeafLabel>> raw_subtree = get_subtree_at_path(tree.raw_tree, path); if (!raw_subtree.has_value()) { return std::nullopt; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index c557711a3b..96c3cd5de8 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -2,9 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include "utils/overload.h" namespace FlexFlow { @@ -12,61 +15,65 @@ namespace FlexFlow { template , - typename ParallelLabel2 = std::invoke_result_t, - typename LeafLabel2 = std::invoke_result_t> + typename SeriesLabel2, + typename ParallelLabel2, + typename LeafLabel2> +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, + GenericBinarySPDecompositionTreeVisitor const &visitor); + +template GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &s, F f) { + transform(GenericBinarySeriesSplit const &s, + GenericBinarySPDecompositionTreeVisitor const &visitor) { return GenericBinarySeriesSplit{ - f(s.label), - transform(get_left_child(s), f), - transform(get_right_child(s), f), + visitor.series_split_func(s.label), + transform(get_left_child(s), visitor), + transform(get_right_child(s), visitor), }; }; template , - typename ParallelLabel2 = std::invoke_result_t, - typename LeafLabel2 = std::invoke_result_t> + typename SeriesLabel2, + typename ParallelLabel2, + typename LeafLabel2> GenericBinaryParallelSplit - transform(GenericBinaryParallelSplit const &s, F f) { + transform(GenericBinaryParallelSplit const &s, + GenericBinarySPDecompositionTreeVisitor const &visitor) { return GenericBinaryParallelSplit{ - f(s.label), - transform(get_left_child(s), f), - transform(get_right_child(s), f), + visitor.parallel_split_func(s.label), + transform(get_left_child(s), visitor), + transform(get_right_child(s), visitor), }; }; template , - typename ParallelLabel2 = std::invoke_result_t, - typename LeafLabel2 = std::invoke_result_t> + typename SeriesLabel2, + typename ParallelLabel2, + typename LeafLabel2> GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, F f) { + transform(GenericBinarySPDecompositionTree const &tt, + GenericBinarySPDecompositionTreeVisitor const &visitor) { return visit>( tt, overload{ [&](GenericBinarySeriesSplit const &s) { - return GenericBinarySPDecompositionTree{ - transform(s, f), - }; + return wrap_series_split(transform(s, visitor)); }, [&](GenericBinaryParallelSplit const &s) { - return GenericBinarySPDecompositionTree{ - transform(s, f), - }; + return wrap_parallel_split(transform(s, visitor)); }, [&](LeafLabel const &t) { - return GenericBinarySPDecompositionTree{ - f(t), - }; + return make_generic_binary_sp_leaf(visitor.leaf_func(t)); }, }); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h index 509c20ba23..ec0a45f83a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h @@ -10,11 +10,13 @@ namespace FlexFlow { template GenericBinarySPDecompositionTree wrap_series_split(GenericBinarySeriesSplit const &series_split) { - return FullBinaryTree, LeafLabel> { - FullBinaryTreeParentNode, LeafLabel> { - /*label=*/series_split.label, - /*lhs=*/series_split.pre.raw_tree, - /*rhs=*/series_split.post.raw_tree, + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel> { + FullBinaryTreeParentNode, LeafLabel> { + /*label=*/GenericBinarySPSplitLabel{series_split.label}, + /*lhs=*/series_split.pre.raw_tree, + /*rhs=*/series_split.post.raw_tree, + }, }, }; } @@ -22,11 +24,13 @@ GenericBinarySPDecompositionTree template GenericBinarySPDecompositionTree wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { - return FullBinaryTree, LeafLabel> { - FullBinaryTreeParentNode, LeafLabel> { - /*label=*/parallel_split.label, - /*lhs=*/parallel_split.lhs.raw_tree, - /*rhs=*/parallel_split.rhs.raw_tree, + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel> { + FullBinaryTreeParentNode, LeafLabel> { + /*label=*/GenericBinarySPSplitLabel{parallel_split.label}, + /*lhs=*/parallel_split.lhs.raw_tree, + /*rhs=*/parallel_split.rhs.raw_tree, + }, }, }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml new file mode 100644 index 0000000000..27203b8b05 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySPDecompositionTreeVisitor" +features = [] + +template_params = [ + "LeafLabel", + "LeafLabel2", +] + +includes = [ + "", +] + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h index 364a3200b1..b4f4239d39 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -3,57 +3,54 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" namespace FlexFlow { -template > -LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &t, F &&f) { - auto ff = overload { - [&](T const &t) { - return f(t); - }, - [&](auto const &x) { - return x; - }, - }; - - return LeafOnlyBinarySeriesSplit{ - transform(t.pre, f), - transform(t.post, f), +template +LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { + return LeafOnlyBinarySeriesSplit{ + transform(t.pre, visitor), + transform(t.post, visitor), }; } -template > -LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &t, F &&f) { - auto ff = overload { - [&](T const &t) { - return f(t); - }, - [&](auto const &x) { - return x; - }, - }; - - return LeafOnlyBinaryParallelSplit{ - transform(t.lhs, f), - transform(t.rhs, f), +template +LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { + return LeafOnlyBinaryParallelSplit{ + transform(t.lhs, visitor), + transform(t.rhs, visitor), }; } -template > -LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, F &&f) { - auto ff = overload { - [&](T const &t) { - return f(t); +template +LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { + using GenericVisitor = GenericBinarySPDecompositionTreeVisitor + ; + + GenericVisitor generic_visitor = GenericVisitor{ + [&](LeafOnlyBinarySeriesSplitLabel const &x) { + return x; }, - [&](auto const &x) { + [&](LeafOnlyBinaryParallelSplitLabel const &x) { return x; }, + [&](LeafLabel const &t) { + return visitor.leaf_func(t); + }, }; - return LeafOnlyBinarySPDecompositionTree{ - transform(t.raw_tree, ff), + return LeafOnlyBinarySPDecompositionTree{ + transform(t.raw_tree, generic_visitor), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h index 4a86bc8d49..0284f6ba41 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h @@ -18,8 +18,8 @@ LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySer LeafOnlyBinaryParallelSplitLabel, LeafLabel>{ LeafOnlyBinarySeriesSplitLabel{}, - split.pre, - split.post, + split.pre.raw_tree, + split.post.raw_tree, } ), }; @@ -34,8 +34,8 @@ LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryP LeafOnlyBinaryParallelSplitLabel, LeafLabel>{ LeafOnlyBinaryParallelSplitLabel{}, - split.lhs, - split.rhs, + split.lhs.raw_tree, + split.rhs.raw_tree, } ), }; diff --git a/lib/utils/src/utils/containers/try_at.cc b/lib/utils/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..0d1ed3b04a --- /dev/null +++ b/lib/utils/src/utils/containers/try_at.cc @@ -0,0 +1 @@ +#include "utils/containers/try_at.h" diff --git a/lib/utils/test/src/utils/containers/try_at.cc b/lib/utils/test/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..818456f65a --- /dev/null +++ b/lib/utils/test/src/utils/containers/try_at.cc @@ -0,0 +1,26 @@ +#include "utils/containers/try_at.h" +#include +#include +#include "test/utils/doctest/fmt/optional.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("try_at(T, K)", T, std::unordered_map, std::map) { + T m = {{1, "one"}, {2, "two"}}; + + SUBCASE("map contains key") { + std::optional result = try_at(m, 1); + std::optional correct = "one"; + + CHECK(result == correct); + } + + SUBCASE("map does not contain key") { + std::optional result = try_at(m, 3); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} From 85fd5b4fde364da21d150a55ecc4256ea75505ea Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 2 Oct 2024 11:12:38 -0700 Subject: [PATCH 19/29] Get tests building again --- .../get_optimal_machine_mapping.h | 4 +- .../machine_mapping_constraints.h | 2 - .../machine_mapping/machine_mapping_result.h | 20 +- ...racted_tensor_set_movement_across_split.cc | 2 +- .../get_optimal_machine_mapping.cc | 25 +- .../machine_mapping_constraints.cc | 12 +- .../machine_mapping_problem_tree.cc | 13 + .../machine_mapping/machine_mapping_result.cc | 18 +- ...racted_tensor_set_movement_across_split.cc | 254 ++++++++++ .../cost_estimator_for_test.cc | 20 +- .../machine_mapping/cost_estimator_for_test.h | 4 + .../get_tensor_set_movement_across_split.cc | 466 +++++++++--------- .../get_machine_mapping_problem_tree.cc | 90 +++- .../machine_mapping/machine_mapping_result.cc | 163 +++--- lib/utils/include/utils/containers/flatmap.h | 17 + .../include/utils/containers/merge_maps.h | 9 +- .../test/src/utils/containers/flatmap.cc | 65 +++ 17 files changed, 822 insertions(+), 362 deletions(-) create mode 100644 lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 1c52ccc2bb..9cc7db4da2 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -7,6 +7,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" #include "pcg/machine_specification.dtg.h" namespace FlexFlow { @@ -23,7 +24,8 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingContext const &context, MMProblemTreeSeriesSplit const &series_split, MachineSpecification const &resources, - MachineMappingConstraints const &constraints); + MachineMappingConstraints const &constraints, + std::optional const ¶llel_split_transformation); MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index 3ab879aed3..6e46b49c69 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -26,8 +26,6 @@ MachineMappingConstraints restrict_to_right_child(MachineMappingConstraints cons MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, ParallelLayerGuidObliviousMachineMapping const &); -MachineMapping require_fully_constrained(MachineMappingConstraints const &); - std::optional require_only_root(MachineMappingConstraints const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 6fb4c70d2e..bc723b924c 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -6,21 +6,21 @@ namespace FlexFlow { -MachineMappingResult infeasible_machine_mapping_result(); -bool is_infeasible(MachineMappingResult const &); +[[nodiscard]] MachineMappingResult infeasible_machine_mapping_result(); +[[nodiscard]] bool is_infeasible(MachineMappingResult const &); FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); -MachineMappingResult series_combine(float comm_cost, - MachineMappingResult const &pre_result, - MachineMappingResult const &post_result, - std::optional const ¶llel_split_transformation); -MachineMappingResult parallel_combine(MachineMappingResult const &lhs_result, - MachineMappingResult const &rhs_result); +[[nodiscard]] MachineMappingResult series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const ¶llel_split_transformation); +[[nodiscard]] MachineMappingResult parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); [[nodiscard]] MachineMappingResult minimize_runtime(MachineMappingResult const &m1, MachineMappingResult const &m2); -MachineMappingResult make_singleton_machine_mapping_result(float runtime, - MachineView const &machine_view); +[[nodiscard]] MachineMappingResult make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index cea2df8073..3846644155 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -10,7 +10,7 @@ namespace FlexFlow { -AbstractedTensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { auto get_path_to_layer = [&](parallel_layer_guid_t const &l) { diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index bf3b2d9f9b..2912dec1ac 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -47,13 +47,24 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult result = visit( problem_tree, - [&](auto const &decomp_tree_node) { - return get_optimal_machine_mapping - (result_cache, - context, - decomp_tree_node, - resources, - constraints); + overload { + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping + (result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping + (result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, }); result_cache.save(state, result); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 206bde2e78..5d04c99e72 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -36,8 +36,8 @@ std::optional get_machine_view_for_layer(MachineMappingConstraints return partial_solution.machine_views.at(layer); } -MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &constraints, - BinaryTreePathEntry const &prefix) { +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { return MachineMappingConstraints{ filtermap_keys(constraints.machine_views, [&](BinaryTreePath const &path) -> std::optional { @@ -82,5 +82,13 @@ MachineMappingConstraints with_additional_constraints(MachineMappingConstraints return result; } +std::optional require_only_root(MachineMappingConstraints const &constraints) { + if (keys(constraints.machine_views) != std::unordered_set{binary_tree_root_path()}) { + throw mk_runtime_error(fmt::format("require_only_root expected constraints to have only a single key (the root path), but received {}", constraints)); + } + + return constraints.machine_views.at(binary_tree_root_path()); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index e3f5e582ea..6b75d3943b 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -5,6 +5,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" namespace FlexFlow { @@ -59,6 +60,18 @@ UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { return require_leaf(t.raw_tree); } +MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &series) { + return MachineMappingProblemTree{ + wrap_series_split(series.raw_split), + }; +} + +MachineMappingProblemTree wrap_parallel_split(MMProblemTreeParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + wrap_parallel_split(parallel.raw_split), + }; +} + std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { return get_leaves(t.raw_tree); } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 7ee49f465b..804e6254e3 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -7,7 +7,19 @@ namespace FlexFlow { -MachineMappingResult sequential_combine(float comm_cost, +MachineMappingResult infeasible_machine_mapping_result() { + return MachineMappingResult{std::nullopt}; +} + +bool is_infeasible(MachineMappingResult const &result) { + return !result.raw_result.has_value(); +} + +FeasibleMachineMappingResult require_feasible(MachineMappingResult const &result) { + return result.raw_result.value(); +} + +MachineMappingResult series_combine(float comm_cost, MachineMappingResult const &maybe_pre_result, MachineMappingResult const &maybe_post_result, std::optional const ¶llel_split_transformation) { @@ -70,10 +82,6 @@ MachineMappingResult parallel_combine(MachineMappingResult const &maybe_lhs_resu }; } -MachineMappingResult infeasible_machine_mapping_result() { - return MachineMappingResult{std::nullopt}; -} - MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, MachineMappingResult const &maybe_m2) { FeasibleMachineMappingResult m1 = ({ diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..5a2a478214 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,254 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_abstracted_tensor_set_movement_across_split") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAttrs relu_attrs + = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs ew_add_attrs + = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + ParallelLayerAddedResult layer_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult layer_2 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + + SUBCASE("single edge across split") { + PCGBinarySeriesSplit split = require_series(\ + make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(layer_1.parallel_layer)), + make_pcg_leaf_node(layer_2.parallel_layer))); + + AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split + (pcg_get_transitive_reduction(pcg), + split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult layer_3 + = add_parallel_layer(pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(\ + make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_series_split( + make_pcg_leaf_node(layer_1.parallel_layer), + make_pcg_leaf_node(layer_2.parallel_layer))), + make_pcg_leaf_node(layer_3.parallel_layer))); + + AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split + (pcg_get_transitive_reduction(pcg), + split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult layer_3 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(layer_1.parallel_layer)), + make_pcg_parallel_split( + make_pcg_leaf_node(layer_2.parallel_layer), + make_pcg_leaf_node(layer_3.parallel_layer)))); + + AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split + (pcg_get_transitive_reduction(pcg), + split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult layer_3 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + + ParallelLayerAddedResult layer_4 + = add_parallel_layer(pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_3.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_parallel_split( + make_pcg_leaf_node(layer_1.parallel_layer), + make_pcg_leaf_node(layer_3.parallel_layer))), + make_pcg_parallel_split( + make_pcg_leaf_node(layer_2.parallel_layer), + make_pcg_leaf_node(layer_4.parallel_layer)))); + + AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split + (pcg_get_transitive_reduction(pcg), + split); + + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc index 5a7f56eb79..75808b88b4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -1,4 +1,7 @@ #include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/map_keys.h" namespace FlexFlow { @@ -36,4 +39,19 @@ CostEstimator make_fake_cost_estimator( }); } -} // namespace FlexFlow +CostEstimator make_fake_cost_estimator( + std::unordered_map> const &op_cost_map, + std::unordered_map const &comm_cost_map) { + + auto de_nest_key = [](UnmappedOpCostEstimateKey const &k1, std::unordered_map const &v) + -> std::unordered_map + { + return map_keys(v, [&](MachineView const &k2) { return map_unmapped_op_cost_estimate_key(k1, k2); }); + }; + + std::unordered_map mapped_costs = flatmap(op_cost_map, de_nest_key); + + return make_fake_cost_estimator(mapped_costs, comm_cost_map); +} + +} // namespace FlexFlop diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index c8e2624c54..bdfe0d57cb 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -25,6 +25,10 @@ CostEstimator make_fake_cost_estimator( std::function const &get_operator_cost, std::function const &get_communication_cost); +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); + CostEstimator make_fake_cost_estimator( std::unordered_map> const &op_cost_map, std::unordered_map const &comm_cost_map); diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index cce5dbb1a2..e75f6626bb 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,233 +1,233 @@ -#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" -#include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "utils/containers/get_only.h" -#include -#include "./cost_estimator_for_test.h" - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_tensor_set_movement_across_split") { - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - ParallelTensorShape input_shape = - ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{12, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, - }; - ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - - ParallelLayerAttrs relu_attrs - = ParallelLayerAttrs{ - /*op_attrs=*/PCGOperatorAttrs{ - ElementUnaryAttrs{ - /*op_type=*/OperatorType::RELU, - /*scalar=*/std::nullopt, - }, - }, - /*name=*/std::nullopt, - }; - - ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ - /*shape=*/input_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::YES, - }; - - ParallelLayerAddedResult relu_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - ParallelLayerAddedResult relu_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(relu_1.outputs)}, - {relu_output_attrs}); - - MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); - MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); - MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); - MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); - - SUBCASE("single edge across split") { - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(relu_1.parallel_layer)), - make_pcg_leaf_node(relu_2.parallel_layer))); - - PartialMachineMapping pre_mapping = PartialMachineMapping{{ - {relu_1.parallel_layer, pre_mv1}, - }}; - - PartialMachineMapping post_mapping = PartialMachineMapping{{ - {relu_2.parallel_layer, post_mv1}, - }}; - - TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), - split, - pre_mapping, - post_mapping); - TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1}, - }, - }, - }; - - CHECK(result == correct); - } - - SUBCASE("does not include edges removed by transitive reduction") { - - } - - SUBCASE("single tensor, multiple consumers across split") { - ParallelLayerAddedResult relu_3 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(relu_1.outputs)}, - {relu_output_attrs}); - - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(relu_1.parallel_layer)), - make_pcg_parallel_split( - make_pcg_leaf_node(relu_2.parallel_layer), - make_pcg_leaf_node(relu_3.parallel_layer)))); - - SUBCASE("consumers have same view") { - PartialMachineMapping pre_mapping = PartialMachineMapping{{ - {relu_1.parallel_layer, pre_mv1}, - }}; - - PartialMachineMapping post_mapping = PartialMachineMapping{{ - {relu_2.parallel_layer, post_mv1}, - {relu_3.parallel_layer, post_mv1}, - }}; - - TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), - split, - pre_mapping, - post_mapping); - - TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1}, - }, - }, - }; - - CHECK(result == correct); - } - - SUBCASE("consumers have different views") { - PartialMachineMapping pre_mapping = PartialMachineMapping{{ - {relu_1.parallel_layer, pre_mv1}, - }}; - - PartialMachineMapping post_mapping = PartialMachineMapping{{ - {relu_2.parallel_layer, post_mv1}, - {relu_3.parallel_layer, post_mv2}, - }}; - - TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), - split, - pre_mapping, - post_mapping); - - TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1, post_mv2}, - }, - }, - }; - - CHECK(result == correct); - } - } - - SUBCASE("multiple tensors, multiple consumers across split") { - ParallelLayerAddedResult relu_3 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult relu_4 - = add_parallel_layer(pcg, - relu_attrs, - // relu's don't have two inputs, but for the purposes of this test it's fine. - {get_only(relu_1.outputs), get_only(relu_3.outputs)}, - {relu_output_attrs}); - - PartialMachineMapping pre_mapping = PartialMachineMapping{{ - {relu_1.parallel_layer, pre_mv1}, - {relu_3.parallel_layer, pre_mv2}, - }}; - - PartialMachineMapping post_mapping = PartialMachineMapping{{ - {relu_2.parallel_layer, post_mv1}, - {relu_4.parallel_layer, post_mv2}, - }}; - - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_parallel_split( - make_pcg_leaf_node(relu_1.parallel_layer), - make_pcg_leaf_node(relu_3.parallel_layer))), - make_pcg_parallel_split( - make_pcg_leaf_node(relu_2.parallel_layer), - make_pcg_leaf_node(relu_4.parallel_layer)))); - - TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), - split, - pre_mapping, - post_mapping); - - - TensorSetMovement correct = TensorSetMovement{ - /*single_tensor_movements=*/{ - SingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv1}, - /*dst_machine_views=*/{post_mv1, post_mv2}, - }, - SingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{pre_mv2}, - /*dst_machine_views=*/{post_mv2}, - }, - }, - }; - - CHECK(result == correct); - } - } -} +// #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +// #include "compiler/machine_mapping/transitive_reduced_pcg.h" +// #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +// #include "pcg/machine_view.h" +// #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +// #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +// #include "utils/containers/get_only.h" +// #include +// #include "./cost_estimator_for_test.h" +// +// using namespace ::FlexFlow; +// +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("get_tensor_set_movement_across_split") { +// ParallelComputationGraph pcg = empty_parallel_computation_graph(); +// +// ParallelTensorShape input_shape = +// ParallelTensorShape{ +// ParallelTensorDims{ +// FFOrdered{ +// ShardParallelDim{10, 2}, +// ShardParallelDim{12, 1}, +// }, +// ReplicaParallelDimSet{ +// SumDegree{1}, +// DiscardCopyDegree{1}, +// }, +// }, +// DataType::FLOAT, +// }; +// ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); +// +// ParallelLayerAttrs relu_attrs +// = ParallelLayerAttrs{ +// /*op_attrs=*/PCGOperatorAttrs{ +// ElementUnaryAttrs{ +// /*op_type=*/OperatorType::RELU, +// /*scalar=*/std::nullopt, +// }, +// }, +// /*name=*/std::nullopt, +// }; +// +// ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ +// /*shape=*/input_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// /*create_gradients=*/CreateGrad::YES, +// }; +// +// ParallelLayerAddedResult relu_1 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// ParallelLayerAddedResult relu_2 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); +// MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); +// MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); +// MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); +// +// SUBCASE("single edge across split") { +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_leaf_node(relu_2.parallel_layer))); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("does not include edges removed by transitive reduction") { +// +// } +// +// SUBCASE("single tensor, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer)))); +// +// SUBCASE("consumers have same view") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("consumers have different views") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv2}, +// }}; +// +// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("multiple tensors, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// +// ParallelLayerAddedResult relu_4 +// = add_parallel_layer(pcg, +// relu_attrs, +// // relu's don't have two inputs, but for the purposes of this test it's fine. +// {get_only(relu_1.outputs), get_only(relu_3.outputs)}, +// {relu_output_attrs}); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// {relu_3.parallel_layer, pre_mv2}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_4.parallel_layer, post_mv2}, +// }}; +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_1.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer))), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_4.parallel_layer)))); +// +// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv2}, +// /*dst_machine_views=*/{post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index de4da010e5..0a107c2682 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,5 +1,5 @@ -#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/get_only.h" @@ -42,24 +42,43 @@ TEST_SUITE(FF_TEST_SUITE) { PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + auto make_input_key = [&](ParallelTensorShape const ¶llel_tensor_shape) { + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/input_attrs, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{parallel_tensor_shape}, + }; + }; + SUBCASE("single layer") { - ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input_added = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); parallel_layer_guid_t input_layer = input_added.parallel_layer; + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + PCGBinarySPDecomposition sp_decomposition = \ make_pcg_leaf_node(input_layer); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_attrs); + MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_key); CHECK(result == correct); } SUBCASE("two layers in series") { - ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input_added = add_parallel_layer(pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); parallel_layer_guid_t input_layer = input_added.parallel_layer; parallel_tensor_guid_t input = get_only(input_added.outputs); + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ ElementUnaryAttrs{ /*op_type=*/OperatorType::RELU, @@ -74,6 +93,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t relu_layer = relu_added.parallel_layer; parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/relu_attrs, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{relu_output_shape}, + }; + PCGBinarySPDecomposition sp_decomposition = \ make_pcg_series_split( make_pcg_leaf_node(input_layer), @@ -85,13 +111,17 @@ TEST_SUITE(FF_TEST_SUITE) { mm_problem_tree_make_series_split( AbstractedTensorSetMovement{{ AbstractedSingleTensorMovement{ - input_shape, - {input_layer}, - {relu_layer}, + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{ + BinaryTreePath{{}}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{}}, + }, }, }}, - mm_problem_tree_make_leaf(input_attrs), - mm_problem_tree_make_leaf(relu_attrs)); + mm_problem_tree_make_leaf(input_key), + mm_problem_tree_make_leaf(relu_key)); CHECK(result == correct); } @@ -99,9 +129,11 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("two layers in parallel") { ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); PCGBinarySPDecomposition sp_decomposition = \ make_pcg_series_split( @@ -112,8 +144,8 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingProblemTree correct = \ mm_problem_tree_make_parallel_split( - mm_problem_tree_make_leaf(input_attrs), - mm_problem_tree_make_leaf(input_attrs)); + mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); CHECK(result == correct); } @@ -122,10 +154,12 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input1_layer = input1_added.parallel_layer; parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input2_layer = input2_added.parallel_layer; parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ ElementBinaryAttrs{ @@ -141,6 +175,12 @@ TEST_SUITE(FF_TEST_SUITE) { {input1_tensor, input2_tensor}, {make_output_attrs(ew_op_output_shape)}); parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/ew_op_attrs, + /*input_shapes=*/{input_shape, input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{ew_op_output_shape}, + }; PCGBinarySPDecomposition sp_decomposition = \ make_pcg_series_split( @@ -156,19 +196,31 @@ TEST_SUITE(FF_TEST_SUITE) { AbstractedTensorSetMovement{{ AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{input1_layer}, - /*dst_machine_views=*/{ew_op_layer}, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{}}, + }, }, AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{input2_layer}, - /*dst_machine_views=*/{ew_op_layer}, + /*src_machine_views=*/{ + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/{ + BinaryTreePath{{}}, + }, }, }}, /*pre=*/mm_problem_tree_make_parallel_split( - mm_problem_tree_make_leaf(input_attrs), - mm_problem_tree_make_leaf(input_attrs)), - /*post=*/mm_problem_tree_make_leaf(ew_op_attrs)); + mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), + /*post=*/mm_problem_tree_make_leaf(ew_op_key)); CHECK(result == correct); } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index ba06265cec..10db2496f6 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -7,91 +7,94 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sequential_combine") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - MachineMapping combined( - {{layer0, machine_view_0}, {layer1, machine_view_1}}); - MachineMappingResult s0(0, machine_mapping_empty); - MachineMappingResult s1(1, machine_mapping_0); - MachineMappingResult s2(2, machine_mapping_1); - - float comm_cost = 2.0; - - MachineMappingResult result0 = sequential_combine(s0, comm_cost, s1); - CHECK(result0.runtime == 1); - CHECK(result0.machine_mapping == machine_mapping_0); - - MachineMappingResult result1 = sequential_combine(s0, comm_cost, s2); - CHECK(result1.runtime == 2); - CHECK(result1.machine_mapping == machine_mapping_1); - - MachineMappingResult result2 = sequential_combine(s1, comm_cost, s2); - CHECK(result2.runtime == 3); - CHECK(result2.machine_mapping == combined); + FAIL("TODO"); + // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); + // MachineMapping machine_mapping_empty( + // std::unordered_map{}); + // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); + // MachineMapping combined( + // {{layer0, machine_view_0}, {layer1, machine_view_1}}); + // MachineMappingResult s0(0, machine_mapping_empty); + // MachineMappingResult s1(1, machine_mapping_0); + // MachineMappingResult s2(2, machine_mapping_1); + // + // float comm_cost = 2.0; + // + // MachineMappingResult result0 = sequential_combine(s0, comm_cost, s1); + // CHECK(result0.runtime == 1); + // CHECK(result0.machine_mapping == machine_mapping_0); + // + // MachineMappingResult result1 = sequential_combine(s0, comm_cost, s2); + // CHECK(result1.runtime == 2); + // CHECK(result1.machine_mapping == machine_mapping_1); + // + // MachineMappingResult result2 = sequential_combine(s1, comm_cost, s2); + // CHECK(result2.runtime == 3); + // CHECK(result2.machine_mapping == combined); } TEST_CASE("parallel_combine") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - MachineMapping combined( - {{layer0, machine_view_0}, {layer1, machine_view_1}}); - MachineMappingResult s0(0, machine_mapping_empty); - MachineMappingResult s1(1, machine_mapping_0); - MachineMappingResult s2(2, machine_mapping_1); - - MachineMappingResult result0 = parallel_combine(s0, s1); - CHECK(result0.runtime == 1); - CHECK(result0.machine_mapping == machine_mapping_0); - - MachineMappingResult result1 = parallel_combine(s0, s2); - CHECK(result1.runtime == 2); - CHECK(result1.machine_mapping == machine_mapping_1); - - MachineMappingResult result2 = parallel_combine(s1, s2); - CHECK(result2.runtime == 2); - CHECK(result2.machine_mapping == combined); + FAIL("TODO"); + // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); + // MachineMapping machine_mapping_empty( + // std::unordered_map{}); + // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); + // MachineMapping combined( + // {{layer0, machine_view_0}, {layer1, machine_view_1}}); + // MachineMappingResult s0(0, machine_mapping_empty); + // MachineMappingResult s1(1, machine_mapping_0); + // MachineMappingResult s2(2, machine_mapping_1); + // + // MachineMappingResult result0 = parallel_combine(s0, s1); + // CHECK(result0.runtime == 1); + // CHECK(result0.machine_mapping == machine_mapping_0); + // + // MachineMappingResult result1 = parallel_combine(s0, s2); + // CHECK(result1.runtime == 2); + // CHECK(result1.machine_mapping == machine_mapping_1); + // + // MachineMappingResult result2 = parallel_combine(s1, s2); + // CHECK(result2.runtime == 2); + // CHECK(result2.machine_mapping == combined); } TEST_CASE("minimize_runtime") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - MachineMapping machine_mapping_empty( - std::unordered_map{}); - MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - MachineMapping combined( - {{layer0, machine_view_0}, {layer1, machine_view_1}}); - MachineMappingResult s0(0, machine_mapping_empty); - MachineMappingResult s1(1, machine_mapping_0); - MachineMappingResult s2(2, machine_mapping_1); - - MachineMappingResult _s0 = s0; - MachineMappingResult _s1 = s1; - MachineMappingResult _s2 = s2; - - minimize_runtime(_s0, _s1); - CHECK(_s0 == s0); - minimize_runtime(_s1, _s2); - CHECK(_s1 == s1); - - minimize_runtime(_s1, _s0); - CHECK(_s1 == s0); - - minimize_runtime(_s2, get_infinity_machine_mapping_result()); - CHECK(_s2 == s2); + FAIL("TODO"); + // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); + // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); + // MachineMapping machine_mapping_empty( + // std::unordered_map{}); + // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); + // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); + // MachineMapping combined( + // {{layer0, machine_view_0}, {layer1, machine_view_1}}); + // MachineMappingResult s0(0, machine_mapping_empty); + // MachineMappingResult s1(1, machine_mapping_0); + // MachineMappingResult s2(2, machine_mapping_1); + // + // MachineMappingResult _s0 = s0; + // MachineMappingResult _s1 = s1; + // MachineMappingResult _s2 = s2; + // + // minimize_runtime(_s0, _s1); + // CHECK(_s0 == s0); + // minimize_runtime(_s1, _s2); + // CHECK(_s1 == s1); + // + // minimize_runtime(_s1, _s0); + // CHECK(_s1 == s0); + // + // minimize_runtime(_s2, get_infinity_machine_mapping_result()); + // CHECK(_s2 == s2); } } diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 0f8906f34a..537bb2d177 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -4,6 +4,8 @@ #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" #include +#include +#include "utils/containers/merge_maps.h" namespace FlexFlow { @@ -39,6 +41,21 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } +template ::key_type, + typename OutV = typename std::invoke_result_t::mapped_type> +std::unordered_map flatmap(std::unordered_map const &m, F &&f) { + std::unordered_map result; + + for (auto const &[k, v] : m) { + result = merge_maps(result, f(k, v)); + } + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 653c9d24f1..6a3f230d08 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -4,13 +4,20 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" #include +#include "utils/fmt/unordered_map.h" +#include "utils/exception.h" namespace FlexFlow { template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); + if (!are_disjoint(keys(lhs), keys(rhs))) { + throw mk_runtime_error(fmt::format("Key sets of merge_maps parameters are " + "non-disjoint: lhs = {}, rhs = {}", + lhs, + rhs)); + } std::unordered_map result; for (auto const &kv : lhs) { diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index 9d7e6439e2..dc8d8437a9 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -2,6 +2,10 @@ #include #include #include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/containers/map_keys.h" +#include "utils/hash/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/pair.h" using namespace ::FlexFlow; @@ -33,4 +37,65 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("flatmap(std::unordered_map, F)") { + auto de_nest_keys = [](int k1, std::unordered_map const &v) { + return map_keys(v, [&](int k2) { return std::pair{k1, k2}; }); + }; + + SUBCASE("input is empty") { + std::unordered_map> input = {}; + + std::unordered_map, std::string> result = flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::unordered_map> input = { + { + 1, + { + {2, "a"}, + {3, "b"}, + }, + }, + { + 2, + {}, + }, + { + 3, + { + {3, "a"}, + }, + }, + }; + + std::unordered_map, std::string> result = flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = { + {{1, 2}, "a"}, + {{1, 3}, "b"}, + {{3, 3}, "a"}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate result keys") { + auto always_return_same_map = [](int, std::string const &) { + return std::unordered_map{ + {"mykey", 10000}, + }; + }; + + std::unordered_map input = { + {1, "a"}, + {2, "b"}, + }; + + CHECK_THROWS(flatmap(input, always_return_same_map)); + } + } } From 597e13c0e81b233ef44a1af1bd6ad87ebdeee9f0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 2 Oct 2024 17:33:59 -0700 Subject: [PATCH 20/29] Pass some basic tests of get_optimal_machine_mapping --- .../machine_mapping/machine_mapping_result.h | 2 + ...racted_tensor_set_movement_across_split.cc | 17 +- .../get_optimal_machine_mapping.cc | 122 +++--- .../machine_mapping/machine_mapping_result.cc | 10 + ...el_layer_guid_oblivious_machine_mapping.cc | 2 +- .../pcg_binary_sp_decomposition.cc | 6 +- ...racted_tensor_set_movement_across_split.cc | 94 ++++- .../cost_estimator_for_test.cc | 18 +- .../machine_mapping/cost_estimator_for_test.h | 6 +- .../get_optimal_machine_mapping.cc | 345 +++++++-------- .../machine_mapping/machine_mapping_result.cc | 396 ++++++++++++++---- lib/pcg/src/pcg/computation_graph_builder.cc | 4 +- .../utils/containers/get_all_assignments.h | 6 +- lib/utils/include/utils/containers/get_only.h | 4 +- lib/utils/include/utils/exception.h | 7 +- .../make.h | 6 +- lib/utils/include/utils/sequence.h | 2 +- lib/utils/include/utils/tuple.h | 4 +- lib/utils/src/utils/exception.cc | 8 + .../get_edges_from_subgraph_to_subgraph.cc | 3 +- .../instances/hashmap_undirected_graph.cc | 8 +- .../binary_sp_decomposition_tree.cc | 6 +- .../utils/containers/get_all_assignments.cc | 2 +- 23 files changed, 657 insertions(+), 421 deletions(-) diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index bc723b924c..225c8c6f5c 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -10,6 +10,8 @@ namespace FlexFlow { [[nodiscard]] bool is_infeasible(MachineMappingResult const &); FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime(std::unordered_set const &); + [[nodiscard]] MachineMappingResult series_combine(float comm_cost, MachineMappingResult const &pre_result, MachineMappingResult const &post_result, diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 3846644155..6ec7a545b5 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg_binary_series_split.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" @@ -11,12 +12,8 @@ namespace FlexFlow { AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { + PCGBinarySeriesSplit const &split) { - auto get_path_to_layer = [&](parallel_layer_guid_t const &l) { - return get_only(find_paths_to_leaf(wrap_series_split(split), l)); - }; - std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); @@ -38,8 +35,14 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(Tran return AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/transform(src_layers, get_path_to_layer), - /*dst_machine_views=*/transform(dst_layers, get_path_to_layer), + /*src_machine_views=*/transform(src_layers, + [&](parallel_layer_guid_t const &l) { + return get_only(find_paths_to_leaf(get_left_child(split), l)); + }), + /*dst_machine_views=*/transform(dst_layers, + [&](parallel_layer_guid_t const &l) { + return get_only(find_paths_to_leaf(get_right_child(split), l)); + }), }; }; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 2912dec1ac..b8b2f1d19c 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -79,10 +79,6 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingConstraints const &constraints, std::optional const ¶llel_split_transformation) { - MachineMappingResult result = infeasible_machine_mapping_result(); - - AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); - auto get_boundary_machine_view_assignments = [&](std::unordered_set const &boundary_layers) -> std::unordered_set { @@ -100,47 +96,50 @@ MachineMappingResult get_optimal_machine_mapping( }); }; - for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views - : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { - + auto eval_pre_boundary_mapping = [&](ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views) { MachineMappingConstraints pre_candidate = with_additional_constraints( restrict_to_left_child(constraints), assigned_pre_machine_views); - MachineMappingResult pre_result = ({ - std::optional returned - = get_optimal_machine_mapping(result_cache, + MachineMappingResult pre_result = + get_optimal_machine_mapping(result_cache, context, get_pre_child(series_split), resources, pre_candidate); - if (!returned.has_value()) { - continue; - } - returned.value(); - }); + + return pre_result; + }; + + auto eval_post_boundary_mapping = [&](ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = + with_additional_constraints( + restrict_to_right_child(constraints), + assigned_post_machine_views); + + MachineMappingResult post_result + = get_optimal_machine_mapping(result_cache, + context, + get_post_child(series_split), + resources, + post_candidate); + + return post_result; + }; + + MachineMappingResult result = infeasible_machine_mapping_result(); + AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); + + for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views + : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments(get_dst_layers(tensor_movement))) { - MachineMappingConstraints post_candidate = - with_additional_constraints( - restrict_to_right_child(constraints), - assigned_post_machine_views); - - MachineMappingResult post_result = ({ - std::optional returned - = get_optimal_machine_mapping(result_cache, - context, - get_post_child(series_split), - resources, - post_candidate); - if (!returned.has_value()) { - continue; - } - returned.value(); - }); + MachineMappingResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement(tensor_movement, /*pre_mapping=*/assigned_pre_machine_views, @@ -167,7 +166,7 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingProblemTree lhs = get_lhs_child(parallel_split); MachineMappingProblemTree rhs = get_rhs_child(parallel_split); - MachineMappingResult optimal_result = [&] { + MachineMappingResult series_result = [&] { MMProblemTreeSeriesSplit series_split = require_series_split(\ mm_problem_tree_make_series_split( /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), @@ -185,26 +184,28 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingConstraints left_constraints = restrict_to_left_child(constraints); MachineMappingConstraints right_constraints = restrict_to_right_child(constraints); - for (auto const &resource_split : get_machine_resource_splits(resources)) { + auto evaluate_resource_split = [&](std::pair const &resource_split) { MachineMappingResult left_result = get_optimal_machine_mapping(result_cache, - context, - lhs, - resource_split.first, - left_constraints); + context, + lhs, + resource_split.first, + left_constraints); MachineMappingResult right_result = get_optimal_machine_mapping(result_cache, - context, - rhs, - resource_split.second, - right_constraints); - - optimal_result = minimize_runtime( - optimal_result, - parallel_combine(left_result, right_result)); - } + context, + rhs, + resource_split.second, + right_constraints); - return optimal_result; + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results + = transform(get_machine_resource_splits(resources), + evaluate_resource_split); + + return minimize_runtime(series_result, get_mapping_with_minimal_runtime(parallel_results)); } MachineMappingResult get_optimal_machine_mapping( @@ -214,13 +215,26 @@ MachineMappingResult get_optimal_machine_mapping( MachineSpecification const &resource, MachineMappingConstraints const &constraints) { - MachineView machine_view = require_only_root(constraints).value(); - OpCostEstimateKey mapped = map_unmapped_op_cost_estimate_key(leaf, machine_view); - float cost = context.cost_estimator.estimate_cost(mapped); + std::unordered_set candidates = [&] { + std::optional machine_view = require_only_root(constraints); + if (machine_view.has_value()) { + return std::unordered_set{machine_view.value()}; + } else { + return context.allowed_machine_views(leaf, resource); + } + }(); + + auto get_mapping_result = [&](MachineView const &machine_view) { + OpCostEstimateKey mapped = map_unmapped_op_cost_estimate_key(leaf, machine_view); + float cost = context.cost_estimator.estimate_cost(mapped); + + return make_singleton_machine_mapping_result(cost, machine_view); + }; + + std::unordered_set candidate_results + = transform(candidates, get_mapping_result); - return make_singleton_machine_mapping_result - (/*runtime=*/cost, - /*machine_view=*/machine_view); + return get_mapping_with_minimal_runtime(candidate_results); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 804e6254e3..1e4de0a929 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -19,6 +19,16 @@ FeasibleMachineMappingResult require_feasible(MachineMappingResult const &result return result.raw_result.value(); } +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime(std::unordered_set const &candidates) { + MachineMappingResult result = infeasible_machine_mapping_result(); + + for (MachineMappingResult const &candidate : candidates) { + result = minimize_runtime(result, candidate); + } + + return result; +} + MachineMappingResult series_combine(float comm_cost, MachineMappingResult const &maybe_pre_result, MachineMappingResult const &maybe_post_result, diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index 27922d62dc..63035f5801 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -8,7 +8,7 @@ namespace FlexFlow { ParallelLayerGuidObliviousMachineMapping binary_combine_mappings(ParallelLayerGuidObliviousMachineMapping const &lhs, - ParallelLayerGuidObliviousMachineMapping const &rhs) { + ParallelLayerGuidObliviousMachineMapping const &rhs) { return ParallelLayerGuidObliviousMachineMapping{ merge_maps( map_keys(lhs.raw_mapping, nest_inside_left_child), diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index 78345398b9..df0245a4d2 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -24,19 +24,19 @@ SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - make_series_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_make_series_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - make_parallel_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { return PCGBinarySPDecomposition{ - make_leaf_node(l), + leaf_only_make_leaf_node(l), }; } diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 5a2a478214..c320900414 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -26,8 +26,6 @@ TEST_SUITE(FF_TEST_SUITE) { }, DataType::FLOAT, }; - ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ /*op_attrs=*/PCGOperatorAttrs{ @@ -59,18 +57,40 @@ TEST_SUITE(FF_TEST_SUITE) { /*create_gradients=*/CreateGrad::YES, }; - ParallelLayerAddedResult layer_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - ParallelLayerAddedResult layer_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); + SUBCASE("no edges across split") { + ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); + + PCGBinarySeriesSplit split = require_series(\ + make_pcg_series_split( + make_pcg_leaf_node(input1.parallel_layer), + make_pcg_leaf_node(input2.parallel_layer))); + + AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split + (pcg_get_transitive_reduction(pcg), + split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{}, + }; + + CHECK(result == correct); + } SUBCASE("single edge across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult layer_2 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + PCGBinarySeriesSplit split = require_series(\ make_pcg_series_split( make_pcg_series_split( @@ -102,6 +122,20 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult layer_3 = add_parallel_layer(pcg, ew_add_attrs, @@ -142,6 +176,20 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult layer_3 = add_parallel_layer(pcg, relu_attrs, @@ -185,16 +233,30 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("multiple tensors, multiple consumers across split") { - ParallelLayerAddedResult layer_3 + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(input.outputs)}, + {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer(pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_3 + = add_parallel_layer(pcg, + relu_attrs, + {get_only(layer_1.outputs)}, + {relu_output_attrs}); + ParallelLayerAddedResult layer_4 = add_parallel_layer(pcg, ew_add_attrs, - {get_only(layer_1.outputs), get_only(layer_3.outputs)}, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, {relu_output_attrs}); PCGBinarySeriesSplit split = require_series(make_pcg_series_split( @@ -202,9 +264,9 @@ TEST_SUITE(FF_TEST_SUITE) { make_pcg_leaf_node(input.parallel_layer), make_pcg_parallel_split( make_pcg_leaf_node(layer_1.parallel_layer), - make_pcg_leaf_node(layer_3.parallel_layer))), + make_pcg_leaf_node(layer_2.parallel_layer))), make_pcg_parallel_split( - make_pcg_leaf_node(layer_2.parallel_layer), + make_pcg_leaf_node(layer_3.parallel_layer), make_pcg_leaf_node(layer_4.parallel_layer)))); AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc index 75808b88b4..a660bf1db4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -1,7 +1,6 @@ #include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" -#include "utils/containers/flatmap.h" -#include "utils/containers/map_keys.h" namespace FlexFlow { @@ -39,19 +38,4 @@ CostEstimator make_fake_cost_estimator( }); } -CostEstimator make_fake_cost_estimator( - std::unordered_map> const &op_cost_map, - std::unordered_map const &comm_cost_map) { - - auto de_nest_key = [](UnmappedOpCostEstimateKey const &k1, std::unordered_map const &v) - -> std::unordered_map - { - return map_keys(v, [&](MachineView const &k2) { return map_unmapped_op_cost_estimate_key(k1, k2); }); - }; - - std::unordered_map mapped_costs = flatmap(op_cost_map, de_nest_key); - - return make_fake_cost_estimator(mapped_costs, comm_cost_map); -} - } // namespace FlexFlop diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index bdfe0d57cb..d3cc2e0f03 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -4,7 +4,9 @@ #include "compiler/cost_estimator/cost_estimator.h" #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" namespace FlexFlow { @@ -29,10 +31,6 @@ CostEstimator make_fake_cost_estimator( std::unordered_map const &op_cost_map, std::unordered_map const &comm_cost_map); -CostEstimator make_fake_cost_estimator( - std::unordered_map> const &op_cost_map, - std::unordered_map const &comm_cost_map); - } // namespace FlexFlow #endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 91ea973245..e813f5efff 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,8 +1,10 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "./cost_estimator_for_test.h" #include +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" @@ -14,20 +16,34 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping_internal") { MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); + MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); - auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, - MachineSpecification const &) { - return std::unordered_set{mv1}; + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, }; - MachineSpecification machine_spec = MachineSpecification{ - /*num_nodes=*/2, + MachineSpecification split_machine_spec = MachineSpecification{ + /*num_nodes=*/1, /*num_cpus_per_node=*/1, /*num_gpus_per_node=*/1, /*inter_node_bandwidth=*/1, /*intra_node_bandwidth=*/1, }; + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, + MachineSpecification const &resources) { + if (resources == full_machine_spec) { + return std::unordered_set{mv1, mv2}; + } else { + return std::unordered_set{mv2}; + } + }; + + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, /*input_shapes=*/{}, @@ -35,34 +51,76 @@ TEST_SUITE(FF_TEST_SUITE) { /*output_shapes=*/{}, }; - CostEstimator cost_estimator = make_fake_cost_estimator( - std::unordered_map>{ - { - k1, - { - {mv1, 1.0}, - } + UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + ParallelTensorShape tensor_shape1 = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, }, - }, - std::unordered_map{}); + }, + DataType::FLOAT, + }; + + AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, + }}; + + ParallelLayerGuidObliviousMachineMapping mm1 = ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, + {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, + {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, + {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), 0.4}, + }}); MachineMappingContext context = MachineMappingContext{ cost_estimator, allowed_machine_views1, }; + MachineMappingCache cache; + SUBCASE("single layer") { MachineMappingProblemTree problem_tree = mm_problem_tree_make_leaf(k1); - MachineMappingCache cache; - MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); MachineMappingResult result = get_optimal_machine_mapping(cache, context, problem_tree, - machine_spec, + full_machine_spec, constraints); MachineMappingResult correct = MachineMappingResult{ FeasibleMachineMappingResult{ @@ -77,199 +135,80 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("pair of layers in sequence") { - FAIL("TODO"); + MachineMappingProblemTree problem_tree = + mm_problem_tree_make_series_split( + movement1, + mm_problem_tree_make_leaf(k1), + mm_problem_tree_make_leaf(k2)); + + MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping(cache, + context, + problem_tree, + full_machine_spec, + constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + }; + + CHECK(result == correct); } SUBCASE("pair of layers in parallel") { - FAIL("TODO"); + MachineMappingProblemTree problem_tree = + mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(k1), + mm_problem_tree_make_leaf(k2)); + + MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping(cache, + context, + problem_tree, + full_machine_spec, + constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.5, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + }, + }; + + CHECK(result == correct); } SUBCASE("multiple edges across split") { FAIL("TODO"); } - - // SUBCASE("simple PCG") { - // - // ParallelComputationGraph pcg_simple = [&] { - // ParallelComputationGraphBuilder builder; - // - // ParallelTensorShape input_shape0 = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{32, 1}, - // ShardParallelDim{16, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // ParallelTensorShape input_shape1 = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{16, 1}, - // ShardParallelDim{8, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // parallel_tensor_guid_t input0 = - // builder.create_input_tensor(input_shape0); - // parallel_tensor_guid_t input1 = - // builder.create_input_tensor(input_shape1); - // parallel_tensor_guid_t dense0 = builder.batch_matmul(input0, input1); - // - // return builder.pcg; - // }(); - // - // MachineMappingResult result = - // get_optimal_machine_mapping(pcg_simple, - // allowed_machine_views1, - // estimator1, - // machine_spec1, - // cached_results1); - // - // CHECK(result.runtime == 3); - // } - - // SUBCASE("PCG is a chain") { - // ParallelComputationGraph pcg_chain = [&] { - // ParallelComputationGraphBuilder builder; - // - // ParallelTensorShape input_shape = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{16, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // parallel_tensor_guid_t input0 = - // builder.create_input_tensor(input_shape); - // parallel_tensor_guid_t layer1 = builder.identity(input0); - // parallel_tensor_guid_t layer2 = builder.identity(layer1); - // parallel_tensor_guid_t layer3 = builder.identity(layer2); - // parallel_tensor_guid_t layer4 = builder.identity(layer3); - // parallel_tensor_guid_t layer5 = builder.identity(layer4); - // parallel_tensor_guid_t layer6 = builder.identity(layer5); - // - // return builder.pcg; - // }(); - // - // MachineMappingResult result = - // get_optimal_machine_mapping(pcg_chain, - // allowed_machine_views1, - // estimator1, - // machine_spec1, - // cached_results1); - // CHECK(result.runtime == 13); - // } - // - // SUBCASE("PCG has multiple chains") { - // ParallelComputationGraph pcg_multiple_chains = [&] { - // ParallelComputationGraphBuilder builder; - // - // ParallelTensorShape input_shape0 = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{32, 1}, - // ShardParallelDim{16, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // ParallelTensorShape input_shape1 = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{16, 1}, - // ShardParallelDim{8, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // parallel_tensor_guid_t input0 = - // builder.create_input_tensor(input_shape0); - // parallel_tensor_guid_t input1 = - // builder.create_input_tensor(input_shape1); - // parallel_tensor_guid_t relu0 = builder.relu(input0); - // parallel_tensor_guid_t relu1 = builder.relu(input1); - // parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); - // - // return builder.pcg; - // }(); - // - // MachineMappingResult result = - // get_optimal_machine_mapping(pcg_multiple_chains, - // allowed_machine_views1, - // estimator1, - // machine_spec1, - // cached_results1); - // CHECK(result.runtime == 5); - // } - // - // SUBCASE("PCG is not sp-izable due to multiple inputs") { - // ParallelComputationGraph pcg_non_sp = [&] { - // ParallelComputationGraphBuilder builder; - // - // ParallelTensorShape input_shape = - // ParallelTensorShape{ParallelTensorDims{ - // FFOrdered{ - // ShardParallelDim{32, 2}, - // ShardParallelDim{16, 1}, - // }, - // ReplicaParallelDimSet{ - // SumDegree{1}, - // DiscardCopyDegree{1}, - // }, - // }, - // DataType::FLOAT}; - // - // parallel_tensor_guid_t input0 = - // builder.create_input_tensor(input_shape); - // parallel_tensor_guid_t dense0 = builder.dense(input0, 8); - // parallel_tensor_guid_t dense1 = builder.dense(input0, 4); - // parallel_tensor_guid_t dense2 = builder.dense(dense1, 8); - // parallel_tensor_guid_t add0 = builder.add(dense0, dense2); - // - // return builder.pcg; - // }(); - // - // // TODO: Handle this case in compiler - // // TODO: separate testcases for this too that actually check the graph - // // manipulation - // if (false) { - // MachineMappingResult result = - // get_optimal_machine_mapping(pcg_non_sp, - // allowed_machine_views1, - // estimator1, - // machine_spec1, - // cached_results1); - // CHECK(bool(result.runtime > 0)); - // CHECK(result.runtime == 7); - // } - // } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 10db2496f6..3717f164ac 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,100 +1,324 @@ #include "compiler/machine_mapping/machine_mapping_result.h" -#include "cost_estimator_for_test.h" -#include "doctest/doctest.h" +#include #include "pcg/machine_view.h" using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("sequential_combine") { - FAIL("TODO"); - // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - // MachineMapping machine_mapping_empty( - // std::unordered_map{}); - // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - // MachineMapping combined( - // {{layer0, machine_view_0}, {layer1, machine_view_1}}); - // MachineMappingResult s0(0, machine_mapping_empty); - // MachineMappingResult s1(1, machine_mapping_0); - // MachineMappingResult s2(2, machine_mapping_1); - // - // float comm_cost = 2.0; - // - // MachineMappingResult result0 = sequential_combine(s0, comm_cost, s1); - // CHECK(result0.runtime == 1); - // CHECK(result0.machine_mapping == machine_mapping_0); - // - // MachineMappingResult result1 = sequential_combine(s0, comm_cost, s2); - // CHECK(result1.runtime == 2); - // CHECK(result1.machine_mapping == machine_mapping_1); - // - // MachineMappingResult result2 = sequential_combine(s1, comm_cost, s2); - // CHECK(result2.runtime == 3); - // CHECK(result2.machine_mapping == combined); + TEST_CASE("series_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + float pre_cost = 2.0; + MachineMappingResult pre = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + float post_cost = 4.0; + MachineMappingResult post = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/post_cost, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + float comm_cost = 3.0; + + SUBCASE("pre is infeasbile") { + MachineMappingResult result = series_combine(comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("post is infeasbile") { + MachineMappingResult result = series_combine(comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = series_combine(comm_cost, infeasible, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult no_parallel_split_transform = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingResult result = series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingResult result = series_combine(comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingResult result = series_combine(comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } } TEST_CASE("parallel_combine") { - FAIL("TODO"); - // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - // MachineMapping machine_mapping_empty( - // std::unordered_map{}); - // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - // MachineMapping combined( - // {{layer0, machine_view_0}, {layer1, machine_view_1}}); - // MachineMappingResult s0(0, machine_mapping_empty); - // MachineMappingResult s1(1, machine_mapping_0); - // MachineMappingResult s2(2, machine_mapping_1); - // - // MachineMappingResult result0 = parallel_combine(s0, s1); - // CHECK(result0.runtime == 1); - // CHECK(result0.machine_mapping == machine_mapping_0); - // - // MachineMappingResult result1 = parallel_combine(s0, s2); - // CHECK(result1.runtime == 2); - // CHECK(result1.machine_mapping == machine_mapping_1); - // - // MachineMappingResult result2 = parallel_combine(s1, s2); - // CHECK(result2.runtime == 2); - // CHECK(result2.machine_mapping == combined); + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult lhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult rhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = parallel_combine(infeasible, rhs); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasbile") { + MachineMappingResult result = parallel_combine(lhs, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = parallel_combine(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult result = parallel_combine(lhs, rhs); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } } TEST_CASE("minimize_runtime") { - FAIL("TODO"); - // MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - // MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); - // parallel_layer_guid_t layer0 = parallel_layer_guid_t(Node(0)); - // parallel_layer_guid_t layer1 = parallel_layer_guid_t(Node(1)); - // MachineMapping machine_mapping_empty( - // std::unordered_map{}); - // MachineMapping machine_mapping_0({{layer0, machine_view_0}}); - // MachineMapping machine_mapping_1({{layer1, machine_view_1}}); - // MachineMapping combined( - // {{layer0, machine_view_0}, {layer1, machine_view_1}}); - // MachineMappingResult s0(0, machine_mapping_empty); - // MachineMappingResult s1(1, machine_mapping_0); - // MachineMappingResult s2(2, machine_mapping_1); - // - // MachineMappingResult _s0 = s0; - // MachineMappingResult _s1 = s1; - // MachineMappingResult _s2 = s2; - // - // minimize_runtime(_s0, _s1); - // CHECK(_s0 == s0); - // minimize_runtime(_s1, _s2); - // CHECK(_s1 == s1); - // - // minimize_runtime(_s1, _s0); - // CHECK(_s1 == s0); - // - // minimize_runtime(_s2, get_infinity_machine_mapping_result()); - // CHECK(_s2 == s2); + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult faster = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult slower = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = minimize_runtime(infeasible, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasible") { + MachineMappingResult result = minimize_runtime(slower, infeasible); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = minimize_runtime(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + SUBCASE("lhs is faster") { + MachineMappingResult result = minimize_runtime(faster, slower); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("rhs is faster") { + MachineMappingResult result = minimize_runtime(slower, faster); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs have the same speed") { + MachineMappingResult result = minimize_runtime(slower, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + } } } diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 4a565476bd..fa610ff9c2 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -489,11 +489,11 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " + throw mk_runtime_error(fmt::format("Invalid data type for input tensor 2 for Gather: " "{} (should be {} or {})", this->get_shape(input).data_type, DataType::INT32, - DataType::INT64); + DataType::INT64)); } TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h index 5980996c27..73ac61fcf7 100644 --- a/lib/utils/include/utils/containers/get_all_assignments.h +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -15,14 +15,12 @@ namespace FlexFlow { /** - * @note If \p options_per_key is empty, an empty set is returned from the - * function (not a set containing an empty set, as the "empty" assignment is - * not considered a valid assignment) + * @note If \p options_per_key is empty, an set containing a single empty assignment is returned */ template std::unordered_set> get_all_assignments(std::unordered_map> const &options_per_key) { if (options_per_key.empty()) { - return {}; + return {{}}; } std::vector ordered_keys = vector_of(keys(options_per_key)); diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index fedb87413d..88f33b52b6 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -10,8 +10,8 @@ namespace FlexFlow { template typename C::value_type get_only(C const &c) { return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error("Encountered container with size {} in get_only", - c.size()); + throw mk_runtime_error(fmt::format("Encountered container with size {} in get_only", + c.size())); }); } diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index 20a8098040..080cbb3611 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -34,12 +34,7 @@ T throw_if_unexpected(tl::expected const &r) { } } -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args) { - return std::runtime_error( - fmt::vformat(fmt_str, fmt::make_format_args(args...))); -} +std::runtime_error mk_runtime_error(std::string const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index 5f5e7d9f64..a9dcb17f0d 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -7,7 +7,7 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, +LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_series_split( @@ -18,7 +18,7 @@ LeafOnlyBinarySPDecompositionTree make_series_split(LeafOnlyBinarySPD } template -LeafOnlyBinarySPDecompositionTree make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, +LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_parallel_split( @@ -29,7 +29,7 @@ LeafOnlyBinarySPDecompositionTree make_parallel_split(LeafOnlyBinaryS } template -LeafOnlyBinarySPDecompositionTree make_leaf_node(LeafLabel const &label) { +LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(LeafLabel const &label) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_sp_leaf< LeafOnlyBinarySeriesSplitLabel, diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 6c66949fd8..07e4554299 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -135,7 +135,7 @@ auto seq_get(F const &f, int i, seq const &s) template auto seq_get(F const &f, int i, seq<> const &) -> decltype(f(std::declval>())) { - throw mk_runtime_error("Failed seq_get for index {}", i); + throw mk_runtime_error(fmt::format("Failed seq_get for index {}", i)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index afc16d4c4b..0296e365a3 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -67,8 +67,8 @@ template std::any get(std::tuple const &t, int idx) { size_t tuple_size = std::tuple_size::value; if (idx < 0 || idx >= tuple_size) { - throw mk_runtime_error( - "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size); + throw mk_runtime_error(fmt::format( + "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size)); } std::any result; visit_tuple(t, tuple_get_visitor{idx, result}); diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc index 9bbf780fd8..c645f241aa 100644 --- a/lib/utils/src/utils/exception.cc +++ b/lib/utils/src/utils/exception.cc @@ -1 +1,9 @@ #include "utils/exception.h" + +namespace FlexFlow { + +std::runtime_error mk_runtime_error(std::string const &s) { + return std::runtime_error(s); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc index da6cd1d493..72200ec483 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -7,8 +7,7 @@ std::unordered_set get_edges_from_subgraph_to_subgraph(DiGraphView std::unordered_set const &src_subgraph, std::unordered_set const &dst_subgraph) { if (!are_disjoint(src_subgraph, dst_subgraph)) { - throw mk_runtime_error(fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) expected src_subgraph and dst_subgraph to be disjoint, " - "but found src_subgraph={}, dst_subgraph={}", src_subgraph, dst_subgraph)); + throw mk_runtime_error(fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) expected src_subgraph and dst_subgraph to be disjoint, but found src_subgraph={}, dst_subgraph={}", src_subgraph, dst_subgraph)); } return g.query_edges(DirectedEdgeQuery{ diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 92a6d0b9eb..61c4f80763 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -23,12 +23,12 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.bigger); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.bigger)); } if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.smaller); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.smaller)); } this->adjacency.at(e.bigger).insert(e.smaller); diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 4aaa657821..2f51762db2 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -12,7 +12,7 @@ BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - make_series_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_make_series_split(lhs.raw_tree, rhs.raw_tree), }; } @@ -20,13 +20,13 @@ BinarySPDecompositionTree make_parallel_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - make_parallel_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } BinarySPDecompositionTree make_leaf_node(Node const &n) { return BinarySPDecompositionTree{ - make_leaf_node(n), + leaf_only_make_leaf_node(n), }; } diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc index 2b2810efe5..17a4e6e749 100644 --- a/lib/utils/test/src/utils/containers/get_all_assignments.cc +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_map> input = {}; std::unordered_set> result = get_all_assignments(input); - std::unordered_set> correct = {}; + std::unordered_set> correct = {{}}; CHECK(result == correct); } From 0c2ab052033062dd710f49b51e4e489dc7c05519 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 2 Oct 2024 22:06:04 -0700 Subject: [PATCH 21/29] Migrate over to use type-erased binary tree --- .../get_optimal_machine_mapping.h | 3 +- .../machine_mapping/machine_mapping_cache.h | 21 ++-- .../machine_mapping_cache.struct.toml | 22 ++++ .../machine_mapping_state.struct.toml | 2 +- .../get_optimal_machine_mapping.cc | 5 +- .../machine_mapping/machine_mapping_cache.cc | 20 ++-- .../get_optimal_machine_mapping.cc | 7 +- .../machine_mapping/machine_mapping_cache.cc | 91 ---------------- .../get_machine_mapping_problem_tree.cc | 2 +- .../full_binary_tree/find_paths_to_leaf.h | 26 +---- .../include/utils/full_binary_tree/fmt.h | 47 -------- .../utils/full_binary_tree/full_binary_tree.h | 87 --------------- .../full_binary_tree.struct.toml | 20 ++++ .../full_binary_tree_parent_node.struct.toml | 20 ++++ .../full_binary_tree_visitor.struct.toml | 22 ++++ .../full_binary_tree/get_all_leaf_paths.h | 27 +---- .../utils/full_binary_tree/get_child.h | 18 ++-- .../utils/full_binary_tree/get_label.h | 15 +++ .../utils/full_binary_tree/get_leaves.h | 25 ++--- .../utils/full_binary_tree/get_left_child.h | 9 +- .../utils/full_binary_tree/get_node_type.h | 17 +-- .../utils/full_binary_tree/get_right_child.h | 9 +- .../full_binary_tree/get_subtree_at_path.h | 30 ++---- .../include/utils/full_binary_tree/hash.h | 26 ----- .../include/utils/full_binary_tree/json.h | 2 +- .../include/utils/full_binary_tree/make.h | 26 +++++ .../raw_full_binary_tree/algorithms.h | 21 ++++ .../raw_full_binary_tree/any_value_type.h | 69 ++++++++++++ .../raw_full_binary_tree/raw_binary_tree.h | 62 +++++++++++ .../include/utils/full_binary_tree/require.h | 21 +++- .../utils/full_binary_tree/transform.h | 2 +- .../include/utils/full_binary_tree/visit.h | 29 +++-- ...ic_binary_parallel_split_label.struct.toml | 16 +++ ...eric_binary_series_split_label.struct.toml | 16 +++ ...c_binary_sp_decomposition_tree.struct.toml | 4 +- ...generic_binary_sp_split_label.variant.toml | 5 + .../get_node_type.h | 20 ++-- .../make.h | 24 ++--- .../require.h | 5 +- .../wrap.h | 22 ++-- .../raw_full_binary_tree/algorithms.cc | 83 ++++++++++++++ .../raw_full_binary_tree/any_value_type.cc | 33 ++++++ .../raw_full_binary_tree/raw_binary_tree.cc | 101 ++++++++++++++++++ .../get_num_tree_nodes.cc | 1 - 44 files changed, 666 insertions(+), 467 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml delete mode 100644 lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc delete mode 100644 lib/utils/include/utils/full_binary_tree/fmt.h delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/get_label.h delete mode 100644 lib/utils/include/utils/full_binary_tree/hash.h create mode 100644 lib/utils/include/utils/full_binary_tree/make.h create mode 100644 lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h create mode 100644 lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h create mode 100644 lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml create mode 100644 lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc create mode 100644 lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc create mode 100644 lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 9cc7db4da2..fc33845320 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H -#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" #include "pcg/machine_specification.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index fc00e7f26a..20cf75e69a 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -1,22 +1,13 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H -#include "compiler/machine_mapping/machine_mapping_state.dtg.h" -#include "compiler/machine_mapping/machine_mapping_result.dtg.h" -#include "utils/optional.h" +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" namespace FlexFlow { -class MachineMappingCache { -public: - MachineMappingCache() = default; - - std::optional load(MachineMappingState const &) const; - void save(MachineMappingState const &, MachineMappingResult const &); - -private: - std::unordered_map cache; -}; +MachineMappingCache empty_machine_mapping_cache(); +std::optional machine_mapping_cache_load(MachineMappingCache const &, MachineMappingState const &); +void machine_mapping_cache_save(MachineMappingCache &, MachineMappingState const &, MachineMappingResult const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml new file mode 100644 index 0000000000..a76ff26eb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MachineMappingCache" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/machine_mapping_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 4d4a29eac7..1346f6ebe7 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -17,7 +17,7 @@ name = "problem_tree" type = "::FlexFlow::MachineMappingProblemTree" [[fields]] -name = "resource" +name = "resources" type = "::FlexFlow::MachineSpecification" [[fields]] diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index b8b2f1d19c..0adf43681e 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" @@ -39,7 +40,7 @@ MachineMappingResult get_optimal_machine_mapping( { std::optional cached_result = - result_cache.load(state); + machine_mapping_cache_load(result_cache, state); if (cached_result) { return cached_result.value(); } @@ -67,7 +68,7 @@ MachineMappingResult get_optimal_machine_mapping( }, }); - result_cache.save(state, result); + machine_mapping_cache_save(result_cache, state, result); return result; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index 76d7fea7ff..c78f7fbf56 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,17 +1,23 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" #include "utils/containers/try_at.h" +#include "utils/containers/contains_key.h" namespace FlexFlow { -std::optional - MachineMappingCache::load(MachineMappingState const &state) const { - return try_at(this->cache, state); +MachineMappingCache empty_machine_mapping_cache() { + return MachineMappingCache{{}}; } -void MachineMappingCache::save(MachineMappingState const &state, - MachineMappingResult const &result) { - assert(!contains_key(cache, state)); - cache.emplace(state, result); +std::optional machine_mapping_cache_load(MachineMappingCache const &cache, MachineMappingState const &k) { + return try_at(cache.raw_map, k); +} + +void machine_mapping_cache_save(MachineMappingCache &cache, MachineMappingState const &k, MachineMappingResult const &v) { + if (contains_key(cache.raw_map, k)) { + throw mk_runtime_error(fmt::format("machine_mapping_cache_save expected key to not already exist, but received existing key {}", k)); + } + + cache.raw_map.emplace(k, v); } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index e813f5efff..b33e0e344d 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -2,6 +2,7 @@ #include "./cost_estimator_for_test.h" #include #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" @@ -109,7 +110,7 @@ TEST_SUITE(FF_TEST_SUITE) { allowed_machine_views1, }; - MachineMappingCache cache; + MachineMappingCache cache = empty_machine_mapping_cache(); SUBCASE("single layer") { MachineMappingProblemTree problem_tree = @@ -206,9 +207,5 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - - SUBCASE("multiple edges across split") { - FAIL("TODO"); - } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc deleted file mode 100644 index fc521e110c..0000000000 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_cache.cc +++ /dev/null @@ -1,91 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_cache.h" -#include "./cost_estimator_for_test.h" -#include "doctest/doctest.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MachineMappingCache") { - ParallelComputationGraph pcg = [&] { - ParallelComputationGraphBuilder builder; - - ParallelTensorShape input_shape0 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{32, 1}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - ParallelTensorShape input_shape1 = - ParallelTensorShape{ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{32, 2}, - ShardParallelDim{16, 1}, - ShardParallelDim{8, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT}; - - parallel_tensor_guid_t input0 = builder.create_input_tensor(input_shape0); - parallel_tensor_guid_t input1 = builder.create_input_tensor(input_shape1); - parallel_tensor_guid_t relu0 = builder.relu(input0); - parallel_tensor_guid_t relu1 = builder.relu(input1); - parallel_tensor_guid_t matmul0 = builder.batch_matmul(relu0, relu1); - - return builder.pcg; - }(); - - FAIL("TODO"); - // SerialParallelDecomposition subgraph0 = - // get_serial_parallel_decomposition(pcg.raw_graph).value(); - // auto [subgraph1, subgraph2] = - // split_sp_decomposition(subgraph0.get()); - // - // MachineSpecification machine_spec(1, 1, 1, 1, 1); - // MachineMappingState state0(subgraph0, machine_spec, {}); - // MachineMappingState state1(subgraph1, machine_spec, {}); - // MachineMappingState state2(subgraph2, machine_spec, {}); - // - // MachineMappingResult result0( - // 2, - // MachineMapping( - // std::unordered_map{})); - // MachineMappingResult result1( - // 1, - // MachineMapping( - // std::unordered_map{})); - // MachineMappingResult result2( - // 1, - // MachineMapping( - // std::unordered_map{})); - // - // MachineMappingCache cache; - // - // cache.save(state0, result0); - // CHECK(cache.load(state0).value() == result0); - // CHECK(!cache.load(state1)); - // CHECK(!cache.load(state2)); - // - // cache.save(state1, result1); - // CHECK(cache.load(state0).value() == result0); - // CHECK(cache.load(state1).value() == result1); - // CHECK(!cache.load(state2)); - // - // cache.save(state2, result2); - // CHECK(cache.load(state0).value() == result0); - // CHECK(cache.load(state1).value() == result1); - // CHECK(cache.load(state2).value() == result2); - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 0a107c2682..c828a9c164 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -136,7 +136,7 @@ TEST_SUITE(FF_TEST_SUITE) { UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); PCGBinarySPDecomposition sp_decomposition = \ - make_pcg_series_split( + make_pcg_parallel_split( make_pcg_leaf_node(input1_layer), make_pcg_leaf_node(input2_layer)); diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 4410f06e67..833013d6f6 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -2,37 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/visit.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" #include -#include "utils/overload.h" -#include "utils/containers/transform.h" -#include "utils/containers/set_union.h" namespace FlexFlow { template std::unordered_set find_paths_to_leaf(FullBinaryTree const &tree, LeafLabel const &leaf) { - return visit>( - tree, - overload { - [&](LeafLabel const &l) -> std::unordered_set { - if (l == leaf) { - return {binary_tree_root_path()}; - } else { - return {}; - } - }, - [&](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(find_paths_to_leaf(get_left_child(parent), leaf), - nest_inside_left_child), - transform(find_paths_to_leaf(get_right_child(parent), leaf), - nest_inside_right_child)); - } - }); + return find_paths_to_leaf(tree.raw_tree, make_any_value_type(leaf)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h deleted file mode 100644 index ff4f54e95d..0000000000 --- a/lib/utils/include/utils/full_binary_tree/fmt.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H - -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" -#include - -namespace FlexFlow { - -template -std::string format_as(FullBinaryTreeParentNode const &t) { - return fmt::format("<{} ({} {})>", - t.label, - get_left_child(t), - get_right_child(t)); -} - -template -std::string format_as(FullBinaryTree const &t) { - return visit( - t, - overload{ - [](FullBinaryTreeParentNode const &parent) { - return fmt::to_string(parent); - }, - [](LeafLabel const &leaf) { - return fmt::format("{}", leaf); - }, - }); -} - -template -std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode const &t) { - return (s << fmt::to_string(t)); -} - -template -std::ostream &operator<<(std::ostream &s, FullBinaryTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h deleted file mode 100644 index f90ffb88c4..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct FullBinaryTree; - -template -struct FullBinaryTreeParentNode { - explicit FullBinaryTreeParentNode( - ParentLabel const &label, - FullBinaryTree const &lhs, - FullBinaryTree const &rhs) - : label(label), - left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) - { } - - FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; - - bool operator==(FullBinaryTreeParentNode const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(FullBinaryTreeParentNode const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(FullBinaryTreeParentNode const &other) const { - return this->tie() < other.tie(); - } -public: - ParentLabel label; - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; -private: - std::tuple const &, - FullBinaryTree const &> - tie() const { - return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct FullBinaryTree { -public: - FullBinaryTree() = delete; - explicit FullBinaryTree(FullBinaryTreeParentNode const &t) - : root{t} {} - - explicit FullBinaryTree(LeafLabel const &t) - : root{t} {} - - bool operator==(FullBinaryTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(FullBinaryTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(FullBinaryTree const &other) const { - return this->tie() < other.tie(); - } -public: - std::variant, LeafLabel> root; -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml new file mode 100644 index 0000000000..aa9a1d8574 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "FullBinaryTree" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "ParentLabel", + "LeafLabel", +] + +includes = [ + "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml new file mode 100644 index 0000000000..277405a23c --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeParentNode" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "ParentLabel", + "LeafLabel", +] + +includes = [ + "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml new file mode 100644 index 0000000000..0849ba2683 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeVisitor" +features = [] + +template_params = [ + "Result", + "ParentLabel", + "LeafLabel", +] + +includes = [ + "", + "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h", +] + +[[fields]] +name = "parent_func" +type = "std::function const &)>" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 926cc0ea9c..4076447f57 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -2,36 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/visit.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" #include -#include "utils/overload.h" -#include "utils/containers/set_union.h" -#include "utils/containers/transform.h" +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" namespace FlexFlow { template std::unordered_set get_all_leaf_paths(FullBinaryTree const &tree) { - return visit> - (tree, - overload { - [](LeafLabel const &) { - return std::unordered_set{binary_tree_root_path()}; - }, - [](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(get_all_leaf_paths(get_left_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(get_right_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } - }); + return get_all_leaf_paths(tree.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index 8f9f76f49d..675e385ca3 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -1,24 +1,18 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" namespace FlexFlow { template FullBinaryTree get_child(FullBinaryTreeParentNode const &t, BinaryTreePathEntry const &e) { - switch (e) { - case BinaryTreePathEntry::LEFT_CHILD: - return get_left_child(t); - case BinaryTreePathEntry::RIGHT_CHILD: - return get_right_child(t); - default: - throw mk_runtime_error(fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); - } + return FullBinaryTreeParentNode{ + get_child(t.raw_tree, e), + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_label.h b/lib/utils/include/utils/full_binary_tree/get_label.h new file mode 100644 index 0000000000..9f0099e609 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_label.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H + +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" + +namespace FlexFlow { + +template +ParentLabel get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &p) { + return p.raw_tree.label.template get(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index c58a850a6d..41fea3c5c2 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -1,28 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" -#include -#include "utils/containers/multiset_union.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(FullBinaryTree const &t) { - return visit>( - t, - overload { - [](FullBinaryTreeParentNode const &parent) { - return multiset_union(get_leaves(get_left_child(parent)), - get_leaves(get_right_child(parent))); - }, - [](ChildLabel const &leaf) { - return std::unordered_multiset{leaf}; - } - }); +template +std::unordered_multiset + get_leaves(FullBinaryTree const &t) { + return transform(get_leaves(t.raw_tree), [](any_value_type const &v) { return v.get(); }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h index 163503abfd..394b9042fe 100644 --- a/lib/utils/include/utils/full_binary_tree/get_left_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -1,13 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" namespace FlexFlow { template -FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { - return *t.left_child_ptr; +FullBinaryTree get_left_child(FullBinaryTreeParentNode const &t) { + return FullBinaryTree{ + t.raw_tree.left_child(), + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h index e1cbe909d5..5d2c613101 100644 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -1,25 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H -#include "utils/overload.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" namespace FlexFlow { template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { - return visit( - t, - overload { - [](FullBinaryTreeParentNode const &) { - return FullBinaryTreeNodeType::PARENT; - }, - [](LeafLabel const &) { - return FullBinaryTreeNodeType::LEAF; - } - }); + return get_node_type(t.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h index e40f2024a1..957ddbede8 100644 --- a/lib/utils/include/utils/full_binary_tree/get_right_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -1,13 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" namespace FlexFlow { template -FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { - return *t.right_child_ptr; +FullBinaryTree get_right_child(FullBinaryTreeParentNode const &t) { + return FullBinaryTree{ + t.raw_tree.right_child(), + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index 6909d9e1ef..59d24b6aad 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -2,35 +2,19 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_child.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" -#include +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" namespace FlexFlow { template std::optional> get_subtree_at_path(FullBinaryTree const &t, BinaryTreePath const &p) { - if (p == binary_tree_root_path()) { - return t; - } - - return visit>>( - t, - overload { - [&](FullBinaryTreeParentNode const &parent) { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - - return get_subtree_at_path(get_child(parent, curr), rest); - }, - [&](LeafLabel const &leaf) { - return std::nullopt; - } - }); + return transform(get_subtree_at_path(t.raw_tree, p), + [](RawBinaryTree const &raw) { + return FullBinaryTree{raw}; + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h deleted file mode 100644 index a29836f972..0000000000 --- a/lib/utils/include/utils/full_binary_tree/hash.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H - -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::FullBinaryTreeParentNode> { - size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { - return get_std_hash(t.tie()); - } -}; - -template -struct hash<::FlexFlow::FullBinaryTree> { - size_t operator()(::FlexFlow::FullBinaryTree const &t) const { - return get_std_hash(t.tie()); - } -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h index 0d830890dc..585c05813e 100644 --- a/lib/utils/include/utils/full_binary_tree/json.h +++ b/lib/utils/include/utils/full_binary_tree/json.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H #include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" #include "utils/full_binary_tree/visit.h" diff --git a/lib/utils/include/utils/full_binary_tree/make.h b/lib/utils/include/utils/full_binary_tree/make.h new file mode 100644 index 0000000000..ac458f0f4d --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/make.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H + +#include "utils/full_binary_tree/full_binary_tree.dtg.h" + +namespace FlexFlow { + +template +FullBinaryTree make_full_binary_tree_parent(ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) { + return FullBinaryTree{ + raw_binary_tree_make_parent(make_any_value_type(label), lhs.raw_tree, rhs.raw_tree), + }; +} + +template +FullBinaryTree make_full_binary_tree_leaf(LeafLabel const &label) { + return FullBinaryTree{ + raw_binary_tree_make_leaf(make_any_value_type(label)), + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h new file mode 100644 index 0000000000..6d0d77caa9 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" +#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" +#include + +namespace FlexFlow { + +RawBinaryTree get_child(RawBinaryTree const &, BinaryTreePathEntry const &); +std::unordered_set get_all_leaf_paths(RawBinaryTree const &); +std::unordered_set find_paths_to_leaf(RawBinaryTree const &, any_value_type const &leaf); +std::unordered_multiset get_leaves(RawBinaryTree const &); +FullBinaryTreeNodeType get_node_type(RawBinaryTree const &); +std::optional get_subtree_at_path(RawBinaryTree const &, BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h new file mode 100644 index 0000000000..8cd7d62101 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h @@ -0,0 +1,69 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H + +#include +#include +#include +#include + +namespace FlexFlow { + +struct any_value_type { +public: + any_value_type(std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string); + + bool operator==(any_value_type const &other) const; + bool operator!=(any_value_type const &other) const; + + template + T get() const { + return std::any_cast(value); + } + + friend std::string format_as(any_value_type const &); +private: + std::any value; + std::function eq; + std::function neq; + std::function hash; + std::function to_string; + + friend std::hash; +}; + + +template +any_value_type make_any_value_type(T const &t) { + return any_value_type{ + std::make_any(t), + [](std::any const &l, std::any const &r) { + return std::any_cast(l) == std::any_cast(r); + }, + [](std::any const &l, std::any const &r) { + return std::any_cast(l) != std::any_cast(r); + }, + [](std::any const &v) { + return std::hash{}(std::any_cast(v)); + }, + [](std::any const &v) { + return fmt::to_string(std::any_cast(v)); + }, + }; +} + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::any_value_type> { + size_t operator()(::FlexFlow::any_value_type const &) const; +}; + +} + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h new file mode 100644 index 0000000000..0bebe12109 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h @@ -0,0 +1,62 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H + +#include +#include +#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" +#include + +namespace FlexFlow { + +struct RawBinaryTree { + explicit RawBinaryTree( + any_value_type const &label, + RawBinaryTree const &lhs, + RawBinaryTree const &rhs); + explicit RawBinaryTree( + any_value_type const &label); + + RawBinaryTree(RawBinaryTree const &) = default; + + bool operator==(RawBinaryTree const &) const; + bool operator!=(RawBinaryTree const &) const; + + RawBinaryTree const &left_child() const; + RawBinaryTree const &right_child() const; + + bool is_leaf() const; +public: + any_value_type label; + std::shared_ptr left_child_ptr; + std::shared_ptr right_child_ptr; +private: + std::tuple, + std::optional> + value_tie() const; + std::tuple const &, + std::shared_ptr const &> + ptr_tie() const; + + friend std::hash; +}; + +std::string format_as(RawBinaryTree const &); +std::ostream &operator<<(std::ostream &, RawBinaryTree const &); + +RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label); +RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::RawBinaryTree> { + size_t operator()(::FlexFlow::RawBinaryTree const &) const; +}; + +} + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h index 0e5ad4914a..f897908c86 100644 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -1,18 +1,29 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" namespace FlexFlow { template -FullBinaryTreeParentNode const &require_parent_node(FullBinaryTree const &t) { - return std::get>(t.root); +FullBinaryTreeParentNode require_parent_node(FullBinaryTree const &t) { + if (t.raw_tree.is_leaf()) { + throw mk_runtime_error(fmt::format("require_parent_node called on leaf node {}", t)); + } + + return FullBinaryTreeParentNode{ + t.raw_tree, + }; } template -LeafLabel const &require_leaf(FullBinaryTree const &t) { - return std::get(t.root); +LeafLabel require_leaf(FullBinaryTree const &t) { + if (!t.raw_tree.is_leaf()) { + throw mk_runtime_error(fmt::format("require_leaf called on non-leaf node {}", t)); + } + + return t.raw_tree.label.template get(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/transform.h b/lib/utils/include/utils/full_binary_tree/transform.h index 3fef8efd18..52ed07f7ba 100644 --- a/lib/utils/include/utils/full_binary_tree/transform.h +++ b/lib/utils/include/utils/full_binary_tree/transform.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" #include "utils/overload.h" diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 978eba4d74..ea5729bd6c 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -1,22 +1,31 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_node_type.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/exception.h" namespace FlexFlow { template Result visit(FullBinaryTree const &tt, F f) { - if (std::holds_alternative>(tt.root)) { - Result result = f(std::get>(tt.root)); - return result; - } else if (std::holds_alternative(tt.root)) { - Result result = f(std::get(tt.root)); - return result; - } else { - throw mk_runtime_error( - "Unexpected case in visit(FullBinaryTree)"); + auto visitor = FullBinaryTreeVisitor{ + f, f + }; + + return visit(tt, visitor); +} + +template +Result visit(FullBinaryTree const &t, FullBinaryTreeVisitor const &v) { + FullBinaryTreeNodeType node_type = get_node_type(t); + switch (node_type) { + case FullBinaryTreeNodeType::PARENT: + return v.parent_func(require_parent_node(t)); + case FullBinaryTreeNodeType::LEAF: + return v.leaf_func(require_leaf(t)); + default: + throw mk_runtime_error(fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml new file mode 100644 index 0000000000..d187b7c93a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "GenericBinaryParallelSplitLabel" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +template_params = [ + "ParallelSplitLabel" +] + +[[fields]] +name = "raw_label" +type = "ParallelSplitLabel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml new file mode 100644 index 0000000000..74e00ada81 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "GenericBinarySeriesSplitLabel" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +template_params = [ + "SeriesSplitLabel" +] + +[[fields]] +name = "raw_label" +type = "SeriesSplitLabel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml index 82f93a9197..9734912f35 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -14,13 +14,11 @@ template_params = [ ] includes = [ - "utils/full_binary_tree/full_binary_tree.h", + "utils/full_binary_tree/full_binary_tree.dtg.h", "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h", ] src_includes = [ - "utils/full_binary_tree/hash.h", - "utils/full_binary_tree/fmt.h", "utils/full_binary_tree/json.h", ] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml index 17920c180e..c50a7b878b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml @@ -12,6 +12,11 @@ template_params = [ "ParallelSplitLabel", ] +# includes = [ +# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", +# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", +# ] + [[values]] type = "SeriesSplitLabel" key = "series" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 9c3ad0daeb..c46be1c651 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -12,16 +12,16 @@ namespace FlexFlow { template SPDecompositionTreeNodeType get_node_type(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt.raw_tree, - overload { - [](LeafLabel const &) { - return SPDecompositionTreeNodeType::NODE; - }, - [](FullBinaryTreeParentNode, LeafLabel> const &parent) { - return get_node_type(parent.label); - }, - }); + auto visitor = FullBinaryTreeVisitor, LeafLabel>{ + [](FullBinaryTreeParentNode, LeafLabel> const &parent) { + return get_node_type(get_full_binary_tree_parent_label(parent)); + }, + [](LeafLabel const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }; + + return visit(tt.raw_tree, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index 2ae89462bb..20ea7e744e 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/full_binary_tree/make.h" namespace FlexFlow { @@ -11,13 +12,10 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &lhs, GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ - FullBinaryTreeParentNode, LeafLabel>{ + make_full_binary_tree_parent( GenericBinarySPSplitLabel{label}, lhs.raw_tree, - rhs.raw_tree, - } - } + rhs.raw_tree), }; } @@ -27,22 +25,18 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &lhs, GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ - FullBinaryTreeParentNode, LeafLabel>{ - GenericBinarySPSplitLabel{label}, - lhs.raw_tree, - rhs.raw_tree, - } - } + make_full_binary_tree_parent( + GenericBinarySPSplitLabel{label}, + lhs.raw_tree, + rhs.raw_tree), }; } template GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(LeafLabel const &leaf) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel>{ - leaf, - }, + make_full_binary_tree_leaf>( + leaf), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 9a93ae8d6a..b8b18c4125 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -4,6 +4,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" #include "utils/full_binary_tree/require.h" +#include "utils/full_binary_tree/get_label.h" namespace FlexFlow { @@ -13,7 +14,7 @@ GenericBinarySeriesSplit FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); return GenericBinarySeriesSplit{ - /*label=*/parent.label.template get(), + /*label=*/get_full_binary_tree_parent_label(parent).template get(), /*pre=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, @@ -29,7 +30,7 @@ GenericBinaryParallelSplit FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); return GenericBinaryParallelSplit{ - /*label=*/parent.label.template get(), + /*label=*/get_full_binary_tree_parent_label(parent).template get(), /*lhs=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h index ec0a45f83a..0b8e53bab6 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h @@ -11,13 +11,10 @@ template GenericBinarySPDecompositionTree wrap_series_split(GenericBinarySeriesSplit const &series_split) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel> { - FullBinaryTreeParentNode, LeafLabel> { - /*label=*/GenericBinarySPSplitLabel{series_split.label}, - /*lhs=*/series_split.pre.raw_tree, - /*rhs=*/series_split.post.raw_tree, - }, - }, + make_full_binary_tree_parent( + /*label=*/GenericBinarySPSplitLabel{series_split.label}, + /*lhs=*/series_split.pre.raw_tree, + /*rhs=*/series_split.post.raw_tree), }; } @@ -25,13 +22,10 @@ template GenericBinarySPDecompositionTree wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { return GenericBinarySPDecompositionTree{ - FullBinaryTree, LeafLabel> { - FullBinaryTreeParentNode, LeafLabel> { - /*label=*/GenericBinarySPSplitLabel{parallel_split.label}, - /*lhs=*/parallel_split.lhs.raw_tree, - /*rhs=*/parallel_split.rhs.raw_tree, - }, - }, + make_full_binary_tree_parent( + /*label=*/GenericBinarySPSplitLabel{parallel_split.label}, + /*lhs=*/parallel_split.lhs.raw_tree, + /*rhs=*/parallel_split.rhs.raw_tree), }; } diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc new file mode 100644 index 0000000000..bc833f95d4 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc @@ -0,0 +1,83 @@ +#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/transform.h" +#include "utils/containers/set_union.h" +#include "utils/containers/multiset_union.h" + +namespace FlexFlow { + +RawBinaryTree get_child(RawBinaryTree const &t, BinaryTreePathEntry const &e) { + if (e == BinaryTreePathEntry::LEFT_CHILD) { + return t.left_child(); + } else { + assert (e == BinaryTreePathEntry::RIGHT_CHILD); + return t.right_child(); + } +} + +std::unordered_set get_all_leaf_paths(RawBinaryTree const &t) { + if (t.is_leaf()) { + return {binary_tree_root_path()}; + } else { + return set_union( + transform(get_all_leaf_paths(t.left_child()), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(t.right_child()), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + } +} + +std::unordered_set find_paths_to_leaf(RawBinaryTree const &t, any_value_type const &leaf) { + if (t.is_leaf()) { + if (t.label == leaf) { + return {binary_tree_root_path()}; + } else { + return {}; + } + } else { + return set_union( + transform(find_paths_to_leaf(t.left_child(), leaf), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(find_paths_to_leaf(t.right_child(), leaf), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + } +} + +std::unordered_multiset get_leaves(RawBinaryTree const &t) { + if (t.is_leaf()) { + return {t.label}; + } else { + return multiset_union(get_leaves(t.left_child()), get_leaves(t.right_child())); + } +} + +FullBinaryTreeNodeType get_node_type(RawBinaryTree const &t) { + if (t.is_leaf()) { + return FullBinaryTreeNodeType::LEAF; + } else { + return FullBinaryTreeNodeType::PARENT; + } +} + +std::optional get_subtree_at_path(RawBinaryTree const &t, BinaryTreePath const &p) { + if (p == binary_tree_root_path()) { + return t; + } else if (t.is_leaf()) { + return std::nullopt; + } else { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(t, curr), rest); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc new file mode 100644 index 0000000000..d54796ae49 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc @@ -0,0 +1,33 @@ +#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" + +namespace FlexFlow { + +any_value_type::any_value_type(std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string) + : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) +{} + +bool any_value_type::operator==(any_value_type const &other) const { + return this->eq(this->value, other.value); +} + +bool any_value_type::operator!=(any_value_type const &other) const { + return this->neq(this->value, other.value); +} + +std::string format_as(any_value_type const &v) { + return v.to_string(v.value); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::any_value_type>::operator()(::FlexFlow::any_value_type const &v) const { + return v.hash(v); +} + +} // namespace std diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc new file mode 100644 index 0000000000..d432d32eb9 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc @@ -0,0 +1,101 @@ +#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace FlexFlow { + +RawBinaryTree::RawBinaryTree( + any_value_type const &label, + RawBinaryTree const &lhs, + RawBinaryTree const &rhs) + : label(label), + left_child_ptr(std::make_shared(lhs)), + right_child_ptr(std::make_shared(rhs)) +{ } + +RawBinaryTree::RawBinaryTree( + any_value_type const &label) + : label(label), left_child_ptr(nullptr), right_child_ptr(nullptr) +{ } + +bool RawBinaryTree::operator==(RawBinaryTree const &other) const { + if (this->ptr_tie() == other.ptr_tie()) { + return true; + } + + return (this->value_tie() == other.value_tie()); +} + +bool RawBinaryTree::operator!=(RawBinaryTree const &other) const { + if (this->ptr_tie() == other.ptr_tie()) { + return false; + } + + return (this->value_tie() != other.value_tie()); +} + +RawBinaryTree const &RawBinaryTree::left_child() const { + return *this->left_child_ptr; +} + +RawBinaryTree const &RawBinaryTree::right_child() const { + return *this->right_child_ptr; +} + +bool RawBinaryTree::is_leaf() const { + return this->left_child_ptr == nullptr && this->right_child_ptr == nullptr; +} + +std::tuple, + std::optional> + RawBinaryTree::value_tie() const { + + auto ptr_to_optional = [](std::shared_ptr const &ptr) + -> std::optional { + if (ptr == nullptr) { + return std::nullopt; + } else { + return *ptr; + } + }; + + return {this->label, ptr_to_optional(this->left_child_ptr), ptr_to_optional(this->right_child_ptr)}; +} + +std::tuple const &, + std::shared_ptr const &> + RawBinaryTree::ptr_tie() const { + return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); +} + +std::string format_as(RawBinaryTree const &t) { + if (t.is_leaf()) { + return fmt::to_string(t.label); + } else { + return fmt::format("({} {} {})", t.label, t.left_child(), t.right_child()); + } +} + +std::ostream &operator<<(std::ostream &s, RawBinaryTree const &t) { + return (s << fmt::to_string(t)); +} + +RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label) { + return RawBinaryTree{label}; +} + +RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs) { + return RawBinaryTree{label, lhs, rhs}; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::RawBinaryTree>::operator()(::FlexFlow::RawBinaryTree const &t) const { + return ::FlexFlow::get_std_hash(t.value_tie()); +} + +} // namespace std diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 3a7d13c2a8..0d36ccbe92 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -6,7 +6,6 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { - FAIL("TODO"); // SUBCASE("leaf") { // GenericBinarySPDecompositionTree input = // make_generic_binary_sp_leaf(5); From e4073bc28e4c54c649a2f1b7bf73a6a37fbab266 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 3 Oct 2024 13:48:25 -0700 Subject: [PATCH 22/29] Move back to templated FullBinaryTree --- .../machine_mapping_problem_tree.cc | 6 +- ...mputation_graph_binary_sp_decomposition.cc | 2 +- .../pcg_binary_sp_decomposition.cc | 6 +- .../any_value_type.h | 6 +- lib/utils/include/utils/fmt/monostate.h | 32 ++++++ .../full_binary_tree/find_paths_to_leaf.h | 26 ++++- .../include/utils/full_binary_tree/fmt.h | 47 ++++++++ .../utils/full_binary_tree/full_binary_tree.h | 102 ++++++++++++++++++ .../full_binary_tree.struct.toml | 20 ---- .../full_binary_tree_parent_node.struct.toml | 20 ---- .../full_binary_tree_visitor.struct.toml | 2 +- .../full_binary_tree/get_all_leaf_paths.h | 27 ++++- .../utils/full_binary_tree/get_child.h | 20 ++-- .../utils/full_binary_tree/get_label.h | 4 +- .../utils/full_binary_tree/get_leaves.h | 25 +++-- .../utils/full_binary_tree/get_left_child.h | 9 +- .../utils/full_binary_tree/get_node_type.h | 13 ++- .../utils/full_binary_tree/get_right_child.h | 9 +- .../full_binary_tree/get_subtree_at_path.h | 30 ++++-- .../include/utils/full_binary_tree/hash.h | 26 +++++ .../include/utils/full_binary_tree/json.h | 2 +- .../include/utils/full_binary_tree/make.h | 10 +- .../raw_full_binary_tree/algorithms.h | 21 ---- .../raw_full_binary_tree/raw_binary_tree.h | 62 ----------- .../include/utils/full_binary_tree/require.h | 21 +--- .../include/utils/full_binary_tree/visit.h | 5 +- ...c_binary_sp_decomposition_tree.struct.toml | 4 +- .../generic_binary_sp_split_label.h | 32 +++++- ...generic_binary_sp_split_label.variant.toml | 12 +-- .../get_node_type.h | 1 + .../make.h | 5 +- .../require.h | 17 +-- .../visit.h | 6 +- .../wrap.h | 6 +- ...ly_binary_parallel_split_label.struct.toml | 12 --- ...only_binary_series_split_label.struct.toml | 12 --- ...y_binary_sp_decomposition_tree.struct.toml | 9 +- .../make.h | 8 +- .../require.h | 20 ++-- .../transform.h | 12 +-- .../wrap.h | 12 +-- .../any_value_type.cc | 2 +- lib/utils/src/utils/fmt/monostate.cc | 9 ++ lib/utils/src/utils/full_binary_tree/fmt.cc | 10 ++ .../src/utils/full_binary_tree/get_label.cc | 8 ++ .../utils/full_binary_tree/get_node_type.cc | 7 ++ lib/utils/src/utils/full_binary_tree/make.cc | 12 +++ .../raw_full_binary_tree/algorithms.cc | 83 -------------- .../raw_full_binary_tree/raw_binary_tree.cc | 101 ----------------- .../src/utils/full_binary_tree/require.cc | 11 ++ lib/utils/src/utils/full_binary_tree/visit.cc | 8 ++ .../binary_sp_decomposition_tree.cc | 6 +- .../find_paths_to_leaf.cc | 9 ++ .../generic_binary_sp_split_label.cc | 16 +++ .../get_all_leaf_paths.cc | 8 ++ .../get_leaves.cc | 12 +++ .../get_left_child.cc | 11 ++ .../get_node_type.cc | 8 ++ .../get_num_tree_nodes.cc | 11 ++ .../get_right_child.cc | 11 ++ .../get_subtree_at_path.cc | 10 ++ .../is.cc | 11 ++ .../is_binary_sp_tree_left_associative.cc | 8 ++ .../is_binary_sp_tree_right_associative.cc | 8 ++ .../make.cc | 17 +++ .../require.cc | 13 +++ .../transform.cc | 9 ++ .../wrap.cc | 12 +++ .../get_leaves.cc | 8 ++ .../get_node_type.cc | 8 ++ .../is_binary_sp_tree_left_associative.cc | 9 ++ .../is_binary_sp_tree_right_associative.cc | 8 ++ .../leaf_only_binary_parallel_split.cc | 10 ++ .../leaf_only_binary_series_split.cc | 10 ++ .../make.cc | 14 +++ .../require.cc | 9 ++ .../transform.cc | 16 +++ .../wrap.cc | 10 ++ 78 files changed, 792 insertions(+), 462 deletions(-) rename lib/utils/include/utils/{full_binary_tree/raw_full_binary_tree => any_value_type}/any_value_type.h (89%) create mode 100644 lib/utils/include/utils/fmt/monostate.h create mode 100644 lib/utils/include/utils/full_binary_tree/fmt.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.h delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/hash.h delete mode 100644 lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h delete mode 100644 lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml rename lib/utils/src/utils/{full_binary_tree/raw_full_binary_tree => any_value_type}/any_value_type.cc (93%) create mode 100644 lib/utils/src/utils/fmt/monostate.cc create mode 100644 lib/utils/src/utils/full_binary_tree/fmt.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_label.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_node_type.cc create mode 100644 lib/utils/src/utils/full_binary_tree/make.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc create mode 100644 lib/utils/src/utils/full_binary_tree/require.cc create mode 100644 lib/utils/src/utils/full_binary_tree/visit.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 6b75d3943b..6d14fbe3cf 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -46,18 +46,18 @@ SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { return MMProblemTreeSeriesSplit{ - require_series(t.raw_tree), + require_generic_binary_series_split(t.raw_tree), }; } MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { return MMProblemTreeParallelSplit{ - require_parallel(t.raw_tree), + require_generic_binary_parallel_split(t.raw_tree), }; } UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { - return require_leaf(t.raw_tree); + return require_generic_binary_leaf(t.raw_tree); } MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &series) { diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc index 00d0d74959..e1c118f891 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -17,7 +17,7 @@ SPDecompositionTreeNodeType } layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_leaf(d.raw_tree); + return require_leaf_only_binary_leaf(d.raw_tree); } std::optional diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index df0245a4d2..f15bf0fe53 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -54,18 +54,18 @@ PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) { PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) { return PCGBinarySeriesSplit{ - require_series(d.raw_tree), + require_leaf_only_binary_series_split(d.raw_tree), }; } PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) { return PCGBinaryParallelSplit{ - require_parallel(d.raw_tree), + require_leaf_only_binary_parallel_split(d.raw_tree), }; } parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { - return require_leaf(d.raw_tree); + return require_leaf_only_binary_leaf(d.raw_tree); } std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &spd, diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h similarity index 89% rename from lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h rename to lib/utils/include/utils/any_value_type/any_value_type.h index 8cd7d62101..eb211b1a1b 100644 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H #include #include @@ -64,6 +64,6 @@ struct hash<::FlexFlow::any_value_type> { size_t operator()(::FlexFlow::any_value_type const &) const; }; -} +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h new file mode 100644 index 0000000000..b03609171f --- /dev/null +++ b/lib/utils/include/utils/fmt/monostate.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H + +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::monostate, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::std::monostate const &, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = ""; + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &, std::monostate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 833013d6f6..4410f06e67 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -2,15 +2,37 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" #include +#include "utils/overload.h" +#include "utils/containers/transform.h" +#include "utils/containers/set_union.h" namespace FlexFlow { template std::unordered_set find_paths_to_leaf(FullBinaryTree const &tree, LeafLabel const &leaf) { - return find_paths_to_leaf(tree.raw_tree, make_any_value_type(leaf)); + return visit>( + tree, + overload { + [&](LeafLabel const &l) -> std::unordered_set { + if (l == leaf) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + [&](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(find_paths_to_leaf(get_left_child(parent), leaf), + nest_inside_left_child), + transform(find_paths_to_leaf(get_right_child(parent), leaf), + nest_inside_right_child)); + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h new file mode 100644 index 0000000000..96d384c3ae --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::string format_as(FullBinaryTreeParentNode const &t) { + return fmt::format("<{} ({} {})>", + t.label, + get_left_child(t), + get_right_child(t)); +} + +template +std::string format_as(FullBinaryTree const &t) { + auto visitor = FullBinaryTreeVisitor{ + [](FullBinaryTreeParentNode const &parent) { + return fmt::to_string(parent); + }, + [](LeafLabel const &leaf) { + return fmt::format("{}", leaf); + }, + }; + + return visit(t, visitor); +} + +template +std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode const &t) { + return (s << fmt::to_string(t)); +} + +template +std::ostream &operator<<(std::ostream &s, FullBinaryTree const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h new file mode 100644 index 0000000000..45d0c5f151 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h @@ -0,0 +1,102 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct FullBinaryTree; + +template +struct FullBinaryTreeParentNode { + explicit FullBinaryTreeParentNode( + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) + : label(label), + left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) + { } + + FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; + + bool operator==(FullBinaryTreeParentNode const &other) const { + if (this->tie_ptr() == other.tie_ptr()) { + return true; + } + + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTreeParentNode const &other) const { + if (this->tie_ptr() == other.tie_ptr()) { + return false; + } + + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTreeParentNode const &other) const { + return this->tie() < other.tie(); + } +public: + ParentLabel label; + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; +private: + std::tuple> const &, + std::shared_ptr> const &> + tie_ptr() const { + return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); + } + + std::tuple const &, + FullBinaryTree const &> + tie() const { + return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct FullBinaryTree { +public: + FullBinaryTree() = delete; + explicit FullBinaryTree(FullBinaryTreeParentNode const &t) + : root{t} {} + + explicit FullBinaryTree(LeafLabel const &t) + : root{t} {} + + bool operator==(FullBinaryTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTree const &other) const { + return this->tie() < other.tie(); + } +public: + std::variant, LeafLabel> root; +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml deleted file mode 100644 index aa9a1d8574..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTree" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml deleted file mode 100644 index 277405a23c..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeParentNode" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml index 0849ba2683..cb637057db 100644 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -10,7 +10,7 @@ template_params = [ includes = [ "", - "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h", + "utils/full_binary_tree/full_binary_tree.h", ] [[fields]] diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 4076447f57..926cc0ea9c 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -2,15 +2,36 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" #include -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/overload.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" namespace FlexFlow { template std::unordered_set get_all_leaf_paths(FullBinaryTree const &tree) { - return get_all_leaf_paths(tree.raw_tree); + return visit> + (tree, + overload { + [](LeafLabel const &) { + return std::unordered_set{binary_tree_root_path()}; + }, + [](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(get_all_leaf_paths(get_left_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(get_right_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index 675e385ca3..e9ceddff6d 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -1,18 +1,26 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/exception.h" +#include namespace FlexFlow { template FullBinaryTree get_child(FullBinaryTreeParentNode const &t, BinaryTreePathEntry const &e) { - return FullBinaryTreeParentNode{ - get_child(t.raw_tree, e), - }; + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return get_left_child(t); + case BinaryTreePathEntry::RIGHT_CHILD: + return get_right_child(t); + default: + throw mk_runtime_error(fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_label.h b/lib/utils/include/utils/full_binary_tree/get_label.h index 9f0099e609..1b48965b01 100644 --- a/lib/utils/include/utils/full_binary_tree/get_label.h +++ b/lib/utils/include/utils/full_binary_tree/get_label.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template ParentLabel get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &p) { - return p.raw_tree.label.template get(); + return p.label; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index 41fea3c5c2..c58a850a6d 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -1,15 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include +#include "utils/containers/multiset_union.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(FullBinaryTree const &t) { - return transform(get_leaves(t.raw_tree), [](any_value_type const &v) { return v.get(); }); +template +std::unordered_multiset + get_leaves(FullBinaryTree const &t) { + return visit>( + t, + overload { + [](FullBinaryTreeParentNode const &parent) { + return multiset_union(get_leaves(get_left_child(parent)), + get_leaves(get_right_child(parent))); + }, + [](ChildLabel const &leaf) { + return std::unordered_multiset{leaf}; + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h index 394b9042fe..163503abfd 100644 --- a/lib/utils/include/utils/full_binary_tree/get_left_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -1,16 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTree get_left_child(FullBinaryTreeParentNode const &t) { - return FullBinaryTree{ - t.raw_tree.left_child(), - }; +FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { + return *t.left_child_ptr; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h index 5d2c613101..0ee8eea6d8 100644 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -1,14 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" namespace FlexFlow { template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { - return get_node_type(t.raw_tree); + if (std::holds_alternative(t.root)) { + return FullBinaryTreeNodeType::LEAF; + } else { + bool is_parent = std::holds_alternative>(t.root); + assert (is_parent); + + return FullBinaryTreeNodeType::PARENT; + } } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h index 957ddbede8..e40f2024a1 100644 --- a/lib/utils/include/utils/full_binary_tree/get_right_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -1,16 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTree get_right_child(FullBinaryTreeParentNode const &t) { - return FullBinaryTree{ - t.raw_tree.right_child(), - }; +FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { + return *t.right_child_ptr; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index 59d24b6aad..6909d9e1ef 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -2,19 +2,35 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/containers/transform.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include namespace FlexFlow { template std::optional> get_subtree_at_path(FullBinaryTree const &t, BinaryTreePath const &p) { - return transform(get_subtree_at_path(t.raw_tree, p), - [](RawBinaryTree const &raw) { - return FullBinaryTree{raw}; - }); + if (p == binary_tree_root_path()) { + return t; + } + + return visit>>( + t, + overload { + [&](FullBinaryTreeParentNode const &parent) { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, curr), rest); + }, + [&](LeafLabel const &leaf) { + return std::nullopt; + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h new file mode 100644 index 0000000000..a29836f972 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/hash.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::FullBinaryTreeParentNode> { + size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { + return get_std_hash(t.tie()); + } +}; + +template +struct hash<::FlexFlow::FullBinaryTree> { + size_t operator()(::FlexFlow::FullBinaryTree const &t) const { + return get_std_hash(t.tie()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h index 585c05813e..0d830890dc 100644 --- a/lib/utils/include/utils/full_binary_tree/json.h +++ b/lib/utils/include/utils/full_binary_tree/json.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H #include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" #include "utils/full_binary_tree/visit.h" diff --git a/lib/utils/include/utils/full_binary_tree/make.h b/lib/utils/include/utils/full_binary_tree/make.h index ac458f0f4d..a4ef47c7df 100644 --- a/lib/utils/include/utils/full_binary_tree/make.h +++ b/lib/utils/include/utils/full_binary_tree/make.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { @@ -10,14 +10,18 @@ FullBinaryTree make_full_binary_tree_parent(ParentLabel FullBinaryTree const &lhs, FullBinaryTree const &rhs) { return FullBinaryTree{ - raw_binary_tree_make_parent(make_any_value_type(label), lhs.raw_tree, rhs.raw_tree), + FullBinaryTreeParentNode{ + label, + lhs, + rhs, + }, }; } template FullBinaryTree make_full_binary_tree_leaf(LeafLabel const &label) { return FullBinaryTree{ - raw_binary_tree_make_leaf(make_any_value_type(label)), + label, }; } diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h deleted file mode 100644 index 6d0d77caa9..0000000000 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H - -#include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" -#include - -namespace FlexFlow { - -RawBinaryTree get_child(RawBinaryTree const &, BinaryTreePathEntry const &); -std::unordered_set get_all_leaf_paths(RawBinaryTree const &); -std::unordered_set find_paths_to_leaf(RawBinaryTree const &, any_value_type const &leaf); -std::unordered_multiset get_leaves(RawBinaryTree const &); -FullBinaryTreeNodeType get_node_type(RawBinaryTree const &); -std::optional get_subtree_at_path(RawBinaryTree const &, BinaryTreePath const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h deleted file mode 100644 index 0bebe12109..0000000000 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H - -#include -#include -#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" -#include - -namespace FlexFlow { - -struct RawBinaryTree { - explicit RawBinaryTree( - any_value_type const &label, - RawBinaryTree const &lhs, - RawBinaryTree const &rhs); - explicit RawBinaryTree( - any_value_type const &label); - - RawBinaryTree(RawBinaryTree const &) = default; - - bool operator==(RawBinaryTree const &) const; - bool operator!=(RawBinaryTree const &) const; - - RawBinaryTree const &left_child() const; - RawBinaryTree const &right_child() const; - - bool is_leaf() const; -public: - any_value_type label; - std::shared_ptr left_child_ptr; - std::shared_ptr right_child_ptr; -private: - std::tuple, - std::optional> - value_tie() const; - std::tuple const &, - std::shared_ptr const &> - ptr_tie() const; - - friend std::hash; -}; - -std::string format_as(RawBinaryTree const &); -std::ostream &operator<<(std::ostream &, RawBinaryTree const &); - -RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label); -RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::RawBinaryTree> { - size_t operator()(::FlexFlow::RawBinaryTree const &) const; -}; - -} - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h index f897908c86..f7be417945 100644 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -1,29 +1,18 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTreeParentNode require_parent_node(FullBinaryTree const &t) { - if (t.raw_tree.is_leaf()) { - throw mk_runtime_error(fmt::format("require_parent_node called on leaf node {}", t)); - } - - return FullBinaryTreeParentNode{ - t.raw_tree, - }; +FullBinaryTreeParentNode const &require_full_binary_tree_parent_node(FullBinaryTree const &t) { + return std::get>(t.root); } template -LeafLabel require_leaf(FullBinaryTree const &t) { - if (!t.raw_tree.is_leaf()) { - throw mk_runtime_error(fmt::format("require_leaf called on non-leaf node {}", t)); - } - - return t.raw_tree.label.template get(); +LeafLabel const &require_full_binary_tree_leaf(FullBinaryTree const &t) { + return std::get(t.root); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index ea5729bd6c..860e60fcca 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -4,6 +4,7 @@ #include "utils/full_binary_tree/get_node_type.h" #include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/exception.h" +#include "utils/full_binary_tree/require.h" namespace FlexFlow { @@ -21,9 +22,9 @@ Result visit(FullBinaryTree const &t, FullBinaryTreeVisi FullBinaryTreeNodeType node_type = get_node_type(t); switch (node_type) { case FullBinaryTreeNodeType::PARENT: - return v.parent_func(require_parent_node(t)); + return v.parent_func(require_full_binary_tree_parent_node(t)); case FullBinaryTreeNodeType::LEAF: - return v.leaf_func(require_leaf(t)); + return v.leaf_func(require_full_binary_tree_leaf(t)); default: throw mk_runtime_error(fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml index 9734912f35..00c49992ef 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -14,12 +14,14 @@ template_params = [ ] includes = [ - "utils/full_binary_tree/full_binary_tree.dtg.h", + "utils/full_binary_tree/full_binary_tree.h", "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h", ] src_includes = [ "utils/full_binary_tree/json.h", + "utils/full_binary_tree/hash.h", + "utils/full_binary_tree/fmt.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h index c856f35d68..0c08a0462b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h @@ -10,11 +10,39 @@ namespace FlexFlow { template SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &label) { return label.template visit(overload { - [](SeriesLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](ParallelLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + [](GenericBinarySeriesSplitLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](GenericBinaryParallelSplitLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, }); } +template +GenericBinarySPSplitLabel make_generic_binary_series_split_label(SeriesLabel const &label) { + return GenericBinarySPSplitLabel{ + GenericBinarySeriesSplitLabel{ + label, + }, + }; +} + +template +GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(ParallelLabel const &label) { + return GenericBinarySPSplitLabel{ + GenericBinaryParallelSplitLabel{ + label, + }, + }; +} + +template +SeriesLabel require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &label) { + return label.template get>().raw_label; +} + +template +ParallelLabel require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &label) { + return label.template get>().raw_label; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml index c50a7b878b..c528c61f37 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml @@ -12,15 +12,15 @@ template_params = [ "ParallelSplitLabel", ] -# includes = [ -# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", -# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", -# ] +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", +] [[values]] -type = "SeriesSplitLabel" +type = "::FlexFlow::GenericBinarySeriesSplitLabel" key = "series" [[values]] -type = "ParallelSplitLabel" +type = "::FlexFlow::GenericBinaryParallelSplitLabel" key = "parallel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index c46be1c651..1dedf581fe 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -4,6 +4,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/full_binary_tree/get_label.h" #include "utils/full_binary_tree/visit.h" #include "utils/overload.h" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index 20ea7e744e..98382c78c8 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/full_binary_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" namespace FlexFlow { @@ -13,7 +14,7 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - GenericBinarySPSplitLabel{label}, + make_generic_binary_series_split_label(label), lhs.raw_tree, rhs.raw_tree), }; @@ -26,7 +27,7 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - GenericBinarySPSplitLabel{label}, + make_generic_binary_parallel_split_label(label), lhs.raw_tree, rhs.raw_tree), }; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index b8b18c4125..4961dc7b61 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/full_binary_tree/require.h" #include "utils/full_binary_tree/get_label.h" @@ -10,11 +11,11 @@ namespace FlexFlow { template GenericBinarySeriesSplit - require_series(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + require_generic_binary_series_split(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinarySeriesSplit{ - /*label=*/get_full_binary_tree_parent_label(parent).template get(), + /*label=*/require_generic_binary_series_split_label(get_full_binary_tree_parent_label(parent)), /*pre=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, @@ -26,11 +27,11 @@ GenericBinarySeriesSplit template GenericBinaryParallelSplit - require_parallel(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinaryParallelSplit{ - /*label=*/get_full_binary_tree_parent_label(parent).template get(), + /*label=*/require_generic_binary_parallel_split_label(get_full_binary_tree_parent_label(parent)), /*lhs=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, @@ -41,8 +42,8 @@ GenericBinaryParallelSplit } template -LeafLabel require_leaf(GenericBinarySPDecompositionTree const &t) { - return require_leaf(t.raw_tree); +LeafLabel require_generic_binary_leaf(GenericBinarySPDecompositionTree const &t) { + return require_full_binary_tree_leaf(t.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index a56ed952e9..a1ac10a6a0 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -13,15 +13,15 @@ Result visit(GenericBinarySPDecompositionTree wrap_series_split(GenericBinarySeriesSplit const &series_split) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - /*label=*/GenericBinarySPSplitLabel{series_split.label}, + /*label=*/make_generic_binary_series_split_label(series_split.label), /*lhs=*/series_split.pre.raw_tree, /*rhs=*/series_split.post.raw_tree), }; @@ -23,7 +25,7 @@ GenericBinarySPDecompositionTree wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - /*label=*/GenericBinarySPSplitLabel{parallel_split.label}, + /*label=*/make_generic_binary_parallel_split_label(parallel_split.label), /*lhs=*/parallel_split.lhs.raw_tree, /*rhs=*/parallel_split.rhs.raw_tree), }; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml deleted file mode 100644 index 0506d36227..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinaryParallelSplitLabel" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml deleted file mode 100644 index b780bfeea6..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySeriesSplitLabel" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml index dacab0244a..bf52ecc6df 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml @@ -12,10 +12,13 @@ template_params = [ includes = [ "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/monostate.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::LeafOnlyBinarySeriesSplitLabel, ::FlexFlow::LeafOnlyBinaryParallelSplitLabel, LeafLabel>" +type = "::FlexFlow::GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index a9dcb17f0d..3297a30ec7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -11,7 +11,7 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnl LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_series_split( - LeafOnlyBinarySeriesSplitLabel{}, + std::monostate{}, pre.raw_tree, post.raw_tree), }; @@ -22,7 +22,7 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafO LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_parallel_split( - LeafOnlyBinaryParallelSplitLabel{}, + std::monostate{}, lhs.raw_tree, rhs.raw_tree), }; @@ -32,8 +32,8 @@ template LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(LeafLabel const &label) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_sp_leaf< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>(label), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h index 65d42eee7c..400b6be1de 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -12,12 +12,12 @@ namespace FlexFlow { template LeafOnlyBinarySeriesSplit - require_series(LeafOnlyBinarySPDecompositionTree const &t) { + require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &t) { GenericBinarySeriesSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel> raw = - require_series(t.raw_tree); + require_generic_binary_series_split(t.raw_tree); return LeafOnlyBinarySeriesSplit{ LeafOnlyBinarySPDecompositionTree{raw.pre}, @@ -27,12 +27,12 @@ LeafOnlyBinarySeriesSplit template LeafOnlyBinaryParallelSplit - require_parallel(LeafOnlyBinarySPDecompositionTree const &t) { + require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &t) { GenericBinaryParallelSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel> raw = - require_parallel(t.raw_tree); + require_generic_binary_parallel_split(t.raw_tree); return LeafOnlyBinaryParallelSplit{ LeafOnlyBinarySPDecompositionTree{raw.lhs}, @@ -41,8 +41,8 @@ LeafOnlyBinaryParallelSplit } template -LeafLabel require_leaf(LeafOnlyBinarySPDecompositionTree const &t) { - return require_leaf(t.raw_tree); +LeafLabel require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &t) { + return require_generic_binary_leaf(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h index b4f4239d39..4cbd2b26bd 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -30,18 +30,18 @@ template LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { using GenericVisitor = GenericBinarySPDecompositionTreeVisitor - ; GenericVisitor generic_visitor = GenericVisitor{ - [&](LeafOnlyBinarySeriesSplitLabel const &x) { + [&](std::monostate const &x) { return x; }, - [&](LeafOnlyBinaryParallelSplitLabel const &x) { + [&](std::monostate const &x) { return x; }, [&](LeafLabel const &t) { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h index 0284f6ba41..21fae97633 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h @@ -14,10 +14,10 @@ LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySer return LeafOnlyBinarySPDecompositionTree{ wrap_series_split( GenericBinarySeriesSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>{ - LeafOnlyBinarySeriesSplitLabel{}, + std::monostate{}, split.pre.raw_tree, split.post.raw_tree, } @@ -30,10 +30,10 @@ LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryP return LeafOnlyBinarySPDecompositionTree{ wrap_parallel_split( GenericBinaryParallelSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>{ - LeafOnlyBinaryParallelSplitLabel{}, + std::monostate{}, split.lhs.raw_tree, split.rhs.raw_tree, } diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc similarity index 93% rename from lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc rename to lib/utils/src/utils/any_value_type/any_value_type.cc index d54796ae49..b3a72dafa9 100644 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -1,4 +1,4 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" +#include "utils/any_value_type/any_value_type.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/fmt/monostate.cc b/lib/utils/src/utils/fmt/monostate.cc new file mode 100644 index 0000000000..55988cdce0 --- /dev/null +++ b/lib/utils/src/utils/fmt/monostate.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/monostate.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::monostate const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/fmt.cc b/lib/utils/src/utils/full_binary_tree/fmt.cc new file mode 100644 index 0000000000..82bf382821 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/fmt.cc @@ -0,0 +1,10 @@ +#include "utils/full_binary_tree/fmt.h" + +namespace FlexFlow { + +template std::string format_as(FullBinaryTreeParentNode const &); +template std::string format_as(FullBinaryTree const &); +template std::ostream &operator<<(std::ostream &, FullBinaryTreeParentNode const &); +template std::ostream &operator<<(std::ostream &, FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_label.cc b/lib/utils/src/utils/full_binary_tree/get_label.cc new file mode 100644 index 0000000000..25ed6cf3f6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_label.cc @@ -0,0 +1,8 @@ +#include "utils/full_binary_tree/get_label.h" + +namespace FlexFlow { + +template + int get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_node_type.cc b/lib/utils/src/utils/full_binary_tree/get_node_type.cc new file mode 100644 index 0000000000..a4c88a03f3 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_node_type.cc @@ -0,0 +1,7 @@ +#include "utils/full_binary_tree/get_node_type.h" + +namespace FlexFlow { + +template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/make.cc b/lib/utils/src/utils/full_binary_tree/make.cc new file mode 100644 index 0000000000..da48d2a2c4 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/make.cc @@ -0,0 +1,12 @@ +#include "utils/full_binary_tree/make.h" + +namespace FlexFlow { + +template + FullBinaryTree make_full_binary_tree_parent(int const &, + FullBinaryTree const &, + FullBinaryTree const &); +template + FullBinaryTree make_full_binary_tree_leaf(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc deleted file mode 100644 index bc833f95d4..0000000000 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" -#include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/containers/transform.h" -#include "utils/containers/set_union.h" -#include "utils/containers/multiset_union.h" - -namespace FlexFlow { - -RawBinaryTree get_child(RawBinaryTree const &t, BinaryTreePathEntry const &e) { - if (e == BinaryTreePathEntry::LEFT_CHILD) { - return t.left_child(); - } else { - assert (e == BinaryTreePathEntry::RIGHT_CHILD); - return t.right_child(); - } -} - -std::unordered_set get_all_leaf_paths(RawBinaryTree const &t) { - if (t.is_leaf()) { - return {binary_tree_root_path()}; - } else { - return set_union( - transform(get_all_leaf_paths(t.left_child()), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(t.right_child()), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } -} - -std::unordered_set find_paths_to_leaf(RawBinaryTree const &t, any_value_type const &leaf) { - if (t.is_leaf()) { - if (t.label == leaf) { - return {binary_tree_root_path()}; - } else { - return {}; - } - } else { - return set_union( - transform(find_paths_to_leaf(t.left_child(), leaf), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(find_paths_to_leaf(t.right_child(), leaf), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } -} - -std::unordered_multiset get_leaves(RawBinaryTree const &t) { - if (t.is_leaf()) { - return {t.label}; - } else { - return multiset_union(get_leaves(t.left_child()), get_leaves(t.right_child())); - } -} - -FullBinaryTreeNodeType get_node_type(RawBinaryTree const &t) { - if (t.is_leaf()) { - return FullBinaryTreeNodeType::LEAF; - } else { - return FullBinaryTreeNodeType::PARENT; - } -} - -std::optional get_subtree_at_path(RawBinaryTree const &t, BinaryTreePath const &p) { - if (p == binary_tree_root_path()) { - return t; - } else if (t.is_leaf()) { - return std::nullopt; - } else { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - - return get_subtree_at_path(get_child(t, curr), rest); - } -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc deleted file mode 100644 index d432d32eb9..0000000000 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace FlexFlow { - -RawBinaryTree::RawBinaryTree( - any_value_type const &label, - RawBinaryTree const &lhs, - RawBinaryTree const &rhs) - : label(label), - left_child_ptr(std::make_shared(lhs)), - right_child_ptr(std::make_shared(rhs)) -{ } - -RawBinaryTree::RawBinaryTree( - any_value_type const &label) - : label(label), left_child_ptr(nullptr), right_child_ptr(nullptr) -{ } - -bool RawBinaryTree::operator==(RawBinaryTree const &other) const { - if (this->ptr_tie() == other.ptr_tie()) { - return true; - } - - return (this->value_tie() == other.value_tie()); -} - -bool RawBinaryTree::operator!=(RawBinaryTree const &other) const { - if (this->ptr_tie() == other.ptr_tie()) { - return false; - } - - return (this->value_tie() != other.value_tie()); -} - -RawBinaryTree const &RawBinaryTree::left_child() const { - return *this->left_child_ptr; -} - -RawBinaryTree const &RawBinaryTree::right_child() const { - return *this->right_child_ptr; -} - -bool RawBinaryTree::is_leaf() const { - return this->left_child_ptr == nullptr && this->right_child_ptr == nullptr; -} - -std::tuple, - std::optional> - RawBinaryTree::value_tie() const { - - auto ptr_to_optional = [](std::shared_ptr const &ptr) - -> std::optional { - if (ptr == nullptr) { - return std::nullopt; - } else { - return *ptr; - } - }; - - return {this->label, ptr_to_optional(this->left_child_ptr), ptr_to_optional(this->right_child_ptr)}; -} - -std::tuple const &, - std::shared_ptr const &> - RawBinaryTree::ptr_tie() const { - return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); -} - -std::string format_as(RawBinaryTree const &t) { - if (t.is_leaf()) { - return fmt::to_string(t.label); - } else { - return fmt::format("({} {} {})", t.label, t.left_child(), t.right_child()); - } -} - -std::ostream &operator<<(std::ostream &s, RawBinaryTree const &t) { - return (s << fmt::to_string(t)); -} - -RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label) { - return RawBinaryTree{label}; -} - -RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs) { - return RawBinaryTree{label, lhs, rhs}; -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::RawBinaryTree>::operator()(::FlexFlow::RawBinaryTree const &t) const { - return ::FlexFlow::get_std_hash(t.value_tie()); -} - -} // namespace std diff --git a/lib/utils/src/utils/full_binary_tree/require.cc b/lib/utils/src/utils/full_binary_tree/require.cc new file mode 100644 index 0000000000..d6b0bbeb68 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/require.cc @@ -0,0 +1,11 @@ +#include "utils/full_binary_tree/require.h" + +namespace FlexFlow { + +template + FullBinaryTreeParentNode const & + require_full_binary_tree_parent_node(FullBinaryTree const &); +template + int const &require_full_binary_tree_leaf(FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc new file mode 100644 index 0000000000..b43eb5bce6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -0,0 +1,8 @@ +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template + int visit(FullBinaryTree const &, FullBinaryTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 2f51762db2..92a46c030d 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -44,18 +44,18 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { BinarySeriesSplit require_series(BinarySPDecompositionTree const &tt) { return BinarySeriesSplit{ - require_series(tt.raw_tree), + require_leaf_only_binary_series_split(tt.raw_tree), }; } BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &tt) { return BinaryParallelSplit{ - require_parallel(tt.raw_tree), + require_leaf_only_binary_parallel_split(tt.raw_tree), }; } Node require_leaf(BinarySPDecompositionTree const &tt) { - return require_leaf(tt.raw_tree); + return require_leaf_only_binary_leaf(tt.raw_tree); } SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tt) { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..2ecd4c94d2 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" + +namespace FlexFlow { + +template + std::unordered_set find_paths_to_leaf(GenericBinarySPDecompositionTree const &, + int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc new file mode 100644 index 0000000000..10bbc60c6d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc @@ -0,0 +1,16 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &); +template + GenericBinarySPSplitLabel make_generic_binary_series_split_label(int const &); +template + GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(int const &); +template + int require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &); +template + int require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..31e664b726 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" + +namespace FlexFlow { + +template + std::unordered_set get_all_leaf_paths(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 71b67acc54..20ba3fa5d7 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1 +1,13 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template + std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &); +template + std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &); +template + std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc index 227e5bd79c..783a7a974b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc index 1618128226..9d652d44da 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 05ec6b5925..6c67fdc244 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" + +namespace FlexFlow { + +template + int get_num_tree_nodes(GenericBinarySPDecompositionTree const &); +template + int get_num_tree_nodes(GenericBinarySeriesSplit const &); +template + int get_num_tree_nodes(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc index f168ba1e2f..03c154fb67 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..6bfb573359 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" + +namespace FlexFlow { + +template + std::optional> + get_subtree_at_path(GenericBinarySPDecompositionTree const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc index 3da024743c..5e5b768ed7 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" + +namespace FlexFlow { + +template + bool is_series_split(GenericBinarySPDecompositionTree const &); +template + bool is_parallel_split(GenericBinarySPDecompositionTree const &); +template + bool is_leaf(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 8fe9397003..87ae55b900 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index d202f55964..5a40a3b6bf 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc index fb1532b3ef..a36ccce359 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree make_generic_binary_series_split( + int const &, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template + GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + int const &label, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template + GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc index 3fee45fcf5..8305a1243e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -1 +1,14 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" + +namespace FlexFlow { + +template + GenericBinarySeriesSplit + require_generic_binary_series_split(GenericBinarySPDecompositionTree const &); +template + GenericBinaryParallelSplit + require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &); +template + int require_generic_binary_leaf(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc index cabd66cff7..4495a60f92 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -1 +1,10 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template + GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc new file mode 100644 index 0000000000..0b3189b47b --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + wrap_series_split(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + wrap_parallel_split(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..41accc79d0 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template + std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc new file mode 100644 index 0000000000..0959a42f01 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType get_node_type(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..dd94936997 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_left_associative(LeafOnlyBinarySPDecompositionTree const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..46b89aa98f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_right_associative(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc new file mode 100644 index 0000000000..5690ebe8a8 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &); +template + LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc new file mode 100644 index 0000000000..ed0e5892da --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &); +template + LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc new file mode 100644 index 0000000000..602aebc7e8 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc @@ -0,0 +1,14 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc new file mode 100644 index 0000000000..1a1cd9909d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" + +namespace FlexFlow { + +template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &); +template int require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..22dd5e0db5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1,16 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); +template + LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); + +template + LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc new file mode 100644 index 0000000000..3836124eb6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySeriesSplit const &); +template + LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryParallelSplit const &); + +} // namespace FlexFlow From 5d22c6dd3267613ad0f6d40b68652e04a1d64fc5 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 3 Oct 2024 17:05:27 -0700 Subject: [PATCH 23/29] Get all existing tests passing again --- .../pcg_binary_sp_decomposition.cc | 6 +- .../make.h | 9 +- .../binary_sp_decomposition_tree.cc | 6 +- .../make.cc | 6 +- .../hash.cc | 117 ------------------ .../is_binary_sp_tree_left_associative.cc | 102 --------------- .../is_binary_sp_tree_right_associative.cc | 102 --------------- .../transform.cc | 28 ----- .../is_binary_sp_tree_left_associative.cc | 114 +++++++++++++++++ .../is_binary_sp_tree_right_associative.cc | 114 +++++++++++++++++ 10 files changed, 241 insertions(+), 363 deletions(-) delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index f15bf0fe53..9a2fc43a37 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -24,19 +24,19 @@ SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - leaf_only_make_series_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - leaf_only_make_parallel_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { return PCGBinarySPDecomposition{ - leaf_only_make_leaf_node(l), + leaf_only_binary_sp_tree_make_leaf(l), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index 3297a30ec7..0eb05dc867 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -7,7 +7,7 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, +LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_series_split( @@ -18,7 +18,7 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnl } template -LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, +LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_parallel_split( @@ -29,12 +29,11 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafO } template -LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(LeafLabel const &label) { +LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_leaf(LeafLabel const &label) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_sp_leaf< std::monostate, - std::monostate, - LeafLabel>(label), + std::monostate, LeafLabel>(label), }; } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 92a46c030d..a4bd8b1ba7 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -12,7 +12,7 @@ BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - leaf_only_make_series_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), }; } @@ -20,13 +20,13 @@ BinarySPDecompositionTree make_parallel_split(BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{ - leaf_only_make_parallel_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } BinarySPDecompositionTree make_leaf_node(Node const &n) { return BinarySPDecompositionTree{ - leaf_only_make_leaf_node(n), + leaf_only_binary_sp_tree_make_leaf(n), }; } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc index 602aebc7e8..112a14c206 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc @@ -3,12 +3,12 @@ namespace FlexFlow { template - LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_series_split(LeafOnlyBinarySPDecompositionTree const &, LeafOnlyBinarySPDecompositionTree const &); template - LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &, LeafOnlyBinarySPDecompositionTree const &); template - LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(int const &); + LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_leaf(int const &); } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index 87d41a0bb6..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("std::hash>") { - FAIL("TODO, probably move over to FullBinaryTree"); - // SUBCASE("leaf") { - // GenericBinarySPDecompositionTree leaf_5 = - // make_generic_binary_sp_leaf(5); - // size_t leaf_5_hash = get_std_hash(leaf_5); - // - // SUBCASE("leaves with same labels hash to the same value") { - // GenericBinarySPDecompositionTree also_leaf_5 = - // make_generic_binary_sp_leaf(5); - // size_t also_leaf_5_hash = get_std_hash(also_leaf_5); - // - // CHECK(leaf_5_hash == also_leaf_5_hash); - // } - // - // SUBCASE("leaves with different labels hash to different values") { - // GenericBinarySPDecompositionTree leaf_6 = - // make_generic_binary_sp_leaf(6); - // size_t leaf_6_hash = get_std_hash(leaf_6); - // - // CHECK(leaf_5_hash != leaf_6_hash); - // } - // } - // - // SUBCASE("series split") { - // GenericBinarySPDecompositionTree series_5_6 = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // size_t series_5_6_hash = get_std_hash(series_5_6); - // - // SUBCASE("same children lead to the same hash") { - // GenericBinarySPDecompositionTree also_series_5_6 = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // size_t also_series_5_6_hash = get_std_hash(also_series_5_6); - // - // CHECK(series_5_6_hash == also_series_5_6_hash); - // } - // - // SUBCASE("hash is order dependent") { - // GenericBinarySPDecompositionTree series_6_5 = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(6), - // make_generic_binary_sp_leaf(5)); - // size_t series_6_5_hash = get_std_hash(series_6_5); - // - // CHECK(series_5_6_hash != series_6_5_hash); - // } - // - // SUBCASE("different left child leads to different hash") { - // GenericBinarySPDecompositionTree series_4_6 = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(4), - // make_generic_binary_sp_leaf(6)); - // size_t series_4_6_hash = get_std_hash(series_4_6); - // - // CHECK(series_5_6_hash != series_4_6_hash); - // } - // - // SUBCASE("different right child leads to different hash") { - // GenericBinarySPDecompositionTree series_5_7 = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(7)); - // size_t series_5_7_hash = get_std_hash(series_5_7); - // - // CHECK(series_5_6_hash != series_5_7_hash); - // } - // } - // - // SUBCASE("parallel split") { - // GenericBinarySPDecompositionTree parallel_5_6 = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // size_t parallel_5_6_hash = get_std_hash(parallel_5_6); - // - // SUBCASE("same children lead to the same hash") { - // GenericBinarySPDecompositionTree also_parallel_5_6 = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); - // - // CHECK(parallel_5_6_hash == also_parallel_5_6_hash); - // } - // - // SUBCASE("hash is order dependent") { - // GenericBinarySPDecompositionTree parallel_6_5 = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), - // make_generic_binary_sp_leaf(5)); - // size_t parallel_6_5_hash = get_std_hash(parallel_6_5); - // - // CHECK(parallel_5_6_hash != parallel_6_5_hash); - // } - // - // SUBCASE("different left child leads to different hash") { - // GenericBinarySPDecompositionTree parallel_4_6 = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - // make_generic_binary_sp_leaf(6)); - // size_t parallel_4_6_hash = get_std_hash(parallel_4_6); - // - // CHECK(parallel_5_6_hash != parallel_4_6_hash); - // } - // - // SUBCASE("different right child leads to different hash") { - // GenericBinarySPDecompositionTree parallel_5_7 = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(7)); - // size_t parallel_5_7_hash = get_std_hash(parallel_5_7); - // - // CHECK(parallel_5_6_hash != parallel_5_7_hash); - // } - // } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc deleted file mode 100644 index 481dcd85d3..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_left_associative(" - "GenericBinarySPDecompositionTree)") { - FAIL("TODO"); - // int n1 = 1; - // int n2 = 2; - // int n3 = 3; - // int n4 = 4; - // - // SUBCASE("input is actually left associative") { - // SUBCASE("just node") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_sp_leaf(n1); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just series") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_sp_leaf(n3)); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just parallel") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_sp_leaf(n3)); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("nested") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n3), - // make_generic_binary_sp_leaf(n4))); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("input is not left associative") { - // SUBCASE("just series") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n2), - // make_generic_binary_sp_leaf(n3))); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = false; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just parallel") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n2), - // make_generic_binary_sp_leaf(n3))); - // - // bool result = is_binary_sp_tree_left_associative(input); - // bool correct = false; - // - // CHECK(result == correct); - // } - // } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc deleted file mode 100644 index 3651eca03a..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_right_associative(" - "LeafOnlyBinarySPDecompositionTree)") { - FAIL("TODO"); - // int n1 = 1; - // int n2 = 2; - // int n3 = 3; - // int n4 = 4; - // - // SUBCASE("input is actually right associative") { - // SUBCASE("just node") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_sp_leaf(n1); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just series") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n2), - // make_generic_binary_sp_leaf(n3))); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just parallel") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n2), - // make_generic_binary_sp_leaf(n3))); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // - // SUBCASE("nested") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n3), - // make_generic_binary_sp_leaf(n4))); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = true; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("input is not right associative") { - // SUBCASE("just series") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_series_split( - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_sp_leaf(n3)); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = false; - // - // CHECK(result == correct); - // } - // - // SUBCASE("just parallel") { - // LeafOnlyBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(n1), - // make_generic_binary_sp_leaf(n2)), - // make_generic_binary_sp_leaf(n3)); - // - // bool result = is_binary_sp_tree_right_associative(input); - // bool correct = false; - // - // CHECK(result == correct); - // } - // } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index b9021a19ef..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { - FAIL("TODO"); - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_series_split(make_generic_binary_sp_leaf(1), - // make_generic_binary_sp_leaf(4)), - // make_generic_binary_sp_leaf(2)); - // - // GenericBinarySPDecompositionTree result = - // transform(input, [](int x) { return std::to_string(x); }); - // - // GenericBinarySPDecompositionTree correct = - // make_generic_binary_parallel_split( - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(std::string{"1"}), - // make_generic_binary_sp_leaf(std::string{"4"})), - // make_generic_binary_sp_leaf(std::string{"2"})); - // - // CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..3b1e1899ca --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,114 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_left_associative") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + auto make_leaf = [](int n) { + return leaf_only_binary_sp_tree_make_leaf(n); + }; + + auto make_series_split = [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_series_split(l, r); + }; + + auto make_parallel_split = [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_parallel_split(l, r); + }; + + SUBCASE("input is actually left associative") { + SUBCASE("just node") { + LeafOnlyBinarySPDecompositionTree input = + make_leaf(n1); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_series_split( + make_leaf(n1), + make_leaf(n2)), + make_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + LeafOnlyBinarySPDecompositionTree input = + make_parallel_split( + make_parallel_split( + make_leaf(n1), + make_leaf(n2)), + make_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_parallel_split( + make_leaf(n1), + make_leaf(n2)), + make_parallel_split( + make_leaf(n3), + make_leaf(n4))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not left associative") { + SUBCASE("just series") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_leaf(n1), + make_series_split( + make_leaf(n2), + make_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + LeafOnlyBinarySPDecompositionTree input = + make_parallel_split( + make_leaf(n1), + make_parallel_split( + make_leaf(n2), + make_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..9e34a769ba --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,114 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_right_associative") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + auto make_leaf = [](int n) { + return leaf_only_binary_sp_tree_make_leaf(n); + }; + + auto make_series_split = [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_series_split(l, r); + }; + + auto make_parallel_split = [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_parallel_split(l, r); + }; + + SUBCASE("input is actually right associative") { + SUBCASE("just node") { + LeafOnlyBinarySPDecompositionTree input = + make_leaf(n1); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_leaf(n1), + make_series_split( + make_leaf(n2), + make_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + LeafOnlyBinarySPDecompositionTree input = + make_parallel_split( + make_leaf(n1), + make_parallel_split( + make_leaf(n2), + make_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_parallel_split( + make_leaf(n1), + make_leaf(n2)), + make_parallel_split( + make_leaf(n3), + make_leaf(n4))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not right associative") { + SUBCASE("just series") { + LeafOnlyBinarySPDecompositionTree input = + make_series_split( + make_series_split( + make_leaf(n1), + make_leaf(n2)), + make_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + LeafOnlyBinarySPDecompositionTree input = + make_parallel_split( + make_parallel_split( + make_leaf(n1), + make_leaf(n2)), + make_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} From 3d08831e7eb477894e19d0959641107924a5bb62 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 5 Oct 2024 11:32:47 -0700 Subject: [PATCH 24/29] Fix tests and format --- .proj.toml | 8 +- .../compiler/cost_estimator/cost_estimator.h | 2 +- .../abstracted_tensor_set_movement.h | 13 +- ...tracted_tensor_set_movement_across_split.h | 7 +- .../get_machine_resource_splits.h | 2 +- .../get_optimal_machine_mapping.h | 42 +- .../get_tensor_set_movement_across_split.h | 9 +- .../machine_mapping/machine_mapping_cache.h | 8 +- .../machine_mapping_constraints.h | 30 +- .../get_machine_mapping_problem_tree.h | 5 +- .../machine_mapping_problem_tree.h | 38 +- .../mm_problem_tree_series_split.h | 3 +- .../unmapped_op_cost_estimate_key.h | 9 +- .../machine_mapping/machine_mapping_result.h | 33 +- ...lel_layer_guid_oblivious_machine_mapping.h | 17 +- .../machine_mapping/transitive_reduced_pcg.h | 26 +- ...get_pcg_balanced_binary_sp_decomposition.h | 3 +- .../get_pcg_series_parallel_decomposition.h | 3 +- .../pcg_binary_sp_decomposition.h | 17 +- .../include/compiler/unity_algorithm.h | 2 +- .../abstracted_tensor_set_movement.cc | 61 +-- ...racted_tensor_set_movement_across_split.cc | 65 +-- .../get_machine_resource_splits.cc | 3 +- .../get_optimal_machine_mapping.cc | 315 +++++++------- .../get_tensor_set_movement_across_split.cc | 18 +- .../machine_mapping/machine_mapping_cache.cc | 15 +- .../machine_mapping_constraints.cc | 100 +++-- .../get_machine_mapping_problem_tree.cc | 50 ++- .../machine_mapping_problem_tree.cc | 85 ++-- .../mm_problem_tree_parallel_split.cc | 5 +- .../mm_problem_tree_series_split.cc | 7 +- .../unmapped_op_cost_estimate_key.cc | 33 +- .../machine_mapping/machine_mapping_result.cc | 70 +-- ...el_layer_guid_oblivious_machine_mapping.cc | 16 +- .../machine_mapping/transitive_reduced_pcg.cc | 72 ++-- ...mputation_graph_binary_sp_decomposition.cc | 12 +- .../get_pcg_series_parallel_decomposition.cc | 3 +- .../pcg_binary_parallel_split.cc | 6 +- .../pcg_binary_series_split.cc | 14 +- .../pcg_binary_sp_decomposition.cc | 38 +- lib/compiler/src/graph_optimize_state.cc | 85 ---- lib/compiler/src/unity_algorithm.cc | 6 +- ...racted_tensor_set_movement_across_split.cc | 407 ++++++++---------- .../cost_estimator_for_test.cc | 34 +- .../machine_mapping/cost_estimator_for_test.h | 12 +- .../get_machine_resource_splits.cc | 338 ++++++++------- .../get_optimal_machine_mapping.cc | 255 +++++------ .../get_tensor_set_movement_across_split.cc | 28 +- .../get_machine_mapping_problem_tree.cc | 243 ++++++----- .../machine_mapping/machine_mapping_result.cc | 330 +++++++------- lib/kernels/include/kernels/accessor.h | 12 +- .../include/kernels/datatype_dispatch.h | 2 +- .../include/local-execution/device_specific.h | 9 +- .../include/local-execution/permissions.h | 4 +- .../src/local_task_argument_accessor.cc | 4 +- lib/local-execution/src/permissions.cc | 3 +- .../parallel_computation_graph.h | 9 +- lib/pcg/src/pcg/computation_graph.cc | 22 +- lib/pcg/src/pcg/computation_graph_builder.cc | 11 +- .../parallel_computation_graph.cc | 62 +-- .../parallel_computation_graph_builder.cc | 25 +- .../parallel_computation_graph.cc | 36 +- .../perform_shape_inference.cc | 3 + .../utils/any_value_type/any_value_type.h | 39 +- lib/utils/include/utils/containers/flatmap.h | 16 +- .../utils/containers/get_all_assignments.h | 28 +- lib/utils/include/utils/containers/get_only.h | 4 +- .../include/utils/containers/merge_maps.h | 4 +- .../include/utils/containers/transform.h | 19 +- lib/utils/include/utils/containers/try_at.h | 4 +- .../containers/unordered_map_from_pairs.h | 7 +- lib/utils/include/utils/fmt/monostate.h | 2 +- .../full_binary_tree/find_paths_to_leaf.h | 44 +- .../include/utils/full_binary_tree/fmt.h | 26 +- .../utils/full_binary_tree/full_binary_tree.h | 38 +- .../full_binary_tree/get_all_leaf_paths.h | 44 +- .../utils/full_binary_tree/get_child.h | 12 +- .../utils/full_binary_tree/get_label.h | 3 +- .../utils/full_binary_tree/get_leaves.h | 23 +- .../utils/full_binary_tree/get_left_child.h | 3 +- .../utils/full_binary_tree/get_node_type.h | 8 +- .../utils/full_binary_tree/get_right_child.h | 3 +- .../full_binary_tree/get_subtree_at_path.h | 24 +- .../include/utils/full_binary_tree/hash.h | 9 +- .../include/utils/full_binary_tree/json.h | 27 +- .../include/utils/full_binary_tree/make.h | 22 +- .../include/utils/full_binary_tree/require.h | 7 +- .../utils/full_binary_tree/transform.h | 51 +-- .../include/utils/full_binary_tree/visit.h | 15 +- .../get_dataflow_edges_from_node_to_node.h | 5 +- ...nsitive_reduced_boundary_nodes_for_split.h | 6 +- ...et_transitive_reduced_edges_across_split.h | 5 +- ..._transitive_reduced_outputs_across_split.h | 5 +- .../transitive_reduced_dataflow_graph.h | 3 +- .../get_edges_from_subgraph_to_subgraph.h | 7 +- ...azy_copy_of_labelled_dataflow_graph_view.h | 6 +- .../algorithms/rewrite_node_labels.h | 3 +- .../binary_parallel_split.h | 2 +- .../binary_series_split.h | 2 +- .../binary_sp_decomposition_tree.h | 7 +- .../find_paths_to_leaf.h | 7 +- .../generic_binary_sp_split_label.h | 43 +- .../get_all_leaf_paths.h | 5 +- .../get_leaves.h | 23 +- .../get_left_child.h | 12 +- .../get_node_type.h | 25 +- .../get_num_tree_nodes.h | 19 +- .../get_right_child.h | 12 +- .../get_subtree_at_path.h | 25 +- .../generic_binary_sp_decomposition_tree/is.h | 12 +- .../is_binary_sp_tree_left_associative.h | 12 +- .../is_binary_sp_tree_right_associative.h | 12 +- .../make.h | 65 ++- .../require.h | 65 ++- .../transform.h | 87 ++-- .../visit.h | 14 +- .../wrap.h | 39 +- .../find_paths_to_leaf.h | 6 +- .../get_leaves.h | 5 +- .../get_node_type.h | 3 +- .../is_binary_sp_tree_left_associative.h | 5 +- .../is_binary_sp_tree_right_associative.h | 5 +- .../leaf_only_binary_parallel_split.h | 6 +- .../leaf_only_binary_series_split.h | 6 +- .../make.h | 33 +- .../require.h | 38 +- .../transform.h | 58 ++- .../wrap.h | 38 +- .../utils/any_value_type/any_value_type.cc | 17 +- .../full_binary_tree/binary_tree_path.cc | 8 +- lib/utils/src/utils/full_binary_tree/fmt.cc | 6 +- .../src/utils/full_binary_tree/get_label.cc | 4 +- lib/utils/src/utils/full_binary_tree/make.cc | 11 +- .../src/utils/full_binary_tree/require.cc | 7 +- lib/utils/src/utils/full_binary_tree/visit.cc | 4 +- .../get_dataflow_edges_from_node_to_node.cc | 13 +- ...sitive_reduced_boundary_nodes_for_split.cc | 22 +- ...t_transitive_reduced_edges_across_split.cc | 30 +- ...transitive_reduced_outputs_across_split.cc | 6 +- .../transitive_reduced_dataflow_graph.cc | 7 +- .../get_edges_from_subgraph_to_subgraph.cc | 18 +- .../algorithms/get_subgraph_inputs.cc | 2 +- .../binary_parallel_split.cc | 4 +- .../binary_series_split.cc | 4 +- .../binary_sp_decomposition_tree.cc | 7 +- .../find_paths_to_leaf.cc | 6 +- .../generic_binary_sp_split_label.cc | 20 +- .../get_all_leaf_paths.cc | 4 +- .../get_leaves.cc | 13 +- .../get_left_child.cc | 10 +- .../get_node_type.cc | 5 +- .../get_num_tree_nodes.cc | 12 +- .../get_right_child.cc | 10 +- .../get_subtree_at_path.cc | 3 +- .../is.cc | 11 +- .../is_binary_sp_tree_left_associative.cc | 5 +- .../is_binary_sp_tree_right_associative.cc | 5 +- .../make.cc | 24 +- .../require.cc | 16 +- .../transform.cc | 12 +- .../wrap.cc | 6 +- .../get_leaves.cc | 4 +- .../get_node_type.cc | 4 +- .../is_binary_sp_tree_left_associative.cc | 5 +- .../is_binary_sp_tree_right_associative.cc | 4 +- .../leaf_only_binary_parallel_split.cc | 8 +- .../leaf_only_binary_series_split.cc | 8 +- .../make.cc | 18 +- .../require.cc | 10 +- .../transform.cc | 18 +- .../wrap.cc | 8 +- ...ft_associative_binary_sp_tree_from_nary.cc | 11 +- ...ht_associative_binary_sp_tree_from_nary.cc | 4 +- .../intermediate_sp_decomposition_tree.cc | 2 +- .../test/src/utils/containers/flatmap.cc | 62 +-- .../utils/containers/get_all_assignments.cc | 33 +- lib/utils/test/src/utils/containers/try_at.cc | 7 +- .../containers/unordered_map_from_pairs.cc | 40 +- .../get_dataflow_edges_from_node_to_node.cc | 49 ++- ...sitive_reduced_boundary_nodes_for_split.cc | 27 +- ...t_transitive_reduced_edges_across_split.cc | 118 +++-- ...transitive_reduced_outputs_across_split.cc | 23 +- .../get_edges_from_subgraph_to_subgraph.cc | 69 +-- .../is_binary_sp_tree_left_associative.cc | 68 ++- .../is_binary_sp_tree_right_associative.cc | 68 ++- 185 files changed, 2836 insertions(+), 2582 deletions(-) delete mode 100644 lib/compiler/src/graph_optimize_state.cc diff --git a/.proj.toml b/.proj.toml index 22649424f8..5592f184ad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -22,11 +22,11 @@ test_targets = [ "utils-tests", "op-attrs-tests", "pcg-tests", - # "substitutions-tests", + "substitutions-tests", "compiler-tests", - # "substitution-generator-tests", - # "local-execution-tests", - # "models-tests", + "substitution-generator-tests", + "local-execution-tests", + "models-tests", ] [cmake_flags_extra] diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h index 7d3aa6bb9f..65bae0c76a 100644 --- a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H -#include #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.dtg.h" +#include namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h index 7a32b7a694..5b7e2f3613 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -10,12 +10,15 @@ namespace FlexFlow { AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); -std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); -std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &); -TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &, - ParallelLayerGuidObliviousMachineMapping const &pre, - ParallelLayerGuidObliviousMachineMapping const &post); +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &, + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h index 3a34e956ad..8390c5b9cb 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -1,14 +1,15 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" -#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" namespace FlexFlow { -AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, - PCGBinarySeriesSplit const &split); +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h index 2800c0a353..990c1c8205 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h +++ b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H #include "pcg/machine_specification.dtg.h" -#include #include +#include namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index fc33845320..62da90bfcb 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -8,25 +8,25 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" -#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" #include "pcg/machine_specification.dtg.h" namespace FlexFlow { -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, - MachineMappingConstraints const &constraints); - -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, - MachineMappingConstraints const &constraints, - std::optional const ¶llel_split_transformation); +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation); MachineMappingResult get_optimal_machine_mapping( MachineMappingCache &result_cache, @@ -35,12 +35,12 @@ MachineMappingResult get_optimal_machine_mapping( MachineSpecification const &resources, MachineMappingConstraints const &constraints); -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &, - UnmappedOpCostEstimateKey const &leaf, - MachineSpecification const &resources, - MachineMappingConstraints const &constraints); +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h index 770bfe982d..ee3d2bf159 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -8,10 +8,11 @@ namespace FlexFlow { -TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, - PCGBinarySeriesSplit const &split, - ParallelLayerGuidObliviousMachineMapping const &pre_mapping, - ParallelLayerGuidObliviousMachineMapping const &post_mapping); +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index 20cf75e69a..3a0fcf0e15 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -6,8 +6,12 @@ namespace FlexFlow { MachineMappingCache empty_machine_mapping_cache(); -std::optional machine_mapping_cache_load(MachineMappingCache const &, MachineMappingState const &); -void machine_mapping_cache_save(MachineMappingCache &, MachineMappingState const &, MachineMappingResult const &); +std::optional + machine_mapping_cache_load(MachineMappingCache const &, + MachineMappingState const &); +void machine_mapping_cache_save(MachineMappingCache &, + MachineMappingState const &, + MachineMappingResult const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index 6e46b49c69..87c556910f 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -1,30 +1,36 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#include "compiler/machine_mapping/include_unconstrained.dtg.h" #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" -#include "compiler/machine_mapping/include_unconstrained.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" #include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" namespace FlexFlow { -MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &); +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &); -std::unordered_set get_all_layers(MachineMappingConstraints const &, - IncludeUnconstrained const &); +std::unordered_set + get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); -std::optional get_machine_view_for_layer(MachineMappingConstraints const &, - BinaryTreePath const &); +std::optional + get_machine_view_for_layer(MachineMappingConstraints const &, + BinaryTreePath const &); -MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, BinaryTreePathEntry const &); -MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &); -MachineMappingConstraints restrict_to_right_child(MachineMappingConstraints const &); - +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, + BinaryTreePathEntry const &); +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &); +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &); -MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, - ParallelLayerGuidObliviousMachineMapping const &); +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &, + ParallelLayerGuidObliviousMachineMapping const &); std::optional require_only_root(MachineMappingConstraints const &); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h index b5ab1988ad..2635f4a318 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -9,8 +9,9 @@ namespace FlexFlow { -MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, - PCGBinarySPDecomposition const &sp); +MachineMappingProblemTree + get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 13a7358a6e..4064b2f0c9 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -2,36 +2,43 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" namespace FlexFlow { +MachineMappingProblemTree mm_problem_tree_make_series_split( + AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &pre, + MachineMappingProblemTree const &post); MachineMappingProblemTree - mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &pre, - MachineMappingProblemTree const &post); + mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs); MachineMappingProblemTree - mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs); -MachineMappingProblemTree mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &); + mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &); SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); -MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &); -MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &); +MMProblemTreeSeriesSplit + require_series_split(MachineMappingProblemTree const &); +MMProblemTreeParallelSplit + require_parallel_split(MachineMappingProblemTree const &); UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &); MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &); -MachineMappingProblemTree wrap_parallel_split(MMProblemTreeParallelSplit const &); +MachineMappingProblemTree + wrap_parallel_split(MMProblemTreeParallelSplit const &); -std::unordered_multiset get_leaves(MachineMappingProblemTree const &); -std::unordered_set get_all_leaf_paths(MachineMappingProblemTree const &); +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &); +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &); -std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, - BinaryTreePath const &); +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, + BinaryTreePath const &); template Result visit(MachineMappingProblemTree const &t, F &&f) { @@ -50,7 +57,8 @@ Result visit(MachineMappingProblemTree const &t, F &&f) { return result; } default: - throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); + throw mk_runtime_error( + fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); } } diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h index 8332da66f9..a7faced4d8 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h @@ -8,7 +8,8 @@ namespace FlexFlow { MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &); MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &); -AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &); +AbstractedTensorSetMovement const & + get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h index e90dcdd94c..9fbad4a1d0 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -8,11 +8,12 @@ namespace FlexFlow { -UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer(ParallelComputationGraph const &, - parallel_layer_guid_t const &); +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &, parallel_layer_guid_t const &); -OpCostEstimateKey map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, - MachineView const &machine_view); +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 225c8c6f5c..b21fea5f24 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -10,19 +10,26 @@ namespace FlexFlow { [[nodiscard]] bool is_infeasible(MachineMappingResult const &); FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); -[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime(std::unordered_set const &); - -[[nodiscard]] MachineMappingResult series_combine(float comm_cost, - MachineMappingResult const &pre_result, - MachineMappingResult const &post_result, - std::optional const ¶llel_split_transformation); -[[nodiscard]] MachineMappingResult parallel_combine(MachineMappingResult const &lhs_result, - MachineMappingResult const &rhs_result); - -[[nodiscard]] MachineMappingResult minimize_runtime(MachineMappingResult const &m1, MachineMappingResult const &m2); - -[[nodiscard]] MachineMappingResult make_singleton_machine_mapping_result(float runtime, - MachineView const &machine_view); +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &); + +[[nodiscard]] MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const + ¶llel_split_transformation); +[[nodiscard]] MachineMappingResult + parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); + +[[nodiscard]] MachineMappingResult + minimize_runtime(MachineMappingResult const &m1, + MachineMappingResult const &m2); + +[[nodiscard]] MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h index 23c589a261..accd96af4c 100644 --- a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -5,15 +5,18 @@ namespace FlexFlow { -ParallelLayerGuidObliviousMachineMapping - binary_combine_mappings(ParallelLayerGuidObliviousMachineMapping const &pre, - ParallelLayerGuidObliviousMachineMapping const &post); +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); -ParallelLayerGuidObliviousMachineMapping restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); -ParallelLayerGuidObliviousMachineMapping restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping + restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping + restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); -std::optional get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, - BinaryTreePath const &); +std::optional + get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h index 3545c4fa63..60c47ba049 100644 --- a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H -#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" #include "compiler/series_parallel/pcg_binary_series_split.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" @@ -11,21 +11,23 @@ namespace FlexFlow { -TransitiveReducedDataflowGraphView get_underlying_transitive_reduced_dataflow_graph(TransitiveReducedPCG const &); - -TransitiveReducedPCG pcg_get_transitive_reduction(ParallelComputationGraph const &); +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &); -std::unordered_set - pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, - PCGBinarySeriesSplit const &); +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &); -std::unordered_set - pcg_get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &, - PCGBinarySeriesSplit const &); +std::unordered_set + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); -PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split(TransitiveReducedPCG const &, - PCGBinarySeriesSplit const &); +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h index 65a7a69ef8..d43edaa79d 100644 --- a/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h @@ -4,8 +4,7 @@ namespace FlexFlow { std::optional - get_pcg_balanced_binary_sp_decomposition( - ParallelComputationGraph const &); + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h index 04f84d76fd..d4ae77541a 100644 --- a/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h @@ -7,8 +7,7 @@ namespace FlexFlow { std::optional - get_pcg_series_parallel_decomposition( - ParallelComputationGraph const &); + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h index 2744393fc2..e2e170b4d5 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h @@ -12,14 +12,18 @@ namespace FlexFlow { std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); std::unordered_multiset - get_parallel_layers(PCGBinarySPDecomposition const &); + get_parallel_layers(PCGBinarySPDecomposition const &); SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); -PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); -PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &, PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition + make_pcg_series_split(PCGBinarySPDecomposition const &, + PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition + make_pcg_parallel_split(PCGBinarySPDecomposition const &, + PCGBinarySPDecomposition const &); PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &); PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &); @@ -29,8 +33,9 @@ PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &); PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &); parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &); -std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &, - parallel_layer_guid_t const &); +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); template ReturnType visit(PCGBinarySPDecomposition const &d, F &&f) { diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 7ac9759650..232f2b9563 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#include "compiler/graph_optimize_result.dtg.h" #include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/graph_optimize_result.dtg.h" #include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 8fc7239a45..6f3deca138 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -1,8 +1,8 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "utils/containers/flatmap.h" -#include "utils/containers/unordered_set_of.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" namespace FlexFlow { @@ -10,43 +10,52 @@ AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { return AbstractedTensorSetMovement{{}}; } -std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &m) { return flatmap(unordered_set_of(m.single_tensor_movements), - [](AbstractedSingleTensorMovement const &s) { + [](AbstractedSingleTensorMovement const &s) { return s.src_machine_views; }); } -std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &m) { return flatmap(unordered_set_of(m.single_tensor_movements), - [](AbstractedSingleTensorMovement const &s) { + [](AbstractedSingleTensorMovement const &s) { return s.dst_machine_views; }); } -TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &abstracted, - ParallelLayerGuidObliviousMachineMapping const &pre_mapping, - ParallelLayerGuidObliviousMachineMapping const &post_mapping) { - ParallelLayerGuidObliviousMachineMapping mapping = - binary_combine_mappings(/*lhs=*/pre_mapping, - /*rhs=*/post_mapping); - - auto concretize_tensor_movement = [&](AbstractedSingleTensorMovement const &a) { - return SingleTensorMovement{ - /*parallel_tensor_shape=*/a.parallel_tensor_shape, - /*src_machine_views=*/transform(a.src_machine_views, - [&](BinaryTreePath const &path) { - return get_machine_view_for_path(pre_mapping, path).value(); - }), - /*dst_machine_views=*/transform(a.dst_machine_views, - [&](BinaryTreePath const &path) { - return get_machine_view_for_path(post_mapping, path).value(); - }), - }; - }; +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &abstracted, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(/*lhs=*/pre_mapping, + /*rhs=*/post_mapping); + + auto concretize_tensor_movement = + [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/ + transform( + a.src_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(pre_mapping, path).value(); + }), + /*dst_machine_views=*/ + transform( + a.dst_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(post_mapping, path).value(); + }), + }; + }; return TensorSetMovement{ - /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, concretize_tensor_movement), + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, + concretize_tensor_movement), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 6ec7a545b5..56fafbc5e3 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,7 +1,7 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "compiler/series_parallel/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" @@ -11,47 +11,52 @@ namespace FlexFlow { -AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - std::unordered_set - edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + std::unordered_set edges_across_split = + pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { - std::unordered_set tensor_edges = filter(edges_across_split, - [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); + std::unordered_set tensor_edges = + filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { + return get_parallel_tensor(e) == t; + }); - std::unordered_set src_layers = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_src_layer(e); - }); + std::unordered_set src_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); - std::unordered_set dst_layers = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_dst_layer(e); - }); + std::unordered_set dst_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); return AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/transform(src_layers, - [&](parallel_layer_guid_t const &l) { - return get_only(find_paths_to_leaf(get_left_child(split), l)); - }), - /*dst_machine_views=*/transform(dst_layers, - [&](parallel_layer_guid_t const &l) { - return get_only(find_paths_to_leaf(get_right_child(split), l)); - }), + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/ + transform(src_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(get_left_child(split), l)); + }), + /*dst_machine_views=*/ + transform(dst_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(get_right_child(split), l)); + }), }; }; - std::unordered_map single_tensor_movements = - generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), - get_movement_for_tensor); + std::unordered_map + single_tensor_movements = generate_map( + pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); return AbstractedTensorSetMovement{ - values(single_tensor_movements), + values(single_tensor_movements), }; } diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc index 51ed1f7ff4..5126d9687e 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -5,7 +5,8 @@ namespace FlexFlow { std::unordered_set> get_machine_resource_splits(MachineSpecification const &resource) { - std::unordered_set> result; + std::unordered_set> + result; for (int i = 1; i < resource.num_nodes; i *= 2) { MachineSpecification sub_resource1 = resource; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 0adf43681e..20da56eb55 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -17,148 +17,162 @@ #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" +#include "utils/containers/flatmap.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/overload.h" -#include "utils/containers/flatmap.h" - namespace FlexFlow { -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - MachineMappingProblemTree const &problem_tree, - MachineSpecification const &resources, - MachineMappingConstraints const &constraints) { +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { MachineMappingState state = MachineMappingState{ - problem_tree, resources, constraints, + problem_tree, + resources, + constraints, }; { std::optional cached_result = - machine_mapping_cache_load(result_cache, state); + machine_mapping_cache_load(result_cache, state); if (cached_result) { return cached_result.value(); } } MachineMappingResult result = visit( - problem_tree, - overload { - [&](MMProblemTreeSeriesSplit const &series_split) { - return get_optimal_machine_mapping - (result_cache, - context, - series_split, - resources, - constraints, - /*parallel_split_transformation=*/std::nullopt); - }, - [&](auto const &decomp_tree_node) { - return get_optimal_machine_mapping - (result_cache, - context, - decomp_tree_node, - resources, - constraints); - }, - }); + problem_tree, + overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping( + result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping(result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, + }); machine_mapping_cache_save(result_cache, state, result); return result; } -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - MMProblemTreeSeriesSplit const &series_split, - MachineSpecification const &resources, - MachineMappingConstraints const &constraints, - std::optional const ¶llel_split_transformation) { - - auto get_boundary_machine_view_assignments = [&](std::unordered_set const &boundary_layers) - -> std::unordered_set - { +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation) { + + auto get_boundary_machine_view_assignments = + [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { std::unordered_map> - allowed = generate_map(boundary_layers, - [&](BinaryTreePath const &l) -> std::unordered_set { - UnmappedOpCostEstimateKey leaf = require_leaf - (mm_problem_tree_get_subtree_at_path - (wrap_series_split(series_split), l).value()); - return context.allowed_machine_views(leaf, resources); - }); - return transform(get_all_assignments(allowed), - [](std::unordered_map const &m) { - return ParallelLayerGuidObliviousMachineMapping{m}; - }); + allowed = generate_map( + boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = + require_leaf(mm_problem_tree_get_subtree_at_path( + wrap_series_split(series_split), l) + .value()); + return context.allowed_machine_views(leaf, resources); + }); + return transform( + get_all_assignments(allowed), + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; + }); }; - auto eval_pre_boundary_mapping = [&](ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views) { - MachineMappingConstraints pre_candidate = - with_additional_constraints( - restrict_to_left_child(constraints), - assigned_pre_machine_views); - - MachineMappingResult pre_result = - get_optimal_machine_mapping(result_cache, - context, - get_pre_child(series_split), - resources, - pre_candidate); - - return pre_result; - }; - - auto eval_post_boundary_mapping = [&](ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views) { - MachineMappingConstraints post_candidate = - with_additional_constraints( - restrict_to_right_child(constraints), - assigned_post_machine_views); - - MachineMappingResult post_result - = get_optimal_machine_mapping(result_cache, - context, - get_post_child(series_split), - resources, - post_candidate); - - return post_result; - }; + auto eval_pre_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views) { + MachineMappingConstraints pre_candidate = with_additional_constraints( + restrict_to_left_child(constraints), assigned_pre_machine_views); + + MachineMappingResult pre_result = + get_optimal_machine_mapping(result_cache, + context, + get_pre_child(series_split), + resources, + pre_candidate); + + return pre_result; + }; + + auto eval_post_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = with_additional_constraints( + restrict_to_right_child(constraints), assigned_post_machine_views); + + MachineMappingResult post_result = + get_optimal_machine_mapping(result_cache, + context, + get_post_child(series_split), + resources, + post_candidate); + + return post_result; + }; MachineMappingResult result = infeasible_machine_mapping_result(); - AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); - - for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views - : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { - - MachineMappingResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); - - for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views - : get_boundary_machine_view_assignments(get_dst_layers(tensor_movement))) { - - MachineMappingResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); - - TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement(tensor_movement, - /*pre_mapping=*/assigned_pre_machine_views, - /*post_mapping=*/assigned_post_machine_views); - float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); - - result = minimize_runtime(result, - series_combine(cost_across_split, pre_result, post_result, parallel_split_transformation)); + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_movement(series_split); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views : + get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingResult pre_result = + eval_pre_boundary_mapping(assigned_pre_machine_views); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views : + get_boundary_machine_view_assignments( + get_dst_layers(tensor_movement))) { + + MachineMappingResult post_result = + eval_post_boundary_mapping(assigned_post_machine_views); + + TensorSetMovement comm_across_split = + concretize_abstracted_tensor_set_movement( + tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); + float cost_across_split = + context.cost_estimator.estimate_cost(comm_across_split); + + result = minimize_runtime(result, + series_combine(cost_across_split, + pre_result, + post_result, + parallel_split_transformation)); } } return result; } - - MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, + MachineMappingCache &result_cache, MachineMappingContext const &context, MMProblemTreeParallelSplit const ¶llel_split, MachineSpecification const &resources, @@ -168,53 +182,53 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingProblemTree rhs = get_rhs_child(parallel_split); MachineMappingResult series_result = [&] { - MMProblemTreeSeriesSplit series_split = require_series_split(\ - mm_problem_tree_make_series_split( - /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), - /*pre=*/lhs, - /*post=*/rhs)); - + MMProblemTreeSeriesSplit series_split = + require_series_split(mm_problem_tree_make_series_split( + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*pre=*/lhs, + /*post=*/rhs)); + return get_optimal_machine_mapping(result_cache, - context, - series_split, - resources, - constraints, - ParallelSplitTransformation::LthenR); + context, + series_split, + resources, + constraints, + ParallelSplitTransformation::LthenR); }(); - MachineMappingConstraints left_constraints = restrict_to_left_child(constraints); - MachineMappingConstraints right_constraints = restrict_to_right_child(constraints); - - auto evaluate_resource_split = [&](std::pair const &resource_split) { - MachineMappingResult left_result = - get_optimal_machine_mapping(result_cache, - context, - lhs, - resource_split.first, - left_constraints); - MachineMappingResult right_result = - get_optimal_machine_mapping(result_cache, - context, - rhs, - resource_split.second, - right_constraints); - - return parallel_combine(left_result, right_result); - }; - - std::unordered_set parallel_results - = transform(get_machine_resource_splits(resources), - evaluate_resource_split); - - return minimize_runtime(series_result, get_mapping_with_minimal_runtime(parallel_results)); + MachineMappingConstraints left_constraints = + restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = + restrict_to_right_child(constraints); + + auto evaluate_resource_split = + [&](std::pair const + &resource_split) { + MachineMappingResult left_result = get_optimal_machine_mapping( + result_cache, context, lhs, resource_split.first, left_constraints); + MachineMappingResult right_result = + get_optimal_machine_mapping(result_cache, + context, + rhs, + resource_split.second, + right_constraints); + + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results = transform( + get_machine_resource_splits(resources), evaluate_resource_split); + + return minimize_runtime(series_result, + get_mapping_with_minimal_runtime(parallel_results)); } -MachineMappingResult get_optimal_machine_mapping( - MachineMappingCache &result_cache, - MachineMappingContext const &context, - UnmappedOpCostEstimateKey const &leaf, - MachineSpecification const &resource, - MachineMappingConstraints const &constraints) { +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resource, + MachineMappingConstraints const &constraints) { std::unordered_set candidates = [&] { std::optional machine_view = require_only_root(constraints); @@ -226,14 +240,15 @@ MachineMappingResult get_optimal_machine_mapping( }(); auto get_mapping_result = [&](MachineView const &machine_view) { - OpCostEstimateKey mapped = map_unmapped_op_cost_estimate_key(leaf, machine_view); + OpCostEstimateKey mapped = + map_unmapped_op_cost_estimate_key(leaf, machine_view); float cost = context.cost_estimator.estimate_cost(mapped); return make_singleton_machine_mapping_result(cost, machine_view); }; - std::unordered_set candidate_results - = transform(candidates, get_mapping_result); + std::unordered_set candidate_results = + transform(candidates, get_mapping_result); return get_mapping_with_minimal_runtime(candidate_results); } diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 2979947c7c..6cc3f4329c 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -12,15 +12,15 @@ namespace FlexFlow { -TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split, - ParallelLayerGuidObliviousMachineMapping const &pre_mapping, - ParallelLayerGuidObliviousMachineMapping const &post_mapping) { - AbstractedTensorSetMovement abstracted = get_abstracted_tensor_set_movement_across_split(tr_pcg, split); - return concretize_abstracted_tensor_set_movement(abstracted, pre_mapping, post_mapping); +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + AbstractedTensorSetMovement abstracted = + get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement( + abstracted, pre_mapping, post_mapping); } - } // namespace FlexFlow - - diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc index c78f7fbf56..fbfccf737f 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -1,6 +1,6 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" -#include "utils/containers/try_at.h" #include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" namespace FlexFlow { @@ -8,13 +8,20 @@ MachineMappingCache empty_machine_mapping_cache() { return MachineMappingCache{{}}; } -std::optional machine_mapping_cache_load(MachineMappingCache const &cache, MachineMappingState const &k) { +std::optional + machine_mapping_cache_load(MachineMappingCache const &cache, + MachineMappingState const &k) { return try_at(cache.raw_map, k); } -void machine_mapping_cache_save(MachineMappingCache &cache, MachineMappingState const &k, MachineMappingResult const &v) { +void machine_mapping_cache_save(MachineMappingCache &cache, + MachineMappingState const &k, + MachineMappingResult const &v) { if (contains_key(cache.raw_map, k)) { - throw mk_runtime_error(fmt::format("machine_mapping_cache_save expected key to not already exist, but received existing key {}", k)); + throw mk_runtime_error( + fmt::format("machine_mapping_cache_save expected key to not already " + "exist, but received existing key {}", + k)); } cache.raw_map.emplace(k, v); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 5d04c99e72..2cee866a01 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -1,80 +1,94 @@ #include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "utils/containers/filter.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" -#include "utils/containers/filter.h" #include "utils/containers/generate_map.h" #include "utils/containers/keys.h" -#include "utils/containers/restrict_keys.h" #include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" #include "utils/full_binary_tree/binary_tree_path.h" namespace FlexFlow { -MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &layers) { +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &layers) { return MachineMappingConstraints{ - generate_map(layers, - [](BinaryTreePath const &) -> std::optional { - return std::nullopt; - }), + generate_map(layers, + [](BinaryTreePath const &) -> std::optional { + return std::nullopt; + }), }; } -std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution, - IncludeUnconstrained const &include_unconstrained) { - std::unordered_set with_unconstrained = keys(partial_solution.machine_views); +std::unordered_set + get_all_layers(MachineMappingConstraints const &partial_solution, + IncludeUnconstrained const &include_unconstrained) { + std::unordered_set with_unconstrained = + keys(partial_solution.machine_views); if (include_unconstrained.raw_bool) { return with_unconstrained; } else { - return filter(with_unconstrained, - [&](BinaryTreePath const &l) { return partial_solution.machine_views.at(l).has_value(); }); + return filter(with_unconstrained, [&](BinaryTreePath const &l) { + return partial_solution.machine_views.at(l).has_value(); + }); } } -std::optional get_machine_view_for_layer(MachineMappingConstraints const &partial_solution, - BinaryTreePath const &layer) { +std::optional get_machine_view_for_layer( + MachineMappingConstraints const &partial_solution, + BinaryTreePath const &layer) { return partial_solution.machine_views.at(layer); } -MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &constraints, - BinaryTreePathEntry const &prefix) { - return MachineMappingConstraints{ - filtermap_keys(constraints.machine_views, - [&](BinaryTreePath const &path) -> std::optional { - BinaryTreePathEntry head = binary_tree_path_get_top_level(path); - - if (head == prefix) { - BinaryTreePath rest = binary_tree_path_get_non_top_level(path); - return rest; - } else { - return std::nullopt; - } - }) - }; +MachineMappingConstraints + restrict_to_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { + return MachineMappingConstraints{filtermap_keys( + constraints.machine_views, + [&](BinaryTreePath const &path) -> std::optional { + BinaryTreePathEntry head = binary_tree_path_get_top_level(path); + + if (head == prefix) { + BinaryTreePath rest = binary_tree_path_get_non_top_level(path); + return rest; + } else { + return std::nullopt; + } + })}; } -MachineMappingConstraints restrict_to_left_child(MachineMappingConstraints const &c) { +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &c) { return restrict_to_child(c, BinaryTreePathEntry::LEFT_CHILD); } -MachineMappingConstraints restrict_to_right_child(MachineMappingConstraints const &c) { +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &c) { return restrict_to_child(c, BinaryTreePathEntry::RIGHT_CHILD); } -MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &constraints, - ParallelLayerGuidObliviousMachineMapping const &additional) { +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &constraints, + ParallelLayerGuidObliviousMachineMapping const &additional) { MachineMappingConstraints result = constraints; for (auto const &[layer, machine_view] : additional.raw_mapping) { - std::optional current_machine_view = result.machine_views.at(layer); + std::optional current_machine_view = + result.machine_views.at(layer); if (!current_machine_view.has_value()) { result.machine_views.at(layer) = machine_view; } else { if (current_machine_view.value() != machine_view) { - throw mk_runtime_error(fmt::format("with_additional_layer_machine_views received machine view assignment for layer {} " - "to machine view {}, but that layer is already assigned to machine view {}.", - layer, machine_view, current_machine_view.value())); + throw mk_runtime_error( + fmt::format("with_additional_layer_machine_views received machine " + "view assignment for layer {} " + "to machine view {}, but that layer is already " + "assigned to machine view {}.", + layer, + machine_view, + current_machine_view.value())); } } } @@ -82,13 +96,17 @@ MachineMappingConstraints with_additional_constraints(MachineMappingConstraints return result; } -std::optional require_only_root(MachineMappingConstraints const &constraints) { - if (keys(constraints.machine_views) != std::unordered_set{binary_tree_root_path()}) { - throw mk_runtime_error(fmt::format("require_only_root expected constraints to have only a single key (the root path), but received {}", constraints)); +std::optional + require_only_root(MachineMappingConstraints const &constraints) { + if (keys(constraints.machine_views) != + std::unordered_set{binary_tree_root_path()}) { + throw mk_runtime_error( + fmt::format("require_only_root expected constraints to have only a " + "single key (the root path), but received {}", + constraints)); } return constraints.machine_views.at(binary_tree_root_path()); } - } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index c9864c2e25..42f1cd3809 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,6 +1,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg_binary_series_split.h" @@ -8,36 +9,39 @@ #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/overload.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" namespace FlexFlow { -MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, - PCGBinarySPDecomposition const &sp_decomposition_tree) { +MachineMappingProblemTree get_machine_mapping_problem_tree( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); - std::function to_problem_tree; + std::function + to_problem_tree; - to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + to_problem_tree = + [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { return visit( - sp, - overload { - [&](PCGBinarySeriesSplit const &series) { - AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); - return mm_problem_tree_make_series_split( - /*tensor_set_movement=*/tensor_movement, - /*lhs=*/to_problem_tree(get_left_child(series)), - /*rhs=*/to_problem_tree(get_right_child(series))); - }, - [&](PCGBinaryParallelSplit const ¶llel) { - return mm_problem_tree_make_parallel_split( - to_problem_tree(get_left_child(parallel)), - to_problem_tree(get_right_child(parallel))); - }, - [&](parallel_layer_guid_t const &leaf) { - return mm_problem_tree_make_leaf(get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf)); - } - }); + sp, + overload{[&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, + series); + return mm_problem_tree_make_series_split( + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(get_left_child(series)), + /*rhs=*/to_problem_tree(get_right_child(series))); + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return mm_problem_tree_make_parallel_split( + to_problem_tree(get_left_child(parallel)), + to_problem_tree(get_right_child(parallel))); + }, + [&](parallel_layer_guid_t const &leaf) { + return mm_problem_tree_make_leaf( + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf)); + }}); }; return to_problem_tree(sp_decomposition_tree); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 6d14fbe3cf..992a73db03 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -1,58 +1,60 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" namespace FlexFlow { -MachineMappingProblemTree mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { +MachineMappingProblemTree mm_problem_tree_make_series_split( + AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { return MachineMappingProblemTree{ - make_generic_binary_series_split( - MMProblemTreeSeriesSplitLabel{tensor_set_movement}, - lhs.raw_tree, - rhs.raw_tree), + make_generic_binary_series_split( + MMProblemTreeSeriesSplitLabel{tensor_set_movement}, + lhs.raw_tree, + rhs.raw_tree), }; } -MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { +MachineMappingProblemTree + mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { return MachineMappingProblemTree{ - make_generic_binary_parallel_split( - MMProblemTreeParallelSplitLabel{}, - lhs.raw_tree, - rhs.raw_tree), + make_generic_binary_parallel_split( + MMProblemTreeParallelSplitLabel{}, lhs.raw_tree, rhs.raw_tree), }; } -MachineMappingProblemTree mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &leaf_label) { +MachineMappingProblemTree + mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &leaf_label) { return MachineMappingProblemTree{ - make_generic_binary_sp_leaf< - MMProblemTreeSeriesSplitLabel, - MMProblemTreeParallelSplitLabel, - UnmappedOpCostEstimateKey>(leaf_label), + make_generic_binary_sp_leaf(leaf_label), }; } -SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { +SPDecompositionTreeNodeType + get_node_type(MachineMappingProblemTree const &tree) { return get_node_type(tree.raw_tree); } - -MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { +MMProblemTreeSeriesSplit + require_series_split(MachineMappingProblemTree const &t) { return MMProblemTreeSeriesSplit{ - require_generic_binary_series_split(t.raw_tree), + require_generic_binary_series_split(t.raw_tree), }; } -MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { +MMProblemTreeParallelSplit + require_parallel_split(MachineMappingProblemTree const &t) { return MMProblemTreeParallelSplit{ - require_generic_binary_parallel_split(t.raw_tree), + require_generic_binary_parallel_split(t.raw_tree), }; } @@ -60,31 +62,38 @@ UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { return require_generic_binary_leaf(t.raw_tree); } -MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &series) { +MachineMappingProblemTree + wrap_series_split(MMProblemTreeSeriesSplit const &series) { return MachineMappingProblemTree{ - wrap_series_split(series.raw_split), + wrap_series_split(series.raw_split), }; } -MachineMappingProblemTree wrap_parallel_split(MMProblemTreeParallelSplit const ¶llel) { +MachineMappingProblemTree + wrap_parallel_split(MMProblemTreeParallelSplit const ¶llel) { return MachineMappingProblemTree{ - wrap_parallel_split(parallel.raw_split), + wrap_parallel_split(parallel.raw_split), }; } -std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &t) { return get_leaves(t.raw_tree); } -std::unordered_set get_all_leaf_paths(MachineMappingProblemTree const &t) { +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &t) { return get_all_leaf_paths(t.raw_tree); } -std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, - BinaryTreePath const &path) { - std::optional> raw_subtree = get_subtree_at_path(tree.raw_tree, path); +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, + BinaryTreePath const &path) { + std::optional< + GenericBinarySPDecompositionTree> + raw_subtree = get_subtree_at_path(tree.raw_tree, path); if (!raw_subtree.has_value()) { return std::nullopt; diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc index 1b9cd59572..e31613ee25 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc @@ -6,15 +6,14 @@ namespace FlexFlow { MachineMappingProblemTree get_lhs_child(MMProblemTreeParallelSplit const &p) { return MachineMappingProblemTree{ - get_left_child(p.raw_split), + get_left_child(p.raw_split), }; } MachineMappingProblemTree get_rhs_child(MMProblemTreeParallelSplit const &p) { return MachineMappingProblemTree{ - get_right_child(p.raw_split), + get_right_child(p.raw_split), }; } - } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc index 545d06957a..ac67baaf47 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc @@ -4,17 +4,18 @@ namespace FlexFlow { MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &s) { return MachineMappingProblemTree{ - s.raw_split.pre, + s.raw_split.pre, }; } MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &s) { return MachineMappingProblemTree{ - s.raw_split.post, + s.raw_split.post, }; } -AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &s) { +AbstractedTensorSetMovement const & + get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &s) { return s.raw_split.label.tensor_set_movement; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc index 2574fb81aa..990b287f8b 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -4,29 +4,32 @@ namespace FlexFlow { -UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &layer) { - auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { return get_parallel_tensor_shape(pcg, t); }; return UnmappedOpCostEstimateKey{ - /*op_attrs=*/pcg_get_op_attrs(pcg, layer), - /*input_shapes=*/transform(get_incoming_inputs(pcg, layer), get_tensor_shape), - /*weight_shapes=*/transform(get_incoming_weights(pcg, layer), get_tensor_shape), - /*output_shapes=*/transform(get_layer_outputs(pcg, layer), get_tensor_shape), + /*op_attrs=*/pcg_get_op_attrs(pcg, layer), + /*input_shapes=*/ + transform(get_incoming_inputs(pcg, layer), get_tensor_shape), + /*weight_shapes=*/ + transform(get_incoming_weights(pcg, layer), get_tensor_shape), + /*output_shapes=*/ + transform(get_layer_outputs(pcg, layer), get_tensor_shape), }; } - -OpCostEstimateKey map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, - MachineView const &machine_view) { +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view) { return OpCostEstimateKey{ - /*op_attrs=*/unmapped.op_attrs, - /*input_shapes=*/unmapped.input_shapes, - /*weight_shapes=*/unmapped.weight_shapes, - /*output_shapes=*/unmapped.output_shapes, - /*machine_view=*/machine_view, + /*op_attrs=*/unmapped.op_attrs, + /*input_shapes=*/unmapped.input_shapes, + /*weight_shapes=*/unmapped.weight_shapes, + /*output_shapes=*/unmapped.output_shapes, + /*machine_view=*/machine_view, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 1e4de0a929..3409f7f871 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,9 +1,9 @@ #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" #include "utils/containers/map_keys.h" #include "utils/containers/merge_maps.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" namespace FlexFlow { @@ -15,11 +15,13 @@ bool is_infeasible(MachineMappingResult const &result) { return !result.raw_result.has_value(); } -FeasibleMachineMappingResult require_feasible(MachineMappingResult const &result) { +FeasibleMachineMappingResult + require_feasible(MachineMappingResult const &result) { return result.raw_result.value(); } -[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime(std::unordered_set const &candidates) { +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &candidates) { MachineMappingResult result = infeasible_machine_mapping_result(); for (MachineMappingResult const &candidate : candidates) { @@ -29,10 +31,12 @@ FeasibleMachineMappingResult require_feasible(MachineMappingResult const &result return result; } -MachineMappingResult series_combine(float comm_cost, - MachineMappingResult const &maybe_pre_result, - MachineMappingResult const &maybe_post_result, - std::optional const ¶llel_split_transformation) { +MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &maybe_pre_result, + MachineMappingResult const &maybe_post_result, + std::optional const + ¶llel_split_transformation) { FeasibleMachineMappingResult pre_result = ({ if (is_infeasible(maybe_pre_result)) { return infeasible_machine_mapping_result(); @@ -48,26 +52,28 @@ MachineMappingResult series_combine(float comm_cost, }); ParallelLayerGuidObliviousMachineMapping mapping = [&] { - if (parallel_split_transformation.has_value() - && parallel_split_transformation.value() == ParallelSplitTransformation::RthenL) { - return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, + if (parallel_split_transformation.has_value() && + parallel_split_transformation.value() == + ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, /*rhs=*/pre_result.machine_mapping); } else { - return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, + return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, /*rhs=*/post_result.machine_mapping); } }(); return MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, - /*machine_mapping=*/mapping, - }, + FeasibleMachineMappingResult{ + /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, + /*machine_mapping=*/mapping, + }, }; } -MachineMappingResult parallel_combine(MachineMappingResult const &maybe_lhs_result, - MachineMappingResult const &maybe_rhs_result) { +MachineMappingResult + parallel_combine(MachineMappingResult const &maybe_lhs_result, + MachineMappingResult const &maybe_rhs_result) { FeasibleMachineMappingResult lhs_result = ({ if (is_infeasible(maybe_lhs_result)) { return infeasible_machine_mapping_result(); @@ -83,12 +89,12 @@ MachineMappingResult parallel_combine(MachineMappingResult const &maybe_lhs_resu }); return MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), - /*machine_mapping=*/binary_combine_mappings - (/*lhs=*/lhs_result.machine_mapping, - /*rhs=*/rhs_result.machine_mapping), - }, + FeasibleMachineMappingResult{ + /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), + /*machine_mapping=*/ + binary_combine_mappings(/*lhs=*/lhs_result.machine_mapping, + /*rhs=*/rhs_result.machine_mapping), + }, }; } @@ -115,15 +121,17 @@ MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, } } -MachineMappingResult make_singleton_machine_mapping_result(float runtime, - MachineView const &machine_view) { +MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view) { return MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/runtime, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), machine_view}, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/runtime, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), machine_view}, + }}, + }, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index 63035f5801..715a4c2e3d 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -6,18 +6,18 @@ namespace FlexFlow { -ParallelLayerGuidObliviousMachineMapping - binary_combine_mappings(ParallelLayerGuidObliviousMachineMapping const &lhs, - ParallelLayerGuidObliviousMachineMapping const &rhs) { +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &lhs, + ParallelLayerGuidObliviousMachineMapping const &rhs) { return ParallelLayerGuidObliviousMachineMapping{ - merge_maps( - map_keys(lhs.raw_mapping, nest_inside_left_child), - map_keys(rhs.raw_mapping, nest_inside_right_child)), + merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), }; } -std::optional get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &mapping, - BinaryTreePath const &path) { +std::optional get_machine_view_for_path( + ParallelLayerGuidObliviousMachineMapping const &mapping, + BinaryTreePath const &path) { return try_at(mapping.raw_mapping, path); } diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index ccb6ae2eed..618d93e9f2 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -4,73 +4,87 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" namespace FlexFlow { -TransitiveReducedDataflowGraphView get_underlying_transitive_reduced_dataflow_graph(TransitiveReducedPCG const &tr_pcg) { +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &tr_pcg) { return TransitiveReducedDataflowGraphView{ - /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, - /*transitive_reduction=*/tr_pcg.transitive_reduction, + /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, + /*transitive_reduction=*/tr_pcg.transitive_reduction, }; } -TransitiveReducedPCG pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { DiGraphView raw_digraph = pcg.raw_graph; DiGraphView transitive_reduced = transitive_reduction(raw_digraph); return TransitiveReducedPCG{ - /*pcg=*/pcg, - /*transitive_reduction=*/transitive_reduced, + /*pcg=*/pcg, + /*transitive_reduction=*/transitive_reduced, }; } -std::unordered_set - pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { +std::unordered_set + pcg_get_transitive_reduced_edges_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = get_raw_graph_series_split(split); - std::unordered_set raw_edges = get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + std::unordered_set raw_edges = + get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); - return transform(raw_edges, - [](DataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); } -std::unordered_set - pcg_get_transitive_reduced_tensors_across_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = get_raw_graph_series_split(split); - std::unordered_set raw_outputs = get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + std::unordered_set raw_outputs = + get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); - return transform(raw_outputs, - [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); + return transform(raw_outputs, [](DataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); } -PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split(TransitiveReducedPCG const &tr_pcg, - PCGBinarySeriesSplit const &split) { - TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); BinarySeriesSplit raw_split = get_raw_graph_series_split(split); - SplitBoundaryNodes raw_boundary = get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); + SplitBoundaryNodes raw_boundary = + get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); return PCGSplitBoundaryLayers{ - /*pre_split_boundary=*/transform(raw_boundary.pre_split_boundary, [](Node const &n) { return parallel_layer_guid_t{n}; }), - /*post_split_boundary=*/transform(raw_boundary.post_split_boundary, [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*pre_split_boundary=*/transform( + raw_boundary.pre_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*post_split_boundary=*/ + transform(raw_boundary.post_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), }; } - } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc index e1c118f891..25fda37c1e 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -36,10 +36,10 @@ std::optional left_associative_binary_sp_tree_from_nary(sp_decomposition); auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ - [](Node const &n) { return layer_guid_t{n}; }, + [](Node const &n) { return layer_guid_t{n}; }, }; - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, visitor)}; + return ComputationGraphBinarySPDecomposition{ + transform(raw_binary_tree.raw_tree, visitor)}; } std::optional @@ -58,10 +58,10 @@ std::optional right_associative_binary_sp_tree_from_nary(sp_decomposition); auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ - [](Node const &n) { return layer_guid_t{n}; }, + [](Node const &n) { return layer_guid_t{n}; }, }; - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, visitor)}; + return ComputationGraphBinarySPDecomposition{ + transform(raw_binary_tree.raw_tree, visitor)}; } bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { diff --git a/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc index 5559465fa3..95e810fe8f 100644 --- a/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc @@ -3,8 +3,7 @@ namespace FlexFlow { std::optional - get_pcg_series_parallel_decomposition( - ParallelComputationGraph const &) { + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &) { NOT_IMPLEMENTED(); } diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc index dad21c6c8c..0888b5c02d 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc @@ -1,18 +1,18 @@ #include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" namespace FlexFlow { PCGBinarySPDecomposition get_left_child(PCGBinaryParallelSplit const &s) { return PCGBinarySPDecomposition{ - get_left_child(s.raw_split), + get_left_child(s.raw_split), }; } PCGBinarySPDecomposition get_right_child(PCGBinaryParallelSplit const &s) { return PCGBinarySPDecomposition{ - get_right_child(s.raw_split), + get_right_child(s.raw_split), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc index 0b972706d1..1d1ac2b9e4 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -1,28 +1,28 @@ #include "compiler/series_parallel/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" namespace FlexFlow { BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &s) { - auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ - [](parallel_layer_guid_t const &l) { return l.raw_graph_node; } - }; + auto visitor = + LeafOnlyBinarySPDecompositionTreeVisitor{ + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }}; return BinarySeriesSplit{ - transform(s.raw_split, visitor), + transform(s.raw_split, visitor), }; } PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &s) { return PCGBinarySPDecomposition{ - get_left_child(s.raw_split), + get_left_child(s.raw_split), }; } PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &s) { return PCGBinarySPDecomposition{ - get_right_child(s.raw_split), + get_right_child(s.raw_split), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index 9a2fc43a37..1b53a3c047 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -1,20 +1,20 @@ #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" namespace FlexFlow { std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { NOT_IMPLEMENTED(); } std::unordered_multiset - get_parallel_layers(PCGBinarySPDecomposition const &d) { + get_parallel_layers(PCGBinarySPDecomposition const &d) { return get_leaves(d.raw_tree); } @@ -22,45 +22,49 @@ SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { return get_node_type(d.raw_tree); } -PCGBinarySPDecomposition make_pcg_series_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { +PCGBinarySPDecomposition + make_pcg_series_split(PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), }; } -PCGBinarySPDecomposition make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { +PCGBinarySPDecomposition + make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), + leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), }; } PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_leaf(l), + leaf_only_binary_sp_tree_make_leaf(l), }; } PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &s) { return PCGBinarySPDecomposition{ - wrap_series_split(s.raw_split), + wrap_series_split(s.raw_split), }; } PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) { return PCGBinarySPDecomposition{ - wrap_parallel_split(p.raw_split), + wrap_parallel_split(p.raw_split), }; } PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) { return PCGBinarySeriesSplit{ - require_leaf_only_binary_series_split(d.raw_tree), + require_leaf_only_binary_series_split(d.raw_tree), }; } PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) { return PCGBinaryParallelSplit{ - require_leaf_only_binary_parallel_split(d.raw_tree), + require_leaf_only_binary_parallel_split(d.raw_tree), }; } @@ -68,10 +72,10 @@ parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { return require_leaf_only_binary_leaf(d.raw_tree); } -std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &spd, - parallel_layer_guid_t const &l) { +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &spd, + parallel_layer_guid_t const &l) { return find_paths_to_leaf(spd.raw_tree, l); } - } // namespace FlexFlow diff --git a/lib/compiler/src/graph_optimize_state.cc b/lib/compiler/src/graph_optimize_state.cc deleted file mode 100644 index 4b4f323ea4..0000000000 --- a/lib/compiler/src/graph_optimize_state.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "compiler/graph_optimize_state.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" - -namespace FlexFlow { - -GraphOptimizeState::GraphOptimizeState( - GraphOptimizeResult const &graph_optimize_result, float runtime) - : graph_optimize_result(graph_optimize_result), runtime(runtime) {} - -bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { - // Note(@wmdi): This is a hack to implement a partially correct homomorphism - // check. Switch to the homomorphism check used in substitutions right after - // https://github.com/flexflow/FlexFlow/pull/1471 is merged. - auto layers1 = topological_ordering(graph_optimize_result.pcg); - auto layers2 = topological_ordering(other.graph_optimize_result.pcg); - if (layers1.size() != layers2.size()) { - return false; - } - std::unordered_map mapping; - for (size_t i = 0; i < layers1.size(); ++i) { - if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != - get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { - return false; - } - auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); - auto inputs2 = - get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); - if (inputs1.size() != inputs2.size()) { - return false; - } - for (size_t j = 0; j < inputs1.size(); ++j) { - if (inputs1[j] != mapping.at(inputs2[j])) { - return false; - } - } - auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); - auto outputs2 = - get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); - if (outputs1.size() != outputs2.size()) { - return false; - } - for (size_t j = 0; j < outputs1.size(); ++j) { - mapping.emplace(outputs2[j], outputs1[j]); - } - } - return true; -} - -bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { - return !(*this == other); -} - -bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { - return runtime < other.runtime; -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::GraphOptimizeState>::operator()( - ::FlexFlow::GraphOptimizeState const &state) const { - // TODO(@wmdi): Eventually it might be good to use a proper graph hash like - // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash - size_t seed = 0; - auto layers = topological_ordering(state.graph_optimize_result.pcg); - ::FlexFlow::hash_combine(seed, layers.size()); - for (auto layer : layers) { - ::FlexFlow::hash_combine( - seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); - auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); - ::FlexFlow::hash_combine(seed, inputs.size()); - for (auto input : inputs) { - for (size_t i = 0; i < layers.size(); ++i) { - if (get_source_layer(input) == layers[i]) { - ::FlexFlow::hash_combine(seed, i); - break; - } - } - } - } - return seed; -} - -} // namespace std diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 263450117f..86a211c535 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -54,14 +54,16 @@ GraphOptimizeResult graph_optimize( // GraphOptimizeState best_state = initial_state; // candidates.push(initial_state); // - // for (int iteration = 0; !candidates.empty() && iteration < opt_config.budget; + // for (int iteration = 0; !candidates.empty() && iteration < + // opt_config.budget; // ++iteration) { // GraphOptimizeState current_state = candidates.top(); // candidates.pop(); // // if (current_state.runtime < best_state.runtime) { // best_state = current_state; - // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) { + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) + // { // continue; // } // diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index c320900414..3587316e4b 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -12,66 +12,62 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_abstracted_tensor_set_movement_across_split") { ParallelComputationGraph pcg = empty_parallel_computation_graph(); - ParallelTensorShape input_shape = - ParallelTensorShape{ + ParallelTensorShape input_shape = ParallelTensorShape{ ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{12, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, DataType::FLOAT, - }; - ParallelLayerAttrs relu_attrs - = ParallelLayerAttrs{ - /*op_attrs=*/PCGOperatorAttrs{ + }; + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ ElementUnaryAttrs{ - /*op_type=*/OperatorType::RELU, - /*scalar=*/std::nullopt, + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, }, - }, - /*name=*/std::nullopt, - }; + }, + /*name=*/std::nullopt, + }; - ParallelLayerAttrs ew_add_attrs - = ParallelLayerAttrs{ + ParallelLayerAttrs ew_add_attrs = ParallelLayerAttrs{ /*op_attrs=*/PCGOperatorAttrs{ - ElementBinaryAttrs{ - /*type=*/OperatorType::EW_ADD, - /*compute_type=*/DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, - }, + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, }, /*name=*/std::nullopt, - }; + }; ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ - /*shape=*/input_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::YES, + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, }; SUBCASE("no edges across split") { ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); - PCGBinarySeriesSplit split = require_series(\ - make_pcg_series_split( - make_pcg_leaf_node(input1.parallel_layer), - make_pcg_leaf_node(input2.parallel_layer))); + PCGBinarySeriesSplit split = require_series( + make_pcg_series_split(make_pcg_leaf_node(input1.parallel_layer), + make_pcg_leaf_node(input2.parallel_layer))); - AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split - (pcg_get_transitive_reduction(pcg), - split); + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ - /*single_tensor_movements=*/{}, + /*single_tensor_movements=*/{}, }; CHECK(result == correct); @@ -80,42 +76,36 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("single edge across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - ParallelLayerAddedResult layer_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - ParallelLayerAddedResult layer_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); - - PCGBinarySeriesSplit split = require_series(\ - make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(layer_1.parallel_layer)), + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(layer_1.parallel_layer)), make_pcg_leaf_node(layer_2.parallel_layer))); - AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split - (pcg_get_transitive_reduction(pcg), - split); + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ - /*single_tensor_movements=*/{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{}}, - }, + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, }, - }, }; CHECK(result == correct); @@ -124,52 +114,47 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("does not include edges removed by transitive reduction") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - ParallelLayerAddedResult layer_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult layer_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult layer_3 - = add_parallel_layer(pcg, - ew_add_attrs, - {get_only(layer_1.outputs), get_only(layer_2.outputs)}, - {relu_output_attrs}); - - PCGBinarySeriesSplit split = require_series(\ - make_pcg_series_split( + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_series_split( - make_pcg_leaf_node(layer_1.parallel_layer), - make_pcg_leaf_node(layer_2.parallel_layer))), + make_pcg_leaf_node(input.parallel_layer), + make_pcg_series_split( + make_pcg_leaf_node(layer_1.parallel_layer), + make_pcg_leaf_node(layer_2.parallel_layer))), make_pcg_leaf_node(layer_3.parallel_layer))); - AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split - (pcg_get_transitive_reduction(pcg), - split); + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ - /*single_tensor_movements=*/{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{}}, - }, + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, }, - }, }; CHECK(result == correct); @@ -178,55 +163,46 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("single tensor, multiple consumers across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - ParallelLayerAddedResult layer_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); - ParallelLayerAddedResult layer_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); - ParallelLayerAddedResult layer_3 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(layer_1.parallel_layer)), - make_pcg_parallel_split( - make_pcg_leaf_node(layer_2.parallel_layer), - make_pcg_leaf_node(layer_3.parallel_layer)))); + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(layer_1.parallel_layer)), + make_pcg_parallel_split(make_pcg_leaf_node(layer_2.parallel_layer), + make_pcg_leaf_node(layer_3.parallel_layer)))); - AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split - (pcg_get_transitive_reduction(pcg), - split); + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ - /*single_tensor_movements=*/{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, }, - }, }; CHECK(result == correct); @@ -235,79 +211,72 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("multiple tensors, multiple consumers across split") { ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); - ParallelLayerAddedResult layer_1 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult layer_2 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(input.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult layer_3 - = add_parallel_layer(pcg, - relu_attrs, - {get_only(layer_1.outputs)}, - {relu_output_attrs}); - - ParallelLayerAddedResult layer_4 - = add_parallel_layer(pcg, - ew_add_attrs, - {get_only(layer_1.outputs), get_only(layer_2.outputs)}, - {relu_output_attrs}); + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_parallel_split( - make_pcg_leaf_node(layer_1.parallel_layer), - make_pcg_leaf_node(layer_2.parallel_layer))), - make_pcg_parallel_split( - make_pcg_leaf_node(layer_3.parallel_layer), - make_pcg_leaf_node(layer_4.parallel_layer)))); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); - AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split - (pcg_get_transitive_reduction(pcg), - split); + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_4 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = require_series(make_pcg_series_split( + make_pcg_series_split( + make_pcg_leaf_node(input.parallel_layer), + make_pcg_parallel_split( + make_pcg_leaf_node(layer_1.parallel_layer), + make_pcg_leaf_node(layer_2.parallel_layer))), + make_pcg_parallel_split(make_pcg_leaf_node(layer_3.parallel_layer), + make_pcg_leaf_node(layer_4.parallel_layer)))); + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ - /*single_tensor_movements=*/{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, }, - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - }, - }, }; CHECK(result == correct); diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc index a660bf1db4..9ee596af3e 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -5,11 +5,11 @@ namespace FlexFlow { TestCostEstimator::TestCostEstimator( - std::function const &get_operator_cost, - std::function const & get_communication_cost) - : get_operator_cost(get_operator_cost), - get_communication_cost(get_communication_cost) -{ } + std::function const &get_operator_cost, + std::function const + &get_communication_cost) + : get_operator_cost(get_operator_cost), + get_communication_cost(get_communication_cost) {} float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { return this->get_operator_cost(k); @@ -20,22 +20,22 @@ float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { } CostEstimator make_fake_cost_estimator( - std::function const &get_operator_cost, - std::function const &get_communication_cost) { + std::function const &get_operator_cost, + std::function const + &get_communication_cost) { - return CostEstimator::create(get_operator_cost, get_communication_cost); + return CostEstimator::create(get_operator_cost, + get_communication_cost); } CostEstimator make_fake_cost_estimator( - std::unordered_map const &op_cost_map, - std::unordered_map const &comm_cost_map) { + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map) { return make_fake_cost_estimator( - [op_cost_map](OpCostEstimateKey const &k) { - return op_cost_map.at(k); - }, - [comm_cost_map](TensorSetMovement const &m) { - return comm_cost_map.at(m); - }); + [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, + [comm_cost_map](TensorSetMovement const &m) { + return comm_cost_map.at(m); + }); } -} // namespace FlexFlop +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index d3cc2e0f03..7c1d06207a 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -16,7 +16,8 @@ struct TestCostEstimator : public ICostEstimator { TestCostEstimator() = delete; TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, - decltype(get_communication_cost) const &get_communication_cost); + decltype(get_communication_cost) + const &get_communication_cost); float estimate_cost(OpCostEstimateKey const &) const override; @@ -24,12 +25,13 @@ struct TestCostEstimator : public ICostEstimator { }; CostEstimator make_fake_cost_estimator( - std::function const &get_operator_cost, - std::function const &get_communication_cost); + std::function const &get_operator_cost, + std::function const + &get_communication_cost); CostEstimator make_fake_cost_estimator( - std::unordered_map const &op_cost_map, - std::unordered_map const &comm_cost_map); + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc index 1c4aee109a..499b111f8f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -1,8 +1,8 @@ #include "compiler/machine_mapping/get_machine_resource_splits.h" -#include -#include "utils/hash/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/hash/pair.h" +#include using namespace ::FlexFlow; @@ -10,205 +10,225 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_machine_resource_splits") { auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) { return MachineSpecification{ - /*num_nodes=*/num_nodes, - /*num_cpus_per_node=*/1, - /*num_gpus_per_node=*/num_gpus_per_node, - /*inter_node_bandwidth=*/1.0, - /*intra_node_bandwidth=*/1.0, + /*num_nodes=*/num_nodes, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/num_gpus_per_node, + /*inter_node_bandwidth=*/1.0, + /*intra_node_bandwidth=*/1.0, }; }; SUBCASE("returns no splits if no splits are possible") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, /*num_gpus_per_node=*/1); - std::unordered_set> result = get_machine_resource_splits(input); - std::unordered_set> correct = {}; + std::unordered_set> + result = get_machine_resource_splits(input); + std::unordered_set> + correct = {}; CHECK(result == correct); } - SUBCASE("returns splits in gpu and node dimensions, but not at the same time") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/2, + SUBCASE( + "returns splits in gpu and node dimensions, but not at the same time") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/2, /*num_gpus_per_node=*/2); - std::unordered_set> result = get_machine_resource_splits(input); - - std::unordered_set> correct = { - { - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - }, - - }; + std::unordered_set> + result = get_machine_resource_splits(input); + + std::unordered_set> + correct = { + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + + }; CHECK(result == correct); } SUBCASE("returns splits in node dimension in powers of two") { SUBCASE("num_nodes is a power of 2") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/8, + MachineSpecification input = make_machine_spec(/*num_nodes=*/8, /*num_gpus_per_node=*/1); - std::unordered_set> result = get_machine_resource_splits(input); - - std::unordered_set> correct = { - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/7, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/6, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/4, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/4, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/6, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/7, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - }, - }; + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; CHECK(result == correct); } SUBCASE("num_nodes is not a power of 2") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/6, + MachineSpecification input = make_machine_spec(/*num_nodes=*/6, /*num_gpus_per_node=*/1); - std::unordered_set> result = get_machine_resource_splits(input); - - std::unordered_set> correct = { - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/5, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/4, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/4, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/2, - /*num_gpus_per_node=*/1), - }, - { - make_machine_spec(/*num_nodes=*/5, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - }, - }; + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; CHECK(result == correct); - } } SUBCASE("returns splits in gpu dimension in powers of two") { SUBCASE("num_gpus_per_node is a power of 2") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, /*num_gpus_per_node=*/8); - std::unordered_set> result = get_machine_resource_splits(input); - - std::unordered_set> correct = { - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/7), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/6), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/4), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/4), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/6), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/7), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - }, - }; + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; CHECK(result == correct); } SUBCASE("num_gpus_per_node is not a power of 2") { - MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, /*num_gpus_per_node=*/6); - std::unordered_set> result = get_machine_resource_splits(input); - - std::unordered_set> correct = { - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/5), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/4), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/4), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/2), - }, - { - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/5), - make_machine_spec(/*num_nodes=*/1, - /*num_gpus_per_node=*/1), - }, - }; + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; } } } diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index b33e0e344d..de26a5f2ad 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,6 +1,5 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "./cost_estimator_for_test.h" -#include #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" @@ -10,29 +9,29 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "utils/containers/get_only.h" #include "utils/full_binary_tree/binary_tree_path.h" +#include using namespace FlexFlow; - TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_optimal_machine_mapping_internal") { + TEST_CASE("get_optimal_machine_mapping") { MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); MachineSpecification full_machine_spec = MachineSpecification{ - /*num_nodes=*/2, - /*num_cpus_per_node=*/1, - /*num_gpus_per_node=*/1, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, }; MachineSpecification split_machine_spec = MachineSpecification{ - /*num_nodes=*/1, - /*num_cpus_per_node=*/1, - /*num_gpus_per_node=*/1, - /*inter_node_bandwidth=*/1, - /*intra_node_bandwidth=*/1, + /*num_nodes=*/1, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, }; auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, @@ -44,92 +43,96 @@ TEST_SUITE(FF_TEST_SUITE) { } }; - UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, - /*input_shapes=*/{}, - /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, }; UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ - /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ - /*type=*/OperatorType::EW_ADD, - /*compute_type=*/DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, - }}, - /*input_shapes=*/{}, - /*weight_shapes=*/{}, - /*output_shapes=*/{}, + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, }; ParallelTensorShape tensor_shape1 = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{}, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/tensor_shape1, - /*src_machine_views=*/{}, - /*dst_machine_views=*/{}, - }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, }}; - ParallelLayerGuidObliviousMachineMapping mm1 = ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv1}, - }}; - ParallelLayerGuidObliviousMachineMapping mm2 = ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv2}, - }}; + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; CostEstimator cost_estimator = make_fake_cost_estimator( - std::unordered_map{{ - {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, - {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, - {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, - {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, - }}, - std::unordered_map{{ - {TensorSetMovement{{}}, 0.0}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), 0.1}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), 0.2}, - {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), 0.3}, - {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), 0.4}, - }}); + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, + {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, + {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, + {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), + 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), + 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), + 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), + 0.4}, + }}); MachineMappingContext context = MachineMappingContext{ - cost_estimator, - allowed_machine_views1, + cost_estimator, + allowed_machine_views1, }; MachineMappingCache cache = empty_machine_mapping_cache(); SUBCASE("single layer") { - MachineMappingProblemTree problem_tree = - mm_problem_tree_make_leaf(k1); + MachineMappingProblemTree problem_tree = mm_problem_tree_make_leaf(k1); - MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); - MachineMappingResult result = get_optimal_machine_mapping(cache, - context, - problem_tree, - full_machine_spec, - constraints); + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/1.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - {binary_tree_root_path(), mv1}, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/1.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, }; CHECK(result == correct); @@ -137,36 +140,35 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in sequence") { MachineMappingProblemTree problem_tree = - mm_problem_tree_make_series_split( - movement1, - mm_problem_tree_make_leaf(k1), - mm_problem_tree_make_leaf(k2)); - - MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); - - MachineMappingResult result = get_optimal_machine_mapping(cache, - context, - problem_tree, - full_machine_spec, - constraints); + mm_problem_tree_make_series_split(movement1, + mm_problem_tree_make_leaf(k1), + mm_problem_tree_make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/1.0 + 2.0 + 0.1, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, }}, - mv1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - mv1, - }, - }}, - }, + }, }; CHECK(result == correct); @@ -174,35 +176,34 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in parallel") { MachineMappingProblemTree problem_tree = - mm_problem_tree_make_parallel_split( - mm_problem_tree_make_leaf(k1), - mm_problem_tree_make_leaf(k2)); + mm_problem_tree_make_parallel_split(mm_problem_tree_make_leaf(k1), + mm_problem_tree_make_leaf(k2)); - MachineMappingConstraints constraints = get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); - MachineMappingResult result = get_optimal_machine_mapping(cache, - context, - problem_tree, - full_machine_spec, - constraints); + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/2.5, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - mv2, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/2.5, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, }}, - mv2, - }, - }}, - }, + }, }; CHECK(result == correct); diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index e75f6626bb..82210a138b 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -3,7 +3,8 @@ // #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" // #include "pcg/machine_view.h" // #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -// #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +// #include +// "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" // #include "utils/containers/get_only.h" // #include // #include "./cost_estimator_for_test.h" @@ -14,7 +15,7 @@ // TEST_CASE("get_tensor_set_movement_across_split") { // ParallelComputationGraph pcg = empty_parallel_computation_graph(); // -// ParallelTensorShape input_shape = +// ParallelTensorShape input_shape = // ParallelTensorShape{ // ParallelTensorDims{ // FFOrdered{ @@ -30,7 +31,7 @@ // }; // ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); // -// ParallelLayerAttrs relu_attrs +// ParallelLayerAttrs relu_attrs // = ParallelLayerAttrs{ // /*op_attrs=*/PCGOperatorAttrs{ // ElementUnaryAttrs{ @@ -48,7 +49,7 @@ // /*create_gradients=*/CreateGrad::YES, // }; // -// ParallelLayerAddedResult relu_1 +// ParallelLayerAddedResult relu_1 // = add_parallel_layer(pcg, // relu_attrs, // {get_only(input.outputs)}, @@ -79,7 +80,8 @@ // {relu_2.parallel_layer, post_mv1}, // }}; // -// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), // split, // pre_mapping, // post_mapping); @@ -125,7 +127,8 @@ // {relu_3.parallel_layer, post_mv1}, // }}; // -// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), // split, // pre_mapping, // post_mapping); @@ -153,7 +156,8 @@ // {relu_3.parallel_layer, post_mv2}, // }}; // -// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), // split, // pre_mapping, // post_mapping); @@ -182,9 +186,10 @@ // ParallelLayerAddedResult relu_4 // = add_parallel_layer(pcg, // relu_attrs, -// // relu's don't have two inputs, but for the purposes of this test it's fine. -// {get_only(relu_1.outputs), get_only(relu_3.outputs)}, -// {relu_output_attrs}); +// // relu's don't have two inputs, but for the +// purposes of this test it's fine. +// {get_only(relu_1.outputs), +// get_only(relu_3.outputs)}, {relu_output_attrs}); // // PartialMachineMapping pre_mapping = PartialMachineMapping{{ // {relu_1.parallel_layer, pre_mv1}, @@ -206,7 +211,8 @@ // make_pcg_leaf_node(relu_2.parallel_layer), // make_pcg_leaf_node(relu_4.parallel_layer)))); // -// TensorSetMovement result = get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), // split, // pre_mapping, // post_mapping); diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index c828a9c164..1940e6d8a3 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -12,113 +12,118 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph pcg = empty_parallel_computation_graph(); ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; auto make_output_attrs = [](ParallelTensorShape const &shape) { return ParallelTensorAttrs{ - /*shape=*/shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::YES, + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, }; }; auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { return ParallelLayerAttrs{ - /*op_attrs=*/op_attrs, - /*name=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, }; }; PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; - auto make_input_key = [&](ParallelTensorShape const ¶llel_tensor_shape) { - return UnmappedOpCostEstimateKey{ - /*op_attrs=*/input_attrs, - /*input_shapes=*/{}, - /*weight_shapes=*/{}, - /*output_shapes=*/{parallel_tensor_shape}, - }; - }; + auto make_input_key = + [&](ParallelTensorShape const ¶llel_tensor_shape) { + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/input_attrs, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{parallel_tensor_shape}, + }; + }; SUBCASE("single layer") { - ParallelLayerAddedResult input_added = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs(input_attrs), - /*inputs=*/{}, - /*output_labels=*/{make_output_attrs(input_shape)}); + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); parallel_layer_guid_t input_layer = input_added.parallel_layer; UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = \ + PCGBinarySPDecomposition sp_decomposition = make_pcg_leaf_node(input_layer); - MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_key); CHECK(result == correct); } SUBCASE("two layers in series") { - ParallelLayerAddedResult input_added = add_parallel_layer(pcg, - /*layer_attrs=*/make_layer_attrs(input_attrs), - /*inputs=*/{}, - /*output_labels=*/{make_output_attrs(input_shape)}); + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); parallel_layer_guid_t input_layer = input_added.parallel_layer; parallel_tensor_guid_t input = get_only(input_added.outputs); UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ - ElementUnaryAttrs{ - /*op_type=*/OperatorType::RELU, - /*scalar=*/std::nullopt, - }, + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, }; ParallelTensorShape relu_output_shape = input_shape; - ParallelLayerAddedResult relu_added = add_parallel_layer(pcg, - make_layer_attrs(relu_attrs), - {input}, - {make_output_attrs(relu_output_shape)}); + ParallelLayerAddedResult relu_added = + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); parallel_layer_guid_t relu_layer = relu_added.parallel_layer; parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ - /*op_attrs=*/relu_attrs, - /*input_shapes=*/{input_shape}, - /*weight_shapes=*/{}, - /*output_shapes=*/{relu_output_shape}, + /*op_attrs=*/relu_attrs, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{relu_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = \ - make_pcg_series_split( - make_pcg_leaf_node(input_layer), - make_pcg_leaf_node(relu_layer)); + PCGBinarySPDecomposition sp_decomposition = make_pcg_series_split( + make_pcg_leaf_node(input_layer), make_pcg_leaf_node(relu_layer)); - MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = \ - mm_problem_tree_make_series_split( + MachineMappingProblemTree correct = mm_problem_tree_make_series_split( AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{}}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{}}, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{}}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, }, - }, }}, mm_problem_tree_make_leaf(input_key), mm_problem_tree_make_leaf(relu_key)); @@ -127,23 +132,23 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("two layers in parallel") { - ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input1_layer = input1_added.parallel_layer; UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); - ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input2_layer = input2_added.parallel_layer; UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = \ - make_pcg_parallel_split( - make_pcg_leaf_node(input1_layer), - make_pcg_leaf_node(input2_layer)); + PCGBinarySPDecomposition sp_decomposition = make_pcg_parallel_split( + make_pcg_leaf_node(input1_layer), make_pcg_leaf_node(input2_layer)); - MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = \ - mm_problem_tree_make_parallel_split( + MachineMappingProblemTree correct = mm_problem_tree_make_parallel_split( mm_problem_tree_make_leaf(input1_key), mm_problem_tree_make_leaf(input2_key)); @@ -151,75 +156,81 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("multiple tensors across split") { - ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input1_layer = input1_added.parallel_layer; parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); - ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); parallel_layer_guid_t input2_layer = input2_added.parallel_layer; parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ - ElementBinaryAttrs{ - /*type=*/OperatorType::EW_ADD, - /*compute_type=*/DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, - }, + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, }; ParallelTensorShape ew_op_output_shape = input_shape; - ParallelLayerAddedResult ew_op_added = add_parallel_layer(pcg, - make_layer_attrs(ew_op_attrs), - {input1_tensor, input2_tensor}, - {make_output_attrs(ew_op_output_shape)}); + ParallelLayerAddedResult ew_op_added = + add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ - /*op_attrs=*/ew_op_attrs, - /*input_shapes=*/{input_shape, input_shape}, - /*weight_shapes=*/{}, - /*output_shapes=*/{ew_op_output_shape}, + /*op_attrs=*/ew_op_attrs, + /*input_shapes=*/{input_shape, input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{ew_op_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = \ - make_pcg_series_split( - make_pcg_parallel_split( - make_pcg_leaf_node(input1_layer), - make_pcg_leaf_node(input2_layer)), + PCGBinarySPDecomposition sp_decomposition = make_pcg_series_split( + make_pcg_parallel_split(make_pcg_leaf_node(input1_layer), + make_pcg_leaf_node(input2_layer)), make_pcg_leaf_node(ew_op_layer)); - MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = \ - mm_problem_tree_make_series_split( + MachineMappingProblemTree correct = mm_problem_tree_make_series_split( AbstractedTensorSetMovement{{ - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{}}, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, }, - }, - AbstractedSingleTensorMovement{ - /*parallel_tensor_shape=*/input_shape, - /*src_machine_views=*/{ - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - }, - /*dst_machine_views=*/{ - BinaryTreePath{{}}, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, }, - }, }}, - /*pre=*/mm_problem_tree_make_parallel_split( - mm_problem_tree_make_leaf(input1_key), - mm_problem_tree_make_leaf(input2_key)), + /*pre=*/ + mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), /*post=*/mm_problem_tree_make_leaf(ew_op_key)); CHECK(result == correct); diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 3717f164ac..254d6b2784 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -1,6 +1,6 @@ #include "compiler/machine_mapping/machine_mapping_result.h" -#include #include "pcg/machine_view.h" +#include using namespace FlexFlow; @@ -11,36 +11,38 @@ TEST_SUITE(FF_TEST_SUITE) { float pre_cost = 2.0; MachineMappingResult pre = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/pre_cost, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, }}, - machine_view_1, - }, - }}, - }, + }, }; float post_cost = 4.0; MachineMappingResult post = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/post_cost, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{}}, - machine_view_1, - }, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, }; MachineMappingResult infeasible = infeasible_machine_mapping_result(); @@ -48,21 +50,27 @@ TEST_SUITE(FF_TEST_SUITE) { float comm_cost = 3.0; SUBCASE("pre is infeasbile") { - MachineMappingResult result = series_combine(comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); + MachineMappingResult result = series_combine( + comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); MachineMappingResult correct = infeasible; CHECK(result == correct); } SUBCASE("post is infeasbile") { - MachineMappingResult result = series_combine(comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult result = series_combine( + comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); MachineMappingResult correct = infeasible; CHECK(result == correct); } SUBCASE("both are infeasible") { - MachineMappingResult result = series_combine(comm_cost, infeasible, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult result = + series_combine(comm_cost, + infeasible, + infeasible, + ParallelSplitTransformation::LthenR); MachineMappingResult correct = infeasible; CHECK(result == correct); @@ -70,75 +78,80 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("both are feasible") { MachineMappingResult no_parallel_split_transform = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/pre_cost + comm_cost + post_cost, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, }; SUBCASE("parallel_split_transformation = std::nullopt") { - MachineMappingResult result = series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingResult result = + series_combine(comm_cost, pre, post, std::nullopt); MachineMappingResult correct = no_parallel_split_transform; CHECK(result == correct); } SUBCASE("parallel_split_transformation = LthenR") { - MachineMappingResult result = series_combine(comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); MachineMappingResult correct = no_parallel_split_transform; CHECK(result == correct); } SUBCASE("parallel_split_transformation = RthenL") { - MachineMappingResult result = series_combine(comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/pre_cost + comm_cost + post_cost, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_1, - }, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, }; CHECK(result == correct); @@ -151,35 +164,37 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMappingResult lhs = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/2.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, }}, - machine_view_1, - }, - }}, - }, + }, }; MachineMappingResult rhs = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/4.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{}}, - machine_view_1, - }, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, }; MachineMappingResult infeasible = infeasible_machine_mapping_result(); @@ -208,31 +223,32 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("both are feasible") { MachineMappingResult result = parallel_combine(lhs, rhs); MachineMappingResult correct = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/4.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - BinaryTreePathEntry::RIGHT_CHILD, - }}, - machine_view_1, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, }}, - machine_view_1, - }, - }}, - }, + }, }; CHECK(result == correct); @@ -244,35 +260,37 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); MachineMappingResult faster = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/2.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{ - BinaryTreePathEntry::LEFT_CHILD, - }}, - machine_view_0, - }, - { - BinaryTreePath{{ - BinaryTreePathEntry::RIGHT_CHILD, + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, }}, - machine_view_1, - }, - }}, - }, + }, }; MachineMappingResult slower = MachineMappingResult{ - FeasibleMachineMappingResult{ - /*runtime=*/4.0, - /*machine_mapping=*/ParallelLayerGuidObliviousMachineMapping{{ - { - BinaryTreePath{{}}, - machine_view_1, - }, - }}, - }, + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, }; MachineMappingResult infeasible = infeasible_machine_mapping_result(); diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index f523520f9f..5fbcd91a06 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -18,8 +18,8 @@ class GenericTensorAccessorW { if (this->data_type == DT) { return static_cast *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -49,8 +49,8 @@ class GenericTensorAccessorR { if (this->data_type == DT) { return static_cast const *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -97,7 +97,7 @@ typename data_type_enum_to_class

::type * return static_cast *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } @@ -118,7 +118,7 @@ typename data_type_enum_to_class
::type const * return static_cast const *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } diff --git a/lib/kernels/include/kernels/datatype_dispatch.h b/lib/kernels/include/kernels/datatype_dispatch.h index e6ab9fa8cc..e83fc3325d 100644 --- a/lib/kernels/include/kernels/datatype_dispatch.h +++ b/lib/kernels/include/kernels/datatype_dispatch.h @@ -22,7 +22,7 @@ Out dispatch(DataType dt, Args &&...args) { case DataType::BOOL: return F{}(std::forward(args)...); default: - throw mk_runtime_error("Unknown datatype {}", dt); + throw mk_runtime_error(fmt::format("Unknown datatype {}", dt)); } } diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 3a36e02327..4035aaf7cf 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -28,10 +28,11 @@ struct DeviceSpecific { T const *get(size_t curr_device_idx) const { if (curr_device_idx != this->device_idx) { - throw mk_runtime_error("Invalid access to DeviceSpecific: attempted " - "device_idx {} != correct device_idx {})", - curr_device_idx, - this->device_idx); + throw mk_runtime_error( + fmt::format("Invalid access to DeviceSpecific: attempted " + "device_idx {} != correct device_idx {})", + curr_device_idx, + this->device_idx)); } return (T const *)this->ptr.get(); } diff --git a/lib/local-execution/include/local-execution/permissions.h b/lib/local-execution/include/local-execution/permissions.h index ce19e38e7e..f34969f233 100644 --- a/lib/local-execution/include/local-execution/permissions.h +++ b/lib/local-execution/include/local-execution/permissions.h @@ -42,8 +42,8 @@ struct formatter<::FlexFlow::Permissions> : formatter { name = "READ_WRITE"; break; default: - throw ::FlexFlow::mk_runtime_error("Unknown permission {}", - static_cast(p)); + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } return formatter::format(name, ctx); } diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 5d0156201e..54eca7e514 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -30,7 +30,7 @@ GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( @@ -49,7 +49,7 @@ VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return variadic_tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } diff --git a/lib/local-execution/src/permissions.cc b/lib/local-execution/src/permissions.cc index e5c46b42f8..2286215987 100644 --- a/lib/local-execution/src/permissions.cc +++ b/lib/local-execution/src/permissions.cc @@ -33,7 +33,8 @@ static int as_int(Permissions p) { case Permissions::RW: return 2; default: - throw mk_runtime_error("Unknown permission {}", static_cast(p)); + throw mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } } diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index a799e01dbc..c740e1ffd2 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -24,11 +24,12 @@ ParallelLayerAddedResult ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, - ParallelTensorShape const &tensor_shape); + ParallelTensorShape const &tensor_shape); -std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, - parallel_layer_guid_t const &, - parallel_layer_guid_t const &); +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &, + parallel_layer_guid_t const &); std::vector get_incoming_tensors(ParallelComputationGraph const &, diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 32f2335605..3d1bc629e4 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -12,11 +12,11 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" namespace FlexFlow { @@ -179,17 +179,14 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, ComputationGraph without_layer_names(ComputationGraph const &cg) { return ComputationGraph{ - LabelledDataflowGraph:: - create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( - rewrite_node_labels( - cg.raw_graph, - [](Node const &n, LayerAttrs const &old_attrs) { - LayerAttrs new_attrs = old_attrs; - new_attrs.name = std::nullopt; - return new_attrs; - })), + LabelledDataflowGraph::create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels(cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), }; } @@ -200,7 +197,6 @@ bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, .has_value(); } - std::string as_dot(ComputationGraph const &cg) { std::function get_node_label = [](LayerAttrs const &a) -> std::string { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index fa610ff9c2..dff647f5a1 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -489,11 +489,12 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error(fmt::format("Invalid data type for input tensor 2 for Gather: " - "{} (should be {} or {})", - this->get_shape(input).data_type, - DataType::INT32, - DataType::INT64)); + throw mk_runtime_error( + fmt::format("Invalid data type for input tensor 2 for Gather: " + "{} (should be {} or {})", + this->get_shape(input).data_type, + DataType::INT32, + DataType::INT64)); } TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 1562425a80..781c44640c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -8,8 +8,8 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -45,18 +45,19 @@ ParallelLayerAddedResult }; } -ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, - ParallelTensorShape const &tensor_shape) { +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape) { ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ - /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, - /*name=*/std::nullopt, + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, }; ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ - /*shape=*/tensor_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::NO, + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, }; return add_parallel_layer(/*pcg=*/pcg, @@ -65,16 +66,18 @@ ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, /*output_labels=*/{tensor_attrs}); } -std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &src, - parallel_layer_guid_t const &dst) { - std::unordered_set raw_edges = get_dataflow_edges_from_node_to_node(pcg.raw_graph, - src.raw_graph_node, - dst.raw_graph_node); - return transform(raw_edges, [](DataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst) { + std::unordered_set raw_edges = + get_dataflow_edges_from_node_to_node( + pcg.raw_graph, src.raw_graph_node, dst.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); } - std::vector get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -176,19 +179,20 @@ parallel_layer_guid_t return get_only(found); } -ParallelComputationGraph without_layer_names(ParallelComputationGraph const &pcg) { +ParallelComputationGraph + without_layer_names(ParallelComputationGraph const &pcg) { return ParallelComputationGraph{ - LabelledDataflowGraph:: - create_copy_of< - UnorderedSetLabelledOpenDataflowGraph>( - rewrite_node_labels( - pcg.raw_graph, - [](Node const &n, ParallelLayerAttrs const &old_attrs) { - ParallelLayerAttrs new_attrs = old_attrs; - new_attrs.name = std::nullopt; - return new_attrs; - })), + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + pcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), }; } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 2d425f5c6c..f33b4dcd17 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,27 +1,27 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/get_incoming_tensor_roles.h" -#include "op-attrs/ops/weight_attrs.dtg.h" -#include "op-attrs/parallel_op_attrs.h" -#include "op-attrs/pcg_operator_attrs.h" -#include "pcg/parallel_computation_graph/generate_weight_transform.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/concat_vectors.h" -#include "utils/containers/enumerate_vector.h" -#include "utils/containers/get_only.h" -#include "utils/containers/transform.h" #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_matmul.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/reduction.h" #include "op-attrs/ops/repartition.h" -#include "op-attrs/ops/combine.h" #include "op-attrs/ops/replicate.h" -#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/parallel_op_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/generate_weight_transform.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/enumerate_vector.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" namespace FlexFlow { @@ -220,7 +220,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( { ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(attrs, input_shape)); - weights.push_back(make_weight_attrs(projection_shape, projection_initializer)); + weights.push_back( + make_weight_attrs(projection_shape, projection_initializer)); } if (use_bias) { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index f0e58191ef..fc07edf5b3 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -268,17 +268,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("pcg_add_input_layer") { ParallelTensorShape tensor_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{12, 2}, - ShardParallelDim{10, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{2}, - DiscardCopyDegree{2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{2}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; ParallelComputationGraph result = [&] { @@ -286,23 +286,23 @@ TEST_SUITE(FF_TEST_SUITE) { pcg_add_input_layer(pcg, tensor_shape); return pcg; }(); - + ParallelComputationGraph correct = [&] { ParallelComputationGraph pcg = empty_parallel_computation_graph(); ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ - /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, - /*name=*/std::nullopt, + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, }; ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ - /*shape=*/tensor_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::NO, + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, }; - add_parallel_layer(/*pcg=*/pcg, + add_parallel_layer(/*pcg=*/pcg, /*layer_attrs=*/layer_attrs, /*inputs=*/{}, /*output_labels=*/{tensor_attrs}); diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc index 6c3d53d3b9..4d4e557fb8 100644 --- a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,7 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" diff --git a/lib/utils/include/utils/any_value_type/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h index eb211b1a1b..a99ce5c8f0 100644 --- a/lib/utils/include/utils/any_value_type/any_value_type.h +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -2,19 +2,20 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H #include +#include #include #include -#include namespace FlexFlow { struct any_value_type { public: - any_value_type(std::any const &value, - std::function const &eq, - std::function const &neq, - std::function const &hash, - std::function const &to_string); + any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string); bool operator==(any_value_type const &other) const; bool operator!=(any_value_type const &other) const; @@ -25,6 +26,7 @@ struct any_value_type { } friend std::string format_as(any_value_type const &); + private: std::any value; std::function eq; @@ -35,23 +37,18 @@ struct any_value_type { friend std::hash; }; - template any_value_type make_any_value_type(T const &t) { return any_value_type{ - std::make_any(t), - [](std::any const &l, std::any const &r) { - return std::any_cast(l) == std::any_cast(r); - }, - [](std::any const &l, std::any const &r) { - return std::any_cast(l) != std::any_cast(r); - }, - [](std::any const &v) { - return std::hash{}(std::any_cast(v)); - }, - [](std::any const &v) { - return fmt::to_string(std::any_cast(v)); - }, + std::make_any(t), + [](std::any const &l, std::any const &r) { + return std::any_cast(l) == std::any_cast(r); + }, + [](std::any const &l, std::any const &r) { + return std::any_cast(l) != std::any_cast(r); + }, + [](std::any const &v) { return std::hash{}(std::any_cast(v)); }, + [](std::any const &v) { return fmt::to_string(std::any_cast(v)); }, }; } @@ -64,6 +61,6 @@ struct hash<::FlexFlow::any_value_type> { size_t operator()(::FlexFlow::any_value_type const &) const; }; -} // namespace FlexFlow +} // namespace std #endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 537bb2d177..b016a1e03d 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -3,9 +3,9 @@ #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/merge_maps.h" #include #include -#include "utils/containers/merge_maps.h" namespace FlexFlow { @@ -41,12 +41,14 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } -template ::key_type, - typename OutV = typename std::invoke_result_t::mapped_type> -std::unordered_map flatmap(std::unordered_map const &m, F &&f) { +template < + typename InK, + typename InV, + typename F, + typename OutK = typename std::invoke_result_t::key_type, + typename OutV = typename std::invoke_result_t::mapped_type> +std::unordered_map flatmap(std::unordered_map const &m, + F &&f) { std::unordered_map result; for (auto const &[k, v] : m) { diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h index 73ac61fcf7..b7b30cbae4 100644 --- a/lib/utils/include/utils/containers/get_all_assignments.h +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -1,37 +1,39 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H -#include -#include -#include "utils/containers/vector_of.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_map_from_pairs.h" -#include "utils/containers/keys.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" -#include "utils/containers/cartesian_product.h" #include "utils/hash/unordered_map.h" +#include +#include #include namespace FlexFlow { /** - * @note If \p options_per_key is empty, an set containing a single empty assignment is returned + * @note If \p options_per_key is empty, an set containing a single empty + * assignment is returned */ template -std::unordered_set> get_all_assignments(std::unordered_map> const &options_per_key) { +std::unordered_set> get_all_assignments( + std::unordered_map> const &options_per_key) { if (options_per_key.empty()) { return {{}}; } std::vector ordered_keys = vector_of(keys(options_per_key)); - std::vector> ordered_value_option_sets = transform(ordered_keys, - [&](K const &k) { return options_per_key.at(k); }); + std::vector> ordered_value_option_sets = transform( + ordered_keys, [&](K const &k) { return options_per_key.at(k); }); std::unordered_set> result = transform( - cartesian_product(ordered_value_option_sets), - [&](std::vector const &chosen_values) { - return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); - }); + cartesian_product(ordered_value_option_sets), + [&](std::vector const &chosen_values) { + return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); + }); return result; } diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index 88f33b52b6..201095c47d 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -10,8 +10,8 @@ namespace FlexFlow { template typename C::value_type get_only(C const &c) { return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error(fmt::format("Encountered container with size {} in get_only", - c.size())); + throw mk_runtime_error(fmt::format( + "Encountered container with size {} in get_only", c.size())); }); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 6a3f230d08..dd886ab8aa 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -3,9 +3,9 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" -#include -#include "utils/fmt/unordered_map.h" #include "utils/exception.h" +#include "utils/fmt/unordered_map.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index 05a955c485..ef6a26c79a 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -22,9 +22,7 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template > +template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; for (auto const &e : v) { @@ -33,10 +31,9 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } -template > -std::unordered_multiset transform(std::unordered_multiset const &v, F const &f) { +template > +std::unordered_multiset transform(std::unordered_multiset const &v, + F const &f) { std::unordered_multiset result; for (auto const &e : v) { result.insert(f(e)); @@ -44,9 +41,7 @@ std::unordered_multiset transform(std::unordered_multiset const &v, F c return result; } -template > +template > std::set transform(std::set const &v, F const &f) { std::set result; for (auto const &e : v) { @@ -55,9 +50,7 @@ std::set transform(std::set const &v, F const &f) { return result; } -template > +template > std::multiset transform(std::multiset const &v, F const &f) { std::multiset result; for (auto const &e : v) { diff --git a/lib/utils/include/utils/containers/try_at.h b/lib/utils/include/utils/containers/try_at.h index a274c134f7..45e50fca27 100644 --- a/lib/utils/include/utils/containers/try_at.h +++ b/lib/utils/include/utils/containers/try_at.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H -#include -#include #include "utils/containers/contains_key.h" +#include #include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/unordered_map_from_pairs.h b/lib/utils/include/utils/containers/unordered_map_from_pairs.h index 34a1d91e86..660c57c5e7 100644 --- a/lib/utils/include/utils/containers/unordered_map_from_pairs.h +++ b/lib/utils/include/utils/containers/unordered_map_from_pairs.h @@ -5,9 +5,10 @@ namespace FlexFlow { -template -std::unordered_map - unordered_map_from_pairs(C const &c) { +template +std::unordered_map unordered_map_from_pairs(C const &c) { return std::unordered_map(c.cbegin(), c.cend()); } diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h index b03609171f..884f4d389e 100644 --- a/lib/utils/include/utils/fmt/monostate.h +++ b/lib/utils/include/utils/fmt/monostate.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H -#include #include +#include namespace fmt { diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 4410f06e67..11c9e1db81 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -1,38 +1,38 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" #include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/visit.h" -#include #include "utils/overload.h" -#include "utils/containers/transform.h" -#include "utils/containers/set_union.h" +#include namespace FlexFlow { template -std::unordered_set find_paths_to_leaf(FullBinaryTree const &tree, - LeafLabel const &leaf) { +std::unordered_set + find_paths_to_leaf(FullBinaryTree const &tree, + LeafLabel const &leaf) { return visit>( - tree, - overload { - [&](LeafLabel const &l) -> std::unordered_set { - if (l == leaf) { - return {binary_tree_root_path()}; - } else { - return {}; - } - }, - [&](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(find_paths_to_leaf(get_left_child(parent), leaf), - nest_inside_left_child), - transform(find_paths_to_leaf(get_right_child(parent), leaf), - nest_inside_right_child)); - } - }); + tree, + overload{ + [&](LeafLabel const &l) -> std::unordered_set { + if (l == leaf) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + [&](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(find_paths_to_leaf(get_left_child(parent), leaf), + nest_inside_left_child), + transform(find_paths_to_leaf(get_right_child(parent), leaf), + nest_inside_right_child)); + }}); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h index 96d384c3ae..4450b70596 100644 --- a/lib/utils/include/utils/full_binary_tree/fmt.h +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -11,34 +11,34 @@ namespace FlexFlow { template -std::string format_as(FullBinaryTreeParentNode const &t) { - return fmt::format("<{} ({} {})>", - t.label, - get_left_child(t), - get_right_child(t)); +std::string + format_as(FullBinaryTreeParentNode const &t) { + return fmt::format( + "<{} ({} {})>", t.label, get_left_child(t), get_right_child(t)); } template std::string format_as(FullBinaryTree const &t) { auto visitor = FullBinaryTreeVisitor{ - [](FullBinaryTreeParentNode const &parent) { - return fmt::to_string(parent); - }, - [](LeafLabel const &leaf) { - return fmt::format("{}", leaf); - }, + [](FullBinaryTreeParentNode const &parent) { + return fmt::to_string(parent); + }, + [](LeafLabel const &leaf) { return fmt::format("{}", leaf); }, }; return visit(t, visitor); } template -std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode const &t) { +std::ostream & + operator<<(std::ostream &s, + FullBinaryTreeParentNode const &t) { return (s << fmt::to_string(t)); } template -std::ostream &operator<<(std::ostream &s, FullBinaryTree const &t) { +std::ostream &operator<<(std::ostream &s, + FullBinaryTree const &t) { return (s << fmt::to_string(t)); } diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h index 45d0c5f151..562edf52c1 100644 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H #include -#include #include +#include namespace FlexFlow { @@ -13,15 +13,14 @@ struct FullBinaryTree; template struct FullBinaryTreeParentNode { explicit FullBinaryTreeParentNode( - ParentLabel const &label, - FullBinaryTree const &lhs, - FullBinaryTree const &rhs) - : label(label), - left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) - { } + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) + : label(label), + left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; @@ -44,22 +43,24 @@ struct FullBinaryTreeParentNode { bool operator<(FullBinaryTreeParentNode const &other) const { return this->tie() < other.tie(); } + public: ParentLabel label; std::shared_ptr> left_child_ptr; std::shared_ptr> right_child_ptr; + private: std::tuple> const &, std::shared_ptr> const &> - tie_ptr() const { + tie_ptr() const { return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); } std::tuple const &, FullBinaryTree const &> - tie() const { + tie() const { return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); } @@ -70,11 +71,11 @@ template struct FullBinaryTree { public: FullBinaryTree() = delete; - explicit FullBinaryTree(FullBinaryTreeParentNode const &t) - : root{t} {} + explicit FullBinaryTree( + FullBinaryTreeParentNode const &t) + : root{t} {} - explicit FullBinaryTree(LeafLabel const &t) - : root{t} {} + explicit FullBinaryTree(LeafLabel const &t) : root{t} {} bool operator==(FullBinaryTree const &other) const { return this->tie() == other.tie(); @@ -87,8 +88,11 @@ struct FullBinaryTree { bool operator<(FullBinaryTree const &other) const { return this->tie() < other.tie(); } + public: - std::variant, LeafLabel> root; + std::variant, LeafLabel> + root; + private: std::tuple tie() const { return std::tie(this->root); diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 926cc0ea9c..23008b4cc0 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -1,37 +1,37 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" #include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/visit.h" -#include #include "utils/overload.h" -#include "utils/containers/set_union.h" -#include "utils/containers/transform.h" +#include namespace FlexFlow { template -std::unordered_set get_all_leaf_paths(FullBinaryTree const &tree) { - return visit> - (tree, - overload { - [](LeafLabel const &) { - return std::unordered_set{binary_tree_root_path()}; - }, - [](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(get_all_leaf_paths(get_left_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(get_right_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } - }); +std::unordered_set + get_all_leaf_paths(FullBinaryTree const &tree) { + return visit>( + tree, + overload{ + [](LeafLabel const &) { + return std::unordered_set{binary_tree_root_path()}; + }, + [](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(get_all_leaf_paths(get_left_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(get_right_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }}); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index e9ceddff6d..db7ea95a04 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -1,25 +1,27 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H +#include "utils/exception.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" #include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" -#include "utils/exception.h" #include namespace FlexFlow { template -FullBinaryTree get_child(FullBinaryTreeParentNode const &t, - BinaryTreePathEntry const &e) { +FullBinaryTree + get_child(FullBinaryTreeParentNode const &t, + BinaryTreePathEntry const &e) { switch (e) { case BinaryTreePathEntry::LEFT_CHILD: return get_left_child(t); case BinaryTreePathEntry::RIGHT_CHILD: return get_right_child(t); default: - throw mk_runtime_error(fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + throw mk_runtime_error( + fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); } } diff --git a/lib/utils/include/utils/full_binary_tree/get_label.h b/lib/utils/include/utils/full_binary_tree/get_label.h index 1b48965b01..e89fdab98e 100644 --- a/lib/utils/include/utils/full_binary_tree/get_label.h +++ b/lib/utils/include/utils/full_binary_tree/get_label.h @@ -6,7 +6,8 @@ namespace FlexFlow { template -ParentLabel get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &p) { +ParentLabel get_full_binary_tree_parent_label( + FullBinaryTreeParentNode const &p) { return p.label; } diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index c58a850a6d..8ebc945db7 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -1,28 +1,27 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#include "utils/containers/multiset_union.h" #include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/visit.h" #include "utils/overload.h" #include -#include "utils/containers/multiset_union.h" namespace FlexFlow { template std::unordered_multiset - get_leaves(FullBinaryTree const &t) { + get_leaves(FullBinaryTree const &t) { return visit>( - t, - overload { - [](FullBinaryTreeParentNode const &parent) { - return multiset_union(get_leaves(get_left_child(parent)), - get_leaves(get_right_child(parent))); - }, - [](ChildLabel const &leaf) { - return std::unordered_multiset{leaf}; - } - }); + t, + overload{ + [](FullBinaryTreeParentNode const &parent) { + return multiset_union(get_leaves(get_left_child(parent)), + get_leaves(get_right_child(parent))); + }, + [](ChildLabel const &leaf) { + return std::unordered_multiset{leaf}; + }}); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h index 163503abfd..5d5148d594 100644 --- a/lib/utils/include/utils/full_binary_tree/get_left_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -6,7 +6,8 @@ namespace FlexFlow { template -FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { +FullBinaryTree const & + get_left_child(FullBinaryTreeParentNode const &t) { return *t.left_child_ptr; } diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h index 0ee8eea6d8..1a73ce8743 100644 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -7,12 +7,14 @@ namespace FlexFlow { template -FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { +FullBinaryTreeNodeType + get_node_type(FullBinaryTree const &t) { if (std::holds_alternative(t.root)) { return FullBinaryTreeNodeType::LEAF; } else { - bool is_parent = std::holds_alternative>(t.root); - assert (is_parent); + bool is_parent = std::holds_alternative< + FullBinaryTreeParentNode>(t.root); + assert(is_parent); return FullBinaryTreeNodeType::PARENT; } diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h index e40f2024a1..937e803395 100644 --- a/lib/utils/include/utils/full_binary_tree/get_right_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -6,7 +6,8 @@ namespace FlexFlow { template -FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { +FullBinaryTree const & + get_right_child(FullBinaryTreeParentNode const &t) { return *t.right_child_ptr; } diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index 6909d9e1ef..0a6fba4a77 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -12,25 +12,23 @@ namespace FlexFlow { template -std::optional> get_subtree_at_path(FullBinaryTree const &t, - BinaryTreePath const &p) { +std::optional> + get_subtree_at_path(FullBinaryTree const &t, + BinaryTreePath const &p) { if (p == binary_tree_root_path()) { return t; } return visit>>( - t, - overload { - [&](FullBinaryTreeParentNode const &parent) { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + t, + overload{ + [&](FullBinaryTreeParentNode const &parent) { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - return get_subtree_at_path(get_child(parent, curr), rest); - }, - [&](LeafLabel const &leaf) { - return std::nullopt; - } - }); + return get_subtree_at_path(get_child(parent, curr), rest); + }, + [&](LeafLabel const &leaf) { return std::nullopt; }}); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h index a29836f972..6893b990c7 100644 --- a/lib/utils/include/utils/full_binary_tree/hash.h +++ b/lib/utils/include/utils/full_binary_tree/hash.h @@ -9,18 +9,21 @@ namespace std { template struct hash<::FlexFlow::FullBinaryTreeParentNode> { - size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { + size_t operator()( + ::FlexFlow::FullBinaryTreeParentNode const &t) + const { return get_std_hash(t.tie()); } }; template struct hash<::FlexFlow::FullBinaryTree> { - size_t operator()(::FlexFlow::FullBinaryTree const &t) const { + size_t operator()( + ::FlexFlow::FullBinaryTree const &t) const { return get_std_hash(t.tie()); } }; -} // namespace FlexFlow +} // namespace std #endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h index 0d830890dc..a589c541da 100644 --- a/lib/utils/include/utils/full_binary_tree/json.h +++ b/lib/utils/include/utils/full_binary_tree/json.h @@ -12,18 +12,23 @@ namespace nlohmann { template -struct adl_serializer<::FlexFlow::FullBinaryTreeParentNode> { - static ::FlexFlow::FullBinaryTreeParentNode from_json(json const &j) { +struct adl_serializer< + ::FlexFlow::FullBinaryTreeParentNode> { + static ::FlexFlow::FullBinaryTreeParentNode + from_json(json const &j) { return ::FlexFlow::FullBinaryTreeParentNode{ j.at("left_child") - .template get<::FlexFlow::FullBinaryTreeParentNode>(), + .template get< + ::FlexFlow::FullBinaryTreeParentNode>(), j.at("right_child") - .template get<::FlexFlow::FullBinaryTreeParentNode>(), + .template get< + ::FlexFlow::FullBinaryTreeParentNode>(), }; } - static void to_json(json &j, - ::FlexFlow::FullBinaryTreeParentNode const &v) { + static void to_json( + json &j, + ::FlexFlow::FullBinaryTreeParentNode const &v) { j["__type"] = "FullBinaryTreeParentNode"; j["left_child"] = get_left_child(v); j["right_child"] = get_right_child(v); @@ -32,12 +37,15 @@ struct adl_serializer<::FlexFlow::FullBinaryTreeParentNode struct adl_serializer<::FlexFlow::FullBinaryTree> { - static ::FlexFlow::FullBinaryTree from_json(json const &j) { + static ::FlexFlow::FullBinaryTree + from_json(json const &j) { std::string key = j.at("type").get(); if (key == "parent") { return ::FlexFlow::FullBinaryTree{ - j.at("value").get<::FlexFlow::FullBinaryTreeParentNode>(), + j.at("value") + .get<::FlexFlow::FullBinaryTreeParentNode>(), }; } else if (key == "leaf") { return ::FlexFlow::FullBinaryTree{ @@ -56,7 +64,8 @@ struct adl_serializer<::FlexFlow::FullBinaryTree> { ::FlexFlow::visit( v, ::FlexFlow::overload{ - [&](::FlexFlow::FullBinaryTreeParentNode const &s) { + [&](::FlexFlow::FullBinaryTreeParentNode const &s) { j["type"] = "parent"; j["value"] = s; return std::monostate{}; diff --git a/lib/utils/include/utils/full_binary_tree/make.h b/lib/utils/include/utils/full_binary_tree/make.h index a4ef47c7df..488f7f83fd 100644 --- a/lib/utils/include/utils/full_binary_tree/make.h +++ b/lib/utils/include/utils/full_binary_tree/make.h @@ -6,22 +6,24 @@ namespace FlexFlow { template -FullBinaryTree make_full_binary_tree_parent(ParentLabel const &label, - FullBinaryTree const &lhs, - FullBinaryTree const &rhs) { +FullBinaryTree make_full_binary_tree_parent( + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) { return FullBinaryTree{ - FullBinaryTreeParentNode{ - label, - lhs, - rhs, - }, + FullBinaryTreeParentNode{ + label, + lhs, + rhs, + }, }; } template -FullBinaryTree make_full_binary_tree_leaf(LeafLabel const &label) { +FullBinaryTree + make_full_binary_tree_leaf(LeafLabel const &label) { return FullBinaryTree{ - label, + label, }; } diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h index f7be417945..65bcb9b3bd 100644 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -6,12 +6,15 @@ namespace FlexFlow { template -FullBinaryTreeParentNode const &require_full_binary_tree_parent_node(FullBinaryTree const &t) { +FullBinaryTreeParentNode const & + require_full_binary_tree_parent_node( + FullBinaryTree const &t) { return std::get>(t.root); } template -LeafLabel const &require_full_binary_tree_leaf(FullBinaryTree const &t) { +LeafLabel const &require_full_binary_tree_leaf( + FullBinaryTree const &t) { return std::get(t.root); } diff --git a/lib/utils/include/utils/full_binary_tree/transform.h b/lib/utils/include/utils/full_binary_tree/transform.h index 52ed07f7ba..6e33064025 100644 --- a/lib/utils/include/utils/full_binary_tree/transform.h +++ b/lib/utils/include/utils/full_binary_tree/transform.h @@ -4,43 +4,44 @@ #include "utils/full_binary_tree/full_binary_tree.dtg.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" -#include "utils/overload.h" #include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" namespace FlexFlow { -template , typename LeafLabel2 = std::invoke_result_t> -FullBinaryTreeParentNode transform(FullBinaryTreeParentNode const &t, F f) { +FullBinaryTreeParentNode + transform(FullBinaryTreeParentNode const &t, F f) { return FullBinaryTreeParentNode{ - transform(get_left_child(t), f), - transform(get_right_child(t), f), + transform(get_left_child(t), f), + transform(get_right_child(t), f), }; } -template , typename LeafLabel2 = std::invoke_result_t> -FullBinaryTree transform(FullBinaryTree const &t, F f) { - return visit> - ( t, - overload { - [&](FullBinaryTreeParentNode const &parent) { - return FullBinaryTree{ - transform(parent, f), - }; - }, - [&](LeafLabel const &leaf) { - return FullBinaryTree{ - f(leaf), - }; - } - }); +FullBinaryTree + transform(FullBinaryTree const &t, F f) { + return visit>( + t, + overload{ + [&](FullBinaryTreeParentNode const &parent) { + return FullBinaryTree{ + transform(parent, f), + }; + }, + [&](LeafLabel const &leaf) { + return FullBinaryTree{ + f(leaf), + }; + }}); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 860e60fcca..502165f2ab 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -1,24 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H -#include "utils/full_binary_tree/get_node_type.h" -#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/get_node_type.h" #include "utils/full_binary_tree/require.h" namespace FlexFlow { template Result visit(FullBinaryTree const &tt, F f) { - auto visitor = FullBinaryTreeVisitor{ - f, f - }; + auto visitor = FullBinaryTreeVisitor{f, f}; return visit(tt, visitor); } template -Result visit(FullBinaryTree const &t, FullBinaryTreeVisitor const &v) { +Result visit(FullBinaryTree const &t, + FullBinaryTreeVisitor const &v) { FullBinaryTreeNodeType node_type = get_node_type(t); switch (node_type) { case FullBinaryTreeNodeType::PARENT: @@ -26,11 +25,11 @@ Result visit(FullBinaryTree const &t, FullBinaryTreeVisi case FullBinaryTreeNodeType::LEAF: return v.leaf_func(require_full_binary_tree_leaf(t)); default: - throw mk_runtime_error(fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); + throw mk_runtime_error( + fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); } } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h index 5c4632ca2a..de7ead8fb6 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h @@ -5,9 +5,8 @@ namespace FlexFlow { -std::unordered_set get_dataflow_edges_from_node_to_node(DataflowGraphView const &g, - Node const &src, - Node const &dst); +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h index 79cb6059b3..be0e57435a 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" namespace FlexFlow { -SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split(TransitiveReducedDataflowGraphView const &, - BinarySeriesSplit const &); +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h index c3a71b0f63..e53bb876a1 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -6,9 +6,8 @@ namespace FlexFlow { -std::unordered_set - get_transitive_reduced_edges_across_split(TransitiveReducedDataflowGraphView const &, - BinarySeriesSplit const &); +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h index 5fab1fa0b3..ad8eadda0e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -6,9 +6,8 @@ namespace FlexFlow { -std::unordered_set - get_transitive_reduced_outputs_across_split(TransitiveReducedDataflowGraphView const &, - BinarySeriesSplit const &); +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h index 6b711c8382..916e8f7896 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h @@ -5,7 +5,8 @@ namespace FlexFlow { -TransitiveReducedDataflowGraphView get_dataflow_graph_transitive_reduction(DataflowGraphView const &); +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h index ddb3ca1c68..240fc66426 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h @@ -4,9 +4,10 @@ #include "utils/graph/digraph/digraph_view.h" namespace FlexFlow { -std::unordered_set get_edges_from_subgraph_to_subgraph(DiGraphView const &, - std::unordered_set const &, - std::unordered_set const &); +std::unordered_set + get_edges_from_subgraph_to_subgraph(DiGraphView const &, + std::unordered_set const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h index 2685306bd5..b9894fbac3 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h @@ -8,9 +8,9 @@ namespace FlexFlow { - // NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. - // Figuring out what to do with it is tracked in - // https://github.com/flexflow/FlexFlow/issues/1513 +// NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. +// Figuring out what to do with it is tracked in +// https://github.com/flexflow/FlexFlow/issues/1513 template struct LazyLabelledDataflowGraph final diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h index 1aa5b6b37f..07aa64aa62 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h @@ -13,10 +13,9 @@ template rewrite_node_labels( LabelledDataflowGraphView const &g, F f) { return rewrite_node_labels( - view_as_labelled_open_dataflow_graph(g), f); + view_as_labelled_open_dataflow_graph(g), f); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h index db2fbceaed..0d66f80f35 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h index f8ef91a5d8..efd77a89bd 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index 281d64c6f6..b87516e88a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include @@ -44,7 +44,8 @@ Return visit(BinarySPDecompositionTree const &tree, F &&f) { return result; } default: - throw mk_runtime_error(fmt::format("Unhandled SPDecompositionTreeNodeType value: {}", node_type)); + throw mk_runtime_error(fmt::format( + "Unhandled SPDecompositionTreeNodeType value: {}", node_type)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h index e89a35bab8..b2d50676b9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -7,8 +7,11 @@ namespace FlexFlow { template -std::unordered_set find_paths_to_leaf(GenericBinarySPDecompositionTree const &tree, - LeafLabel const &leaf) { +std::unordered_set + find_paths_to_leaf(GenericBinarySPDecompositionTree const &tree, + LeafLabel const &leaf) { return find_paths_to_leaf(tree.raw_tree, leaf); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h index 0c08a0462b..2cafb4e5b9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h @@ -8,39 +8,50 @@ namespace FlexFlow { template -SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &label) { - return label.template visit(overload { - [](GenericBinarySeriesSplitLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](GenericBinaryParallelSplitLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, +SPDecompositionTreeNodeType get_node_type( + GenericBinarySPSplitLabel const &label) { + return label.template visit(overload{ + [](GenericBinarySeriesSplitLabel const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](GenericBinaryParallelSplitLabel const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, }); } template -GenericBinarySPSplitLabel make_generic_binary_series_split_label(SeriesLabel const &label) { +GenericBinarySPSplitLabel + make_generic_binary_series_split_label(SeriesLabel const &label) { return GenericBinarySPSplitLabel{ - GenericBinarySeriesSplitLabel{ - label, - }, + GenericBinarySeriesSplitLabel{ + label, + }, }; } template -GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(ParallelLabel const &label) { +GenericBinarySPSplitLabel + make_generic_binary_parallel_split_label(ParallelLabel const &label) { return GenericBinarySPSplitLabel{ - GenericBinaryParallelSplitLabel{ - label, - }, + GenericBinaryParallelSplitLabel{ + label, + }, }; } template -SeriesLabel require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &label) { - return label.template get>().raw_label; +SeriesLabel require_generic_binary_series_split_label( + GenericBinarySPSplitLabel const &label) { + return label.template get>() + .raw_label; } template -ParallelLabel require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &label) { - return label.template get>().raw_label; +ParallelLabel require_generic_binary_parallel_split_label( + GenericBinarySPSplitLabel const &label) { + return label.template get>() + .raw_label; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h index 0bb0e08eae..6eb9166df0 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -7,7 +7,10 @@ namespace FlexFlow { template -std::unordered_set get_all_leaf_paths(GenericBinarySPDecompositionTree const &tree) { +std::unordered_set get_all_leaf_paths( + GenericBinarySPDecompositionTree const &tree) { return get_all_leaf_paths(tree.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index cad88d25b2..c5d0e1bd30 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -13,24 +13,37 @@ namespace FlexFlow { template std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { + get_leaves(GenericBinarySPDecompositionTree const &tt) { return visit>( tt, overload{ [](LeafLabel const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, - [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + [](GenericBinarySeriesSplit const &s) { + return get_leaves(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_leaves(p); + }, }); } template -std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { +std::unordered_multiset get_leaves( + GenericBinarySeriesSplit const &s) { return multiset_union(get_leaves(get_left_child(s)), get_leaves(get_right_child(s))); } template -std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { +std::unordered_multiset get_leaves( + GenericBinaryParallelSplit const + &p) { return multiset_union(get_leaves(get_left_child(p)), get_leaves(get_right_child(p))); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h index 9e857341c6..95a75a835c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -1,21 +1,25 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &s) { + get_left_child( + GenericBinarySeriesSplit const + &s) { return s.pre; } template GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &p) { + get_left_child( + GenericBinaryParallelSplit const + &p) { return p.lhs; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 1dedf581fe..8f80c32dbf 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -1,25 +1,30 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H +#include "utils/full_binary_tree/get_label.h" +#include "utils/full_binary_tree/visit.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include "utils/full_binary_tree/get_label.h" -#include "utils/full_binary_tree/visit.h" #include "utils/overload.h" namespace FlexFlow { template SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tt) { - auto visitor = FullBinaryTreeVisitor, LeafLabel>{ - [](FullBinaryTreeParentNode, LeafLabel> const &parent) { - return get_node_type(get_full_binary_tree_parent_label(parent)); - }, - [](LeafLabel const &) { - return SPDecompositionTreeNodeType::NODE; - }, + get_node_type(GenericBinarySPDecompositionTree const &tt) { + auto visitor = FullBinaryTreeVisitor< + SPDecompositionTreeNodeType, + GenericBinarySPSplitLabel, + LeafLabel>{ + [](FullBinaryTreeParentNode< + GenericBinarySPSplitLabel, + LeafLabel> const &parent) { + return get_node_type(get_full_binary_tree_parent_label(parent)); + }, + [](LeafLabel const &) { return SPDecompositionTreeNodeType::NODE; }, }; return visit(tt.raw_tree, visitor); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index cfb2ea1cb2..f9619df862 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -10,27 +10,36 @@ namespace FlexFlow { template -int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { +int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { return visit(tt, overload{ [](LeafLabel const &t) { return 1; }, - [](GenericBinarySeriesSplit const &s) { + [](GenericBinarySeriesSplit const &s) { return get_num_tree_nodes(s); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return get_num_tree_nodes(p); }, }); } template -int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { +int get_num_tree_nodes( + GenericBinarySeriesSplit const &s) { return 1 + get_num_tree_nodes(get_left_child(s)) + get_num_tree_nodes(get_right_child(s)); } template -int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { +int get_num_tree_nodes( + GenericBinaryParallelSplit const + &p) { return 1 + get_num_tree_nodes(get_left_child(p)) + get_num_tree_nodes(get_right_child(p)); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h index 766995b8a9..4820bfdc7a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -1,21 +1,25 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &s) { + get_right_child( + GenericBinarySeriesSplit const + &s) { return s.post; } template GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &p) { + get_right_child( + GenericBinaryParallelSplit const + &p) { return p.rhs; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h index fe308ec762..5ec0c03c3a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -8,17 +8,28 @@ namespace FlexFlow { -template -std::optional> - get_subtree_at_path(GenericBinarySPDecompositionTree const &tree, - BinaryTreePath const &path) { - std::optional, LeafLabel>> raw_subtree = get_subtree_at_path(tree.raw_tree, path); +template +std::optional> + get_subtree_at_path(GenericBinarySPDecompositionTree const &tree, + BinaryTreePath const &path) { + std::optional, + LeafLabel>> + raw_subtree = get_subtree_at_path(tree.raw_tree, path); if (!raw_subtree.has_value()) { return std::nullopt; } else { - return GenericBinarySPDecompositionTree{ - raw_subtree.value(), + return GenericBinarySPDecompositionTree{ + raw_subtree.value(), }; } } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h index bdaf8bcc2b..a7046bedbe 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -6,17 +6,23 @@ namespace FlexFlow { template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { +bool is_series_split(GenericBinarySPDecompositionTree const &t) { return get_node_type(t) == SPDecompositionTreeNodeType::SERIES; } template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { return get_node_type(t) == SPDecompositionTreeNodeType::PARALLEL; } template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { +bool is_leaf(GenericBinarySPDecompositionTree const &t) { return get_node_type(t) == SPDecompositionTreeNodeType::NODE; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 1ec84f194f..5331a10c86 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -11,17 +11,23 @@ namespace FlexFlow { template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ [](LeafLabel const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_right_child(s)) && is_binary_sp_tree_left_associative(get_left_child(s)) && is_binary_sp_tree_left_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_right_child(p)) && is_binary_sp_tree_left_associative(get_left_child(p)) && is_binary_sp_tree_left_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index a3ff9d4012..e7a03b1e0e 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -11,17 +11,23 @@ namespace FlexFlow { template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ [](LeafLabel const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_left_child(s)) && is_binary_sp_tree_right_associative(get_left_child(s)) && is_binary_sp_tree_right_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_left_child(p)) && is_binary_sp_tree_right_associative(get_left_child(p)) && is_binary_sp_tree_right_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index 98382c78c8..b1f635389c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -1,43 +1,62 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/full_binary_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" namespace FlexFlow { template -GenericBinarySPDecompositionTree make_generic_binary_series_split( - SeriesLabel const &label, - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - make_generic_binary_series_split_label(label), - lhs.raw_tree, - rhs.raw_tree), +GenericBinarySPDecompositionTree + make_generic_binary_series_split( + SeriesLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + make_full_binary_tree_parent( + make_generic_binary_series_split_label( + label), + lhs.raw_tree, + rhs.raw_tree), }; } template -GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - ParallelLabel const &label, - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - make_generic_binary_parallel_split_label(label), - lhs.raw_tree, - rhs.raw_tree), +GenericBinarySPDecompositionTree + make_generic_binary_parallel_split( + ParallelLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + make_full_binary_tree_parent( + make_generic_binary_parallel_split_label( + label), + lhs.raw_tree, + rhs.raw_tree), }; } template -GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(LeafLabel const &leaf) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_leaf>( - leaf), +GenericBinarySPDecompositionTree + make_generic_binary_sp_leaf(LeafLabel const &leaf) { + return GenericBinarySPDecompositionTree{ + make_full_binary_tree_leaf< + GenericBinarySPSplitLabel>(leaf), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 4961dc7b61..4dae420449 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -1,48 +1,69 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/full_binary_tree/get_label.h" +#include "utils/full_binary_tree/require.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" -#include "utils/full_binary_tree/require.h" -#include "utils/full_binary_tree/get_label.h" namespace FlexFlow { template GenericBinarySeriesSplit - require_generic_binary_series_split(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); + require_generic_binary_series_split( + GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode< + GenericBinarySPSplitLabel, + LeafLabel> + parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinarySeriesSplit{ - /*label=*/require_generic_binary_series_split_label(get_full_binary_tree_parent_label(parent)), - /*pre=*/GenericBinarySPDecompositionTree{ - get_left_child(parent), - }, - /*post=*/GenericBinarySPDecompositionTree{ - get_right_child(parent), - }, + /*label=*/require_generic_binary_series_split_label( + get_full_binary_tree_parent_label(parent)), + /*pre=*/ + GenericBinarySPDecompositionTree{ + get_left_child(parent), + }, + /*post=*/ + GenericBinarySPDecompositionTree{ + get_right_child(parent), + }, }; } template GenericBinaryParallelSplit - require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); + require_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode< + GenericBinarySPSplitLabel, + LeafLabel> + parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinaryParallelSplit{ - /*label=*/require_generic_binary_parallel_split_label(get_full_binary_tree_parent_label(parent)), - /*lhs=*/GenericBinarySPDecompositionTree{ - get_left_child(parent), - }, - /*rhs=*/GenericBinarySPDecompositionTree{ - get_right_child(parent), - }, + /*label=*/require_generic_binary_parallel_split_label( + get_full_binary_tree_parent_label(parent)), + /*lhs=*/ + GenericBinarySPDecompositionTree{ + get_left_child(parent), + }, + /*rhs=*/ + GenericBinarySPDecompositionTree{ + get_right_child(parent), + }, }; } template -LeafLabel require_generic_binary_leaf(GenericBinarySPDecompositionTree const &t) { +LeafLabel require_generic_binary_leaf( + GenericBinarySPDecompositionTree const &t) { return require_full_binary_tree_leaf(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index 96c3cd5de8..045bd41652 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -5,75 +5,108 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" #include "utils/overload.h" namespace FlexFlow { -template GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, - GenericBinarySPDecompositionTreeVisitor const &visitor); + transform( + GenericBinarySPDecompositionTree const &tt, + GenericBinarySPDecompositionTreeVisitor const &visitor); -template -GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &s, - GenericBinarySPDecompositionTreeVisitor const &visitor) { +GenericBinarySeriesSplit transform( + GenericBinarySeriesSplit const &s, + GenericBinarySPDecompositionTreeVisitor const &visitor) { return GenericBinarySeriesSplit{ - visitor.series_split_func(s.label), - transform(get_left_child(s), visitor), - transform(get_right_child(s), visitor), - }; + visitor.series_split_func(s.label), + transform(get_left_child(s), visitor), + transform(get_right_child(s), visitor), + }; }; -template -GenericBinaryParallelSplit - transform(GenericBinaryParallelSplit const &s, - GenericBinarySPDecompositionTreeVisitor const &visitor) { +GenericBinaryParallelSplit transform( + GenericBinaryParallelSplit const &s, + GenericBinarySPDecompositionTreeVisitor const &visitor) { return GenericBinaryParallelSplit{ - visitor.parallel_split_func(s.label), - transform(get_left_child(s), visitor), - transform(get_right_child(s), visitor), - }; + visitor.parallel_split_func(s.label), + transform(get_left_child(s), visitor), + transform(get_right_child(s), visitor), + }; }; -template GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, - GenericBinarySPDecompositionTreeVisitor const &visitor) { - return visit>( + transform( + GenericBinarySPDecompositionTree const &tt, + GenericBinarySPDecompositionTreeVisitor const &visitor) { + return visit>( tt, overload{ - [&](GenericBinarySeriesSplit const &s) { + [&](GenericBinarySeriesSplit const &s) { return wrap_series_split(transform(s, visitor)); }, - [&](GenericBinaryParallelSplit const &s) { + [&](GenericBinaryParallelSplit const &s) { return wrap_parallel_split(transform(s, visitor)); }, [&](LeafLabel const &t) { - return make_generic_binary_sp_leaf(visitor.leaf_func(t)); + return make_generic_binary_sp_leaf( + visitor.leaf_func(t)); }, }); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index a1ac10a6a0..2688c1dd55 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -8,8 +8,15 @@ namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, F f) { +template +Result visit(GenericBinarySPDecompositionTree const &tt, + F f) { SPDecompositionTreeNodeType node_type = get_node_type(tt); switch (node_type) { case SPDecompositionTreeNodeType::SERIES: { @@ -25,7 +32,8 @@ Result visit(GenericBinarySPDecompositionTree GenericBinarySPDecompositionTree - wrap_series_split(GenericBinarySeriesSplit const &series_split) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - /*label=*/make_generic_binary_series_split_label(series_split.label), - /*lhs=*/series_split.pre.raw_tree, - /*rhs=*/series_split.post.raw_tree), + wrap_series_split( + GenericBinarySeriesSplit const + &series_split) { + return GenericBinarySPDecompositionTree{ + make_full_binary_tree_parent( + /*label=*/make_generic_binary_series_split_label( + series_split.label), + /*lhs=*/series_split.pre.raw_tree, + /*rhs=*/series_split.post.raw_tree), }; } template GenericBinarySPDecompositionTree - wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - /*label=*/make_generic_binary_parallel_split_label(parallel_split.label), - /*lhs=*/parallel_split.lhs.raw_tree, - /*rhs=*/parallel_split.rhs.raw_tree), + wrap_parallel_split( + GenericBinaryParallelSplit const + ¶llel_split) { + return GenericBinarySPDecompositionTree{ + make_full_binary_tree_parent( + /*label=*/make_generic_binary_parallel_split_label( + parallel_split.label), + /*lhs=*/parallel_split.lhs.raw_tree, + /*rhs=*/parallel_split.rhs.raw_tree), }; } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h index 77b44adc01..1d7f9ae88c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -7,12 +7,12 @@ namespace FlexFlow { template -std::unordered_set find_paths_to_leaf(LeafOnlyBinarySPDecompositionTree const &tree, - LeafLabel const &leaf) { +std::unordered_set + find_paths_to_leaf(LeafOnlyBinarySPDecompositionTree const &tree, + LeafLabel const &leaf) { return find_paths_to_leaf(tree.raw_tree, leaf); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h index 628cf89a44..9a8a744771 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template -std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { +std::unordered_multiset + get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { return get_leaves(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h index 74dc6cd839..f83103b4de 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h @@ -7,7 +7,8 @@ namespace FlexFlow { template -SPDecompositionTreeNodeType get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { +SPDecompositionTreeNodeType + get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { return get_node_type(tree.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 7ba5d2998c..7d6242030a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template -bool is_binary_sp_tree_left_associative(LeafOnlyBinarySPDecompositionTree const &t) { +bool is_binary_sp_tree_left_associative( + LeafOnlyBinarySPDecompositionTree const &t) { return is_binary_sp_tree_left_associative(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index 84f6b21602..8fbc6d38a0 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template -bool is_binary_sp_tree_right_associative(LeafOnlyBinarySPDecompositionTree const &t) { +bool is_binary_sp_tree_right_associative( + LeafOnlyBinarySPDecompositionTree const &t) { return is_binary_sp_tree_right_associative(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h index 9d4ce10cb4..81fbe0c1fa 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h @@ -7,12 +7,14 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &s) { +LeafOnlyBinarySPDecompositionTree + get_left_child(LeafOnlyBinaryParallelSplit const &s) { return s.lhs; } template -LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &s) { +LeafOnlyBinarySPDecompositionTree + get_right_child(LeafOnlyBinaryParallelSplit const &s) { return s.rhs; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h index 853def2c60..d95e741516 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h @@ -7,12 +7,14 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &s) { +LeafOnlyBinarySPDecompositionTree + get_left_child(LeafOnlyBinarySeriesSplit const &s) { return s.pre; } template -LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &s) { +LeafOnlyBinarySPDecompositionTree + get_right_child(LeafOnlyBinarySeriesSplit const &s) { return s.post; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index 0eb05dc867..c82a4560ae 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -7,37 +7,36 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, - LeafOnlyBinarySPDecompositionTree const &post) { +LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_series_split( + LeafOnlyBinarySPDecompositionTree const &pre, + LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_series_split( - std::monostate{}, - pre.raw_tree, - post.raw_tree), + make_generic_binary_series_split( + std::monostate{}, pre.raw_tree, post.raw_tree), }; } template -LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, - LeafOnlyBinarySPDecompositionTree const &rhs) { +LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_parallel_split( + LeafOnlyBinarySPDecompositionTree const &lhs, + LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_parallel_split( - std::monostate{}, - lhs.raw_tree, - rhs.raw_tree), + make_generic_binary_parallel_split( + std::monostate{}, lhs.raw_tree, rhs.raw_tree), }; } template -LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_leaf(LeafLabel const &label) { +LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_leaf(LeafLabel const &label) { return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_sp_leaf< - std::monostate, - std::monostate, LeafLabel>(label), + make_generic_binary_sp_leaf( + label), }; } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h index 400b6be1de..77d7c2fd8d 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -1,51 +1,45 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { template -LeafOnlyBinarySeriesSplit - require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &t) { - GenericBinarySeriesSplit< - std::monostate, - std::monostate, - LeafLabel> raw = +LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( + LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinarySeriesSplit raw = require_generic_binary_series_split(t.raw_tree); return LeafOnlyBinarySeriesSplit{ - LeafOnlyBinarySPDecompositionTree{raw.pre}, - LeafOnlyBinarySPDecompositionTree{raw.post}, + LeafOnlyBinarySPDecompositionTree{raw.pre}, + LeafOnlyBinarySPDecompositionTree{raw.post}, }; } template -LeafOnlyBinaryParallelSplit - require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &t) { - GenericBinaryParallelSplit< - std::monostate, - std::monostate, - LeafLabel> raw = +LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split( + LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinaryParallelSplit raw = require_generic_binary_parallel_split(t.raw_tree); return LeafOnlyBinaryParallelSplit{ - LeafOnlyBinarySPDecompositionTree{raw.lhs}, - LeafOnlyBinarySPDecompositionTree{raw.rhs}, + LeafOnlyBinarySPDecompositionTree{raw.lhs}, + LeafOnlyBinarySPDecompositionTree{raw.rhs}, }; } template -LeafLabel require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &t) { +LeafLabel require_leaf_only_binary_leaf( + LeafOnlyBinarySPDecompositionTree const &t) { return require_generic_binary_leaf(t.raw_tree); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h index 4cbd2b26bd..a18ce37899 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -1,60 +1,58 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" namespace FlexFlow { template -LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { +LeafOnlyBinarySeriesSplit transform( + LeafOnlyBinarySeriesSplit const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const + &visitor) { return LeafOnlyBinarySeriesSplit{ - transform(t.pre, visitor), - transform(t.post, visitor), + transform(t.pre, visitor), + transform(t.post, visitor), }; } template -LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { +LeafOnlyBinaryParallelSplit transform( + LeafOnlyBinaryParallelSplit const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const + &visitor) { return LeafOnlyBinaryParallelSplit{ - transform(t.lhs, visitor), - transform(t.rhs, visitor), + transform(t.lhs, visitor), + transform(t.rhs, visitor), }; } template -LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { - using GenericVisitor = GenericBinarySPDecompositionTreeVisitor - ; +LeafOnlyBinarySPDecompositionTree transform( + LeafOnlyBinarySPDecompositionTree const &t, + LeafOnlyBinarySPDecompositionTreeVisitor const + &visitor) { + using GenericVisitor = GenericBinarySPDecompositionTreeVisitor; GenericVisitor generic_visitor = GenericVisitor{ - [&](std::monostate const &x) { - return x; - }, - [&](std::monostate const &x) { - return x; - }, - [&](LeafLabel const &t) { - return visitor.leaf_func(t); - }, + [&](std::monostate const &x) { return x; }, + [&](std::monostate const &x) { return x; }, + [&](LeafLabel const &t) { return visitor.leaf_func(t); }, }; return LeafOnlyBinarySPDecompositionTree{ - transform(t.raw_tree, generic_visitor), + transform(t.raw_tree, generic_visitor), }; } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h index 21fae97633..e13cea0fdb 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h @@ -10,34 +10,28 @@ namespace FlexFlow { template -LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySeriesSplit const &split) { +LeafOnlyBinarySPDecompositionTree + wrap_series_split(LeafOnlyBinarySeriesSplit const &split) { return LeafOnlyBinarySPDecompositionTree{ - wrap_series_split( - GenericBinarySeriesSplit< - std::monostate, - std::monostate, - LeafLabel>{ - std::monostate{}, - split.pre.raw_tree, - split.post.raw_tree, - } - ), + wrap_series_split( + GenericBinarySeriesSplit{ + std::monostate{}, + split.pre.raw_tree, + split.post.raw_tree, + }), }; } template -LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryParallelSplit const &split) { +LeafOnlyBinarySPDecompositionTree + wrap_parallel_split(LeafOnlyBinaryParallelSplit const &split) { return LeafOnlyBinarySPDecompositionTree{ - wrap_parallel_split( - GenericBinaryParallelSplit< - std::monostate, - std::monostate, - LeafLabel>{ - std::monostate{}, - split.lhs.raw_tree, - split.rhs.raw_tree, - } - ), + wrap_parallel_split( + GenericBinaryParallelSplit{ + std::monostate{}, + split.lhs.raw_tree, + split.rhs.raw_tree, + }), }; } diff --git a/lib/utils/src/utils/any_value_type/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc index b3a72dafa9..d4c605c441 100644 --- a/lib/utils/src/utils/any_value_type/any_value_type.cc +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -2,13 +2,13 @@ namespace FlexFlow { -any_value_type::any_value_type(std::any const &value, - std::function const &eq, - std::function const &neq, - std::function const &hash, - std::function const &to_string) - : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) -{} +any_value_type::any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string) + : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) {} bool any_value_type::operator==(any_value_type const &other) const { return this->eq(this->value, other.value); @@ -26,7 +26,8 @@ std::string format_as(any_value_type const &v) { namespace std { -size_t hash<::FlexFlow::any_value_type>::operator()(::FlexFlow::any_value_type const &v) const { +size_t hash<::FlexFlow::any_value_type>::operator()( + ::FlexFlow::any_value_type const &v) const { return v.hash(v); } diff --git a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc index 63d083fc5b..8445a2721a 100644 --- a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc +++ b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc @@ -9,13 +9,15 @@ BinaryTreePath binary_tree_root_path() { BinaryTreePath nest_inside_left_child(BinaryTreePath const &p) { BinaryTreePath result = p; - result.entries.insert(result.entries.begin(), BinaryTreePathEntry::LEFT_CHILD); + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::LEFT_CHILD); return result; } BinaryTreePath nest_inside_right_child(BinaryTreePath const &p) { BinaryTreePath result = p; - result.entries.insert(result.entries.begin(), BinaryTreePathEntry::RIGHT_CHILD); + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::RIGHT_CHILD); return result; } @@ -25,7 +27,7 @@ BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &p) { BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &p) { return BinaryTreePath{ - subvec(p.entries, 1, std::nullopt), + subvec(p.entries, 1, std::nullopt), }; } diff --git a/lib/utils/src/utils/full_binary_tree/fmt.cc b/lib/utils/src/utils/full_binary_tree/fmt.cc index 82bf382821..9e4d328be3 100644 --- a/lib/utils/src/utils/full_binary_tree/fmt.cc +++ b/lib/utils/src/utils/full_binary_tree/fmt.cc @@ -4,7 +4,9 @@ namespace FlexFlow { template std::string format_as(FullBinaryTreeParentNode const &); template std::string format_as(FullBinaryTree const &); -template std::ostream &operator<<(std::ostream &, FullBinaryTreeParentNode const &); -template std::ostream &operator<<(std::ostream &, FullBinaryTree const &); +template std::ostream &operator<<(std::ostream &, + FullBinaryTreeParentNode const &); +template std::ostream &operator<<(std::ostream &, + FullBinaryTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_label.cc b/lib/utils/src/utils/full_binary_tree/get_label.cc index 25ed6cf3f6..1270dcbc9d 100644 --- a/lib/utils/src/utils/full_binary_tree/get_label.cc +++ b/lib/utils/src/utils/full_binary_tree/get_label.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - int get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &); +template int get_full_binary_tree_parent_label( + FullBinaryTreeParentNode const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/make.cc b/lib/utils/src/utils/full_binary_tree/make.cc index da48d2a2c4..8de1e60eb7 100644 --- a/lib/utils/src/utils/full_binary_tree/make.cc +++ b/lib/utils/src/utils/full_binary_tree/make.cc @@ -2,11 +2,10 @@ namespace FlexFlow { -template - FullBinaryTree make_full_binary_tree_parent(int const &, - FullBinaryTree const &, - FullBinaryTree const &); -template - FullBinaryTree make_full_binary_tree_leaf(int const &); +template FullBinaryTree + make_full_binary_tree_parent(int const &, + FullBinaryTree const &, + FullBinaryTree const &); +template FullBinaryTree make_full_binary_tree_leaf(int const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/require.cc b/lib/utils/src/utils/full_binary_tree/require.cc index d6b0bbeb68..e4454927a4 100644 --- a/lib/utils/src/utils/full_binary_tree/require.cc +++ b/lib/utils/src/utils/full_binary_tree/require.cc @@ -2,10 +2,9 @@ namespace FlexFlow { -template - FullBinaryTreeParentNode const & +template FullBinaryTreeParentNode const & require_full_binary_tree_parent_node(FullBinaryTree const &); -template - int const &require_full_binary_tree_leaf(FullBinaryTree const &); +template int const & + require_full_binary_tree_leaf(FullBinaryTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc index b43eb5bce6..1b75630269 100644 --- a/lib/utils/src/utils/full_binary_tree/visit.cc +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - int visit(FullBinaryTree const &, FullBinaryTreeVisitor const &); +template int visit(FullBinaryTree const &, + FullBinaryTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc index 53e17f3917..c07d344d05 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -2,14 +2,13 @@ namespace FlexFlow { -std::unordered_set get_dataflow_edges_from_node_to_node(DataflowGraphView const &g, - Node const &src, - Node const &dst) { +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst) { return g.query_edges(DataflowEdgeQuery{ - /*src_nodes=*/query_set{src}, - /*src_idxs=*/query_set::matchall(), - /*dst_nodes=*/query_set{dst}, - /*dst_idxs=*/query_set::matchall(), + /*src_nodes=*/query_set{src}, + /*src_idxs=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_idxs=*/query_set::matchall(), }); } diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc index 66152b9b13..70a66c9a21 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -3,21 +3,21 @@ namespace FlexFlow { -SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split(TransitiveReducedDataflowGraphView const &tr_g, - BinarySeriesSplit const &split) { - std::unordered_set edges = get_transitive_reduced_edges_across_split(tr_g, split); +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set edges = + get_transitive_reduced_edges_across_split(tr_g, split); - std::unordered_set src_boundary_nodes - = transform(edges, - [](DataflowEdge const &e) { return e.src.node; }); + std::unordered_set src_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.src.node; }); - std::unordered_set dst_boundary_nodes - = transform(edges, - [](DataflowEdge const &e) { return e.dst.node; }); + std::unordered_set dst_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.dst.node; }); return SplitBoundaryNodes{ - /*pre_split_boundary=*/src_boundary_nodes, - /*post_split_boundary=*/dst_boundary_nodes, + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, }; } diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index 49783ee0d5..fd2154d6c0 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -7,24 +7,22 @@ namespace FlexFlow { -std::unordered_set - get_transitive_reduced_edges_across_split(TransitiveReducedDataflowGraphView const &tr_g, - BinarySeriesSplit const &split) { - std::unordered_set src_subgraph = unordered_set_of(get_leaves(get_left_child(split))); - std::unordered_set dst_subgraph = unordered_set_of(get_leaves(get_right_child(split))); +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set src_subgraph = + unordered_set_of(get_leaves(get_left_child(split))); + std::unordered_set dst_subgraph = + unordered_set_of(get_leaves(get_right_child(split))); - std::unordered_set raw_edges = get_edges_from_subgraph_to_subgraph(tr_g.transitive_reduction, - src_subgraph, - dst_subgraph); - - return flatmap(raw_edges, - [&](DirectedEdge const &e) { - return get_dataflow_edges_from_node_to_node(tr_g.full_dataflow_graph, - e.src, - e.dst); - }); + std::unordered_set raw_edges = + get_edges_from_subgraph_to_subgraph( + tr_g.transitive_reduction, src_subgraph, dst_subgraph); + return flatmap(raw_edges, [&](DirectedEdge const &e) { + return get_dataflow_edges_from_node_to_node( + tr_g.full_dataflow_graph, e.src, e.dst); + }); } - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc index d4e285e5c3..0bb94c87f4 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -4,9 +4,9 @@ namespace FlexFlow { -std::unordered_set - get_transitive_reduced_outputs_across_split(TransitiveReducedDataflowGraphView const &tr_g, - BinarySeriesSplit const &split) { +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { return transform(get_transitive_reduced_edges_across_split(tr_g, split), [](DataflowEdge const &e) { return e.src; }); } diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc index a068679be4..81751702a2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc @@ -3,13 +3,14 @@ namespace FlexFlow { -TransitiveReducedDataflowGraphView get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { DiGraphView as_digraph = g; DiGraphView transitive_reduced = transitive_reduction(as_digraph); return TransitiveReducedDataflowGraphView{ - /*full_dataflow_graph=*/g, - /*transitive_reduction=*/transitive_reduced, + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, }; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc index 72200ec483..2c6606a06b 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -3,16 +3,22 @@ namespace FlexFlow { -std::unordered_set get_edges_from_subgraph_to_subgraph(DiGraphView const &g, - std::unordered_set const &src_subgraph, - std::unordered_set const &dst_subgraph) { +std::unordered_set get_edges_from_subgraph_to_subgraph( + DiGraphView const &g, + std::unordered_set const &src_subgraph, + std::unordered_set const &dst_subgraph) { if (!are_disjoint(src_subgraph, dst_subgraph)) { - throw mk_runtime_error(fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) expected src_subgraph and dst_subgraph to be disjoint, but found src_subgraph={}, dst_subgraph={}", src_subgraph, dst_subgraph)); + throw mk_runtime_error( + fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) " + "expected src_subgraph and dst_subgraph to be disjoint, " + "but found src_subgraph={}, dst_subgraph={}", + src_subgraph, + dst_subgraph)); } return g.query_edges(DirectedEdgeQuery{ - /*srcs=*/query_set{src_subgraph}, - /*dsts=*/query_set{dst_subgraph}, + /*srcs=*/query_set{src_subgraph}, + /*dsts=*/query_set{dst_subgraph}, }); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index b07423a21a..08dda09698 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -5,8 +5,8 @@ #include "utils/containers/values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" -#include "utils/overload.h" #include "utils/hash/vector.h" +#include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc index 6763d9442b..248522a3a3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc @@ -5,13 +5,13 @@ namespace FlexFlow { BinarySPDecompositionTree get_left_child(BinaryParallelSplit const &s) { return BinarySPDecompositionTree{ - get_left_child(s.raw_split), + get_left_child(s.raw_split), }; } BinarySPDecompositionTree get_right_child(BinaryParallelSplit const &s) { return BinarySPDecompositionTree{ - get_right_child(s.raw_split), + get_right_child(s.raw_split), }; } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc index 79bde7899a..1e80bd68b6 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc @@ -5,13 +5,13 @@ namespace FlexFlow { BinarySPDecompositionTree get_left_child(BinarySeriesSplit const &split) { return BinarySPDecompositionTree{ - get_left_child(split.raw_split), + get_left_child(split.raw_split), }; } BinarySPDecompositionTree get_right_child(BinarySeriesSplit const &split) { return BinarySPDecompositionTree{ - get_right_child(split.raw_split), + get_right_child(split.raw_split), }; } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index a4bd8b1ba7..79042fd061 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,10 +1,10 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" namespace FlexFlow { @@ -44,13 +44,13 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { BinarySeriesSplit require_series(BinarySPDecompositionTree const &tt) { return BinarySeriesSplit{ - require_leaf_only_binary_series_split(tt.raw_tree), + require_leaf_only_binary_series_split(tt.raw_tree), }; } BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &tt) { return BinaryParallelSplit{ - require_leaf_only_binary_parallel_split(tt.raw_tree), + require_leaf_only_binary_parallel_split(tt.raw_tree), }; } @@ -62,5 +62,4 @@ SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tt) { return get_node_type(tt.raw_tree); } - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc index 2ecd4c94d2..d14dd7641c 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -2,8 +2,8 @@ namespace FlexFlow { -template - std::unordered_set find_paths_to_leaf(GenericBinarySPDecompositionTree const &, - int const &); +template std::unordered_set + find_paths_to_leaf(GenericBinarySPDecompositionTree const &, + int const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc index 10bbc60c6d..16ca73e01d 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc @@ -2,15 +2,15 @@ namespace FlexFlow { -template - SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &); -template - GenericBinarySPSplitLabel make_generic_binary_series_split_label(int const &); -template - GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(int const &); -template - int require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &); -template - int require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &); +template SPDecompositionTreeNodeType + get_node_type(GenericBinarySPSplitLabel const &); +template GenericBinarySPSplitLabel + make_generic_binary_series_split_label(int const &); +template GenericBinarySPSplitLabel + make_generic_binary_parallel_split_label(int const &); +template int require_generic_binary_series_split_label( + GenericBinarySPSplitLabel const &); +template int require_generic_binary_parallel_split_label( + GenericBinarySPSplitLabel const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc index 31e664b726..970f401584 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - std::unordered_set get_all_leaf_paths(GenericBinarySPDecompositionTree const &); +template std::unordered_set + get_all_leaf_paths(GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 20ba3fa5d7..ccbd4a8c10 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -2,12 +2,11 @@ namespace FlexFlow { -template - std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &); -template - std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &); -template - std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &); +template std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &); +template std::unordered_multiset + get_leaves(GenericBinarySeriesSplit const &); +template std::unordered_multiset + get_leaves(GenericBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc index 783a7a974b..697fb417d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -2,11 +2,9 @@ namespace FlexFlow { -template - GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &); -template - GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &); +template GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &); +template GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc index 9d652d44da..e66b996721 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -template - SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &); +template SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 6c67fdc244..694d981733 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -2,11 +2,11 @@ namespace FlexFlow { -template - int get_num_tree_nodes(GenericBinarySPDecompositionTree const &); -template - int get_num_tree_nodes(GenericBinarySeriesSplit const &); -template - int get_num_tree_nodes(GenericBinaryParallelSplit const &); +template int + get_num_tree_nodes(GenericBinarySPDecompositionTree const &); +template int + get_num_tree_nodes(GenericBinarySeriesSplit const &); +template int + get_num_tree_nodes(GenericBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc index 03c154fb67..ec56627455 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -2,11 +2,9 @@ namespace FlexFlow { -template - GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &); -template - GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &); +template GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &); +template GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc index 6bfb573359..ac5509045a 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -template - std::optional> +template std::optional> get_subtree_at_path(GenericBinarySPDecompositionTree const &, BinaryTreePath const &); diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc index 5e5b768ed7..056435531f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -2,11 +2,10 @@ namespace FlexFlow { -template - bool is_series_split(GenericBinarySPDecompositionTree const &); -template - bool is_parallel_split(GenericBinarySPDecompositionTree const &); -template - bool is_leaf(GenericBinarySPDecompositionTree const &); +template bool + is_series_split(GenericBinarySPDecompositionTree const &); +template bool + is_parallel_split(GenericBinarySPDecompositionTree const &); +template bool is_leaf(GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 87ae55b900..18bc1f3030 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -template - bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &); +template bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 5a40a3b6bf..fc6a5ee041 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -template - bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &); +template bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc index a36ccce359..27219bd4d8 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -2,17 +2,17 @@ namespace FlexFlow { -template - GenericBinarySPDecompositionTree make_generic_binary_series_split( - int const &, - GenericBinarySPDecompositionTree const &, - GenericBinarySPDecompositionTree const &); -template - GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - int const &label, - GenericBinarySPDecompositionTree const &, - GenericBinarySPDecompositionTree const &); -template - GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(int const &); +template GenericBinarySPDecompositionTree + make_generic_binary_series_split( + int const &, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template GenericBinarySPDecompositionTree + make_generic_binary_parallel_split( + int const &label, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template GenericBinarySPDecompositionTree + make_generic_binary_sp_leaf(int const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc index 8305a1243e..10029ceedd 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -2,13 +2,13 @@ namespace FlexFlow { -template - GenericBinarySeriesSplit - require_generic_binary_series_split(GenericBinarySPDecompositionTree const &); -template - GenericBinaryParallelSplit - require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &); -template - int require_generic_binary_leaf(GenericBinarySPDecompositionTree const &); +template GenericBinarySeriesSplit + require_generic_binary_series_split( + GenericBinarySPDecompositionTree const &); +template GenericBinaryParallelSplit + require_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &); +template int require_generic_binary_leaf( + GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc index 4495a60f92..3193f8828c 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -2,9 +2,13 @@ namespace FlexFlow { -template - GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &, - GenericBinarySPDecompositionTreeVisitor const &); +template GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &, + GenericBinarySPDecompositionTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc index 0b3189b47b..007f1dbd52 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc @@ -2,11 +2,9 @@ namespace FlexFlow { -template - GenericBinarySPDecompositionTree +template GenericBinarySPDecompositionTree wrap_series_split(GenericBinarySeriesSplit const &); -template - GenericBinarySPDecompositionTree +template GenericBinarySPDecompositionTree wrap_parallel_split(GenericBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc index 41accc79d0..61e5c8a9fa 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &); +template std::unordered_multiset + get_leaves(LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc index 0959a42f01..90fed4010d 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - SPDecompositionTreeNodeType get_node_type(LeafOnlyBinarySPDecompositionTree const &); +template SPDecompositionTreeNodeType + get_node_type(LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index dd94936997..9e00926b58 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -2,8 +2,7 @@ namespace FlexFlow { -template - bool is_binary_sp_tree_left_associative(LeafOnlyBinarySPDecompositionTree const &); - +template bool is_binary_sp_tree_left_associative( + LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 46b89aa98f..bec3410841 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template - bool is_binary_sp_tree_right_associative(LeafOnlyBinarySPDecompositionTree const &); +template bool is_binary_sp_tree_right_associative( + LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc index 5690ebe8a8..62e21510a8 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc @@ -2,9 +2,9 @@ namespace FlexFlow { -template - LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &); -template - LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &); +template LeafOnlyBinarySPDecompositionTree + get_left_child(LeafOnlyBinaryParallelSplit const &); +template LeafOnlyBinarySPDecompositionTree + get_right_child(LeafOnlyBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc index ed0e5892da..efb5d779a8 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc @@ -2,9 +2,9 @@ namespace FlexFlow { -template - LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &); -template - LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &); +template LeafOnlyBinarySPDecompositionTree + get_left_child(LeafOnlyBinarySeriesSplit const &); +template LeafOnlyBinarySPDecompositionTree + get_right_child(LeafOnlyBinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc index 112a14c206..07ae9e604a 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc @@ -2,13 +2,15 @@ namespace FlexFlow { -template - LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_series_split(LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTree const &); -template - LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTree const &); -template - LeafOnlyBinarySPDecompositionTree leaf_only_binary_sp_tree_make_leaf(int const &); +template LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_series_split( + LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_parallel_split( + LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinarySPDecompositionTree + leaf_only_binary_sp_tree_make_leaf(int const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc index 1a1cd9909d..75c568fa4a 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc @@ -2,8 +2,12 @@ namespace FlexFlow { -template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &); -template LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &); -template int require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( + LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinaryParallelSplit + require_leaf_only_binary_parallel_split( + LeafOnlyBinarySPDecompositionTree const &); +template int require_leaf_only_binary_leaf( + LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc index 22dd5e0db5..c7fb0811df 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc @@ -2,15 +2,15 @@ namespace FlexFlow { -template - LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); -template - LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); +template LeafOnlyBinarySeriesSplit transform( + LeafOnlyBinarySeriesSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); +template LeafOnlyBinaryParallelSplit transform( + LeafOnlyBinaryParallelSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); -template - LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); +template LeafOnlyBinarySPDecompositionTree transform( + LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc index 3836124eb6..5d417271b4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc @@ -2,9 +2,9 @@ namespace FlexFlow { -template - LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySeriesSplit const &); -template - LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryParallelSplit const &); +template LeafOnlyBinarySPDecompositionTree + wrap_series_split(LeafOnlyBinarySeriesSplit const &); +template LeafOnlyBinarySPDecompositionTree + wrap_parallel_split(LeafOnlyBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index eb66ce2f68..feb4749d0c 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -21,8 +21,7 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( return make_leaf_node(n); }; - auto from_series = - [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { std::vector children = transform(s.children, from_series_child); return foldl1(children, @@ -63,10 +62,10 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( }; return nary.visit(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }); + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index bebb97defc..a4f6000900 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,9 +16,7 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( std::variant const &)> from_parallel_child; - auto from_node = [](Node const &n) { - return make_leaf_node(n); - }; + auto from_node = [](Node const &n) { return make_leaf_node(n); }; auto from_series = [&](SeriesSplit const &s) { std::vector children = diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 996803d1ac..0ad586d499 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,8 +1,8 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/containers/extend.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index dc8d8437a9..c10cc5ae75 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -1,11 +1,11 @@ #include "utils/containers/flatmap.h" -#include -#include +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "utils/containers/map_keys.h" #include "utils/hash/pair.h" -#include "test/utils/doctest/fmt/unordered_map.h" -#include "test/utils/doctest/fmt/pair.h" +#include +#include using namespace ::FlexFlow; @@ -21,16 +21,17 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("type changing") { std::unordered_set input = {"hello", " ", "", "world", "!"}; - + std::unordered_set result = flatmap(input, get_chars); - std::unordered_set correct = {'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + std::unordered_set correct = { + 'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; CHECK(result == correct); } SUBCASE("input is empty") { std::unordered_set input = {}; - + std::unordered_set result = flatmap(input, get_chars); std::unordered_set correct = {}; @@ -39,14 +40,16 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("flatmap(std::unordered_map, F)") { - auto de_nest_keys = [](int k1, std::unordered_map const &v) { + auto de_nest_keys = [](int k1, + std::unordered_map const &v) { return map_keys(v, [&](int k2) { return std::pair{k1, k2}; }); }; SUBCASE("input is empty") { std::unordered_map> input = {}; - std::unordered_map, std::string> result = flatmap(input, de_nest_keys); + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); std::unordered_map, std::string> correct = {}; CHECK(result == correct); @@ -54,30 +57,31 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not empty") { std::unordered_map> input = { - { - 1, { - {2, "a"}, - {3, "b"}, + 1, + { + {2, "a"}, + {3, "b"}, + }, + }, + { + 2, + {}, }, - }, - { - 2, - {}, - }, - { - 3, { - {3, "a"}, + 3, + { + {3, "a"}, + }, }, - }, }; - std::unordered_map, std::string> result = flatmap(input, de_nest_keys); + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); std::unordered_map, std::string> correct = { - {{1, 2}, "a"}, - {{1, 3}, "b"}, - {{3, 3}, "a"}, + {{1, 2}, "a"}, + {{1, 3}, "b"}, + {{3, 3}, "a"}, }; CHECK(result == correct); @@ -86,13 +90,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("duplicate result keys") { auto always_return_same_map = [](int, std::string const &) { return std::unordered_map{ - {"mykey", 10000}, + {"mykey", 10000}, }; }; std::unordered_map input = { - {1, "a"}, - {2, "b"}, + {1, "a"}, + {2, "b"}, }; CHECK_THROWS(flatmap(input, always_return_same_map)); diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc index 17a4e6e749..d5f989318f 100644 --- a/lib/utils/test/src/utils/containers/get_all_assignments.cc +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -1,7 +1,7 @@ #include "utils/containers/get_all_assignments.h" -#include #include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include using namespace ::FlexFlow; @@ -10,26 +10,28 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("empty input") { std::unordered_map> input = {}; - std::unordered_set> result = get_all_assignments(input); + std::unordered_set> result = + get_all_assignments(input); std::unordered_set> correct = {{}}; - CHECK(result == correct); + CHECK(result == correct); } SUBCASE("non-empty input") { std::unordered_map> input = { - {"a", {1, 2, 3}}, - {"b", {2, 3}}, + {"a", {1, 2, 3}}, + {"b", {2, 3}}, }; - std::unordered_set> result = get_all_assignments(input); + std::unordered_set> result = + get_all_assignments(input); std::unordered_set> correct = { - {{"a", 1}, {"b", 2}}, - {{"a", 1}, {"b", 3}}, - {{"a", 2}, {"b", 2}}, - {{"a", 2}, {"b", 3}}, - {{"a", 3}, {"b", 2}}, - {{"a", 3}, {"b", 3}}, + {{"a", 1}, {"b", 2}}, + {{"a", 1}, {"b", 3}}, + {{"a", 2}, {"b", 2}}, + {{"a", 2}, {"b", 3}}, + {{"a", 3}, {"b", 2}}, + {{"a", 3}, {"b", 3}}, }; CHECK(result == correct); @@ -37,11 +39,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("one possible-values set is empty") { std::unordered_map> input = { - {"a", {}}, - {"b", {2, 3}}, + {"a", {}}, + {"b", {2, 3}}, }; - std::unordered_set> result = get_all_assignments(input); + std::unordered_set> result = + get_all_assignments(input); std::unordered_set> correct = {}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/try_at.cc b/lib/utils/test/src/utils/containers/try_at.cc index 818456f65a..548c9b0c79 100644 --- a/lib/utils/test/src/utils/containers/try_at.cc +++ b/lib/utils/test/src/utils/containers/try_at.cc @@ -1,12 +1,15 @@ #include "utils/containers/try_at.h" +#include "test/utils/doctest/fmt/optional.h" #include #include -#include "test/utils/doctest/fmt/optional.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("try_at(T, K)", T, std::unordered_map, std::map) { + TEST_CASE_TEMPLATE("try_at(T, K)", + T, + std::unordered_map, + std::map) { T m = {{1, "one"}, {2, "two"}}; SUBCASE("map contains key") { diff --git a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc index a87e54ed8e..f0cdb19611 100644 --- a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc +++ b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc @@ -1,9 +1,9 @@ #include "utils/containers/unordered_map_from_pairs.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "utils/containers/contains.h" #include -#include #include -#include "utils/containers/contains.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace ::FlexFlow; @@ -11,23 +11,25 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("unordered_map_from_pairs") { SUBCASE("nonempty input") { std::vector> input = { - {1, "hello"}, - {3, "world"}, + {1, "hello"}, + {3, "world"}, }; - std::unordered_map result = unordered_map_from_pairs(input); + std::unordered_map result = + unordered_map_from_pairs(input); std::unordered_map correct = { - {1, "hello"}, - {3, "world"}, + {1, "hello"}, + {3, "world"}, }; - + CHECK(result == correct); } SUBCASE("empty input") { std::vector> input = {}; - std::unordered_map result = unordered_map_from_pairs(input); + std::unordered_map result = + unordered_map_from_pairs(input); std::unordered_map correct = {}; CHECK(result == correct); @@ -35,17 +37,19 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input with duplicate keys") { std::vector> input = { - {1, "a"}, - {2, "c"}, - {1, "b"}, + {1, "a"}, + {2, "c"}, + {1, "b"}, }; - std::unordered_map result = unordered_map_from_pairs(input); + std::unordered_map result = + unordered_map_from_pairs(input); - std::vector> possible_correct_values = { - {{1, "a"}, {2, "c"}}, - {{1, "b"}, {2, "c"}}, - }; + std::vector> + possible_correct_values = { + {{1, "a"}, {2, "c"}}, + {{1, "b"}, {2, "c"}}, + }; CHECK(contains(possible_correct_values, result)); } diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc index a93f22802c..fec5d3401e 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -1,8 +1,8 @@ #include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" -#include #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include using namespace ::FlexFlow; @@ -19,20 +19,21 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({n1_o0, n1_o0, n1_o1}, 0); Node n2 = n2_added.node; - std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); std::unordered_set correct = { - DataflowEdge{ - n1_o0, - DataflowInput{n2, 0}, - }, - DataflowEdge{ - n1_o0, - DataflowInput{n2, 1}, - }, - DataflowEdge{ - n1_o1, - DataflowInput{n2, 2}, - }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, }; CHECK(result == correct); @@ -51,13 +52,15 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = n3_added.node; DataflowOutput o3 = get_only(n3_added.outputs); - std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n3); + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n3); std::unordered_set correct = {}; CHECK(result == correct); } - SUBCASE("does not get flipped edges (i.e., respects from vs to direction)") { + SUBCASE( + "does not get flipped edges (i.e., respects from vs to direction)") { NodeAddedResult n1_added = g.add_node({}, 1); Node n1 = n1_added.node; DataflowOutput o1 = get_only(n1_added.outputs); @@ -65,7 +68,8 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({o1}, 0); Node n2 = n2_added.node; - std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n2, n1); + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n2, n1); std::unordered_set correct = {}; CHECK(result == correct); @@ -78,17 +82,20 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({}, 1); Node n2 = n2_added.node; - std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); std::unordered_set correct = {}; CHECK(result == correct); } - SUBCASE("returns empty set if src node == dst node (as cycles cannot exist in DataflowGraph") { + SUBCASE("returns empty set if src node == dst node (as cycles cannot exist " + "in DataflowGraph") { NodeAddedResult n1_added = g.add_node({}, 1); Node n1 = n1_added.node; - std::unordered_set result = get_dataflow_edges_from_node_to_node(g, n1, n1); + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n1); std::unordered_set correct = {}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc index 1a47dfde25..60e109faa9 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -1,8 +1,8 @@ #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({o1}, 1); Node n2 = n2_added.node; DataflowOutput o2 = get_only(n2_added.outputs); - + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); Node n3 = n3_added.node; DataflowOutput o3 = get_only(n3_added.outputs); @@ -28,21 +28,18 @@ TEST_SUITE(FF_TEST_SUITE) { Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(\ - make_series_split( - make_series_split( - make_leaf_node(n1), - make_leaf_node(n2)), - make_series_split( - make_leaf_node(n3), - make_leaf_node(n4)))); + BinarySeriesSplit split = require_series(make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); - SplitBoundaryNodes result = get_transitive_reduced_boundary_nodes_for_split(tr_g, split); + SplitBoundaryNodes result = + get_transitive_reduced_boundary_nodes_for_split(tr_g, split); SplitBoundaryNodes correct = SplitBoundaryNodes{ - /*pre_split_boundary=*/{n2}, - /*post_split_boundary=*/{n3}, + /*pre_split_boundary=*/{n2}, + /*post_split_boundary=*/{n3}, }; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index 915be7261e..ed66292462 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -1,8 +1,8 @@ #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" #include "utils/containers/get_only.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include @@ -20,7 +20,7 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({}, 1); Node n2 = n2_added.node; DataflowOutput o2 = get_only(n2_added.outputs); - + NodeAddedResult n3_added = g.add_node({o2, o1}, 1); Node n3 = n3_added.node; DataflowOutput o3 = get_only(n3_added.outputs); @@ -28,32 +28,29 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n4_added = g.add_node({o1}, 1); Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - - TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - - BinarySeriesSplit split = require_series(\ - make_series_split( - make_parallel_split( - make_leaf_node(n1), - make_leaf_node(n2)), - make_parallel_split( - make_leaf_node(n3), - make_leaf_node(n4)))); - - std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = require_series(make_series_split( + make_parallel_split(make_leaf_node(n1), make_leaf_node(n2)), + make_parallel_split(make_leaf_node(n3), make_leaf_node(n4)))); + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); std::unordered_set correct = { - DataflowEdge{ - o1, - DataflowInput{n3, 1}, - }, - DataflowEdge{ - o2, - DataflowInput{n3, 0}, - }, - DataflowEdge{ - o1, - DataflowInput{n4, 0}, - }, + DataflowEdge{ + o1, + DataflowInput{n3, 1}, + }, + DataflowEdge{ + o2, + DataflowInput{n3, 0}, + }, + DataflowEdge{ + o1, + DataflowInput{n4, 0}, + }, }; CHECK(result == correct); @@ -67,28 +64,28 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({n1_o1, n1_o2, n1_o1}, 1); Node n2 = n2_added.node; - - TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(\ - make_series_split( - make_leaf_node(n1), - make_leaf_node(n2))); + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); - std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + BinarySeriesSplit split = require_series( + make_series_split(make_leaf_node(n1), make_leaf_node(n2))); + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); std::unordered_set correct = { - DataflowEdge{ - n1_o1, - DataflowInput{n2, 0}, - }, - DataflowEdge{ - n1_o2, - DataflowInput{n2, 1}, - }, - DataflowEdge{ - n1_o1, - DataflowInput{n2, 2}, - }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o2, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, }; CHECK(result == correct); @@ -102,7 +99,7 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({o1}, 1); Node n2 = n2_added.node; DataflowOutput o2 = get_only(n2_added.outputs); - + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); Node n3 = n3_added.node; DataflowOutput o3 = get_only(n3_added.outputs); @@ -111,23 +108,20 @@ TEST_SUITE(FF_TEST_SUITE) { Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(\ - make_series_split( - make_series_split( - make_leaf_node(n1), - make_leaf_node(n2)), - make_series_split( - make_leaf_node(n3), - make_leaf_node(n4)))); + BinarySeriesSplit split = require_series(make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); - std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); std::unordered_set correct = { - DataflowEdge{ - o2, - DataflowInput{n3, 1}, - }, + DataflowEdge{ + o2, + DataflowInput{n3, 1}, + }, }; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc index 2df7c91041..1bd27c5f35 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -1,8 +1,8 @@ #include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" -#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" -#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n2_added = g.add_node({o1}, 1); Node n2 = n2_added.node; DataflowOutput o2 = get_only(n2_added.outputs); - + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); Node n3 = n3_added.node; DataflowOutput o3 = get_only(n3_added.outputs); @@ -28,18 +28,15 @@ TEST_SUITE(FF_TEST_SUITE) { Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(\ - make_series_split( - make_series_split( - make_leaf_node(n1), - make_leaf_node(n2)), - make_series_split( - make_leaf_node(n3), - make_leaf_node(n4)))); + BinarySeriesSplit split = require_series(make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); - std::unordered_set result = get_transitive_reduced_outputs_across_split(tr_g, split); + std::unordered_set result = + get_transitive_reduced_outputs_across_split(tr_g, split); std::unordered_set correct = {o2}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc index c5e25386d5..5a1ea99671 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -1,7 +1,7 @@ #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" -#include #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" +#include using namespace ::FlexFlow; @@ -16,15 +16,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("returns all edges between subgraphs") { std::vector e = { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(4), n.at(2)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(4), n.at(2)}, }; add_edges(g, e); - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); std::unordered_set correct = unordered_set_of(e); CHECK(result == correct); @@ -32,13 +33,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("does not return reverse edges") { std::vector e = { - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(0)}, }; add_edges(g, e); - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); std::unordered_set correct = {e.at(0)}; CHECK(result == correct); @@ -46,27 +48,30 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("does not return edges within subgraph") { std::vector e = { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, }; add_edges(g, e); - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); std::unordered_set correct = {e.at(1)}; CHECK(result == correct); } - SUBCASE("returns no edges if there are no edges from src_subgraph to dst_subgraph") { + SUBCASE("returns no edges if there are no edges from src_subgraph to " + "dst_subgraph") { std::vector e = { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(2), n.at(3)}, }; add_edges(g, e); - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); std::unordered_set correct = {}; CHECK(result == correct); @@ -75,48 +80,53 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("empty subgraphs") { std::vector e = { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, }; add_edges(g, e); SUBCASE("returns no edges if no nodes in src_subgraph") { - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); std::unordered_set correct = {}; CHECK(result == correct); } SUBCASE("returns no edges if no nodes in dst_subgraph") { - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); std::unordered_set correct = {}; CHECK(result == correct); } SUBCASE("returns no edges if both subgraphs are empty") { - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, {}, {}); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, {}); std::unordered_set correct = {}; CHECK(result == correct); } } - SUBCASE("if subgraphs do not cover graph, then does not return external edges") { + SUBCASE("if subgraphs do not cover graph, then does not return external " + "edges") { std::vector e = { - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, }; add_edges(g, e); - + std::unordered_set src_subgraph = {n.at(0)}; std::unordered_set dst_subgraph = {n.at(3)}; - std::unordered_set result = get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); std::unordered_set correct = {e.at(1)}; CHECK(result == correct); @@ -125,7 +135,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("throws an error if subgraphs are not disjoint") { std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(2)}; std::unordered_set dst_subgraph = {n.at(1), n.at(3)}; - CHECK_THROWS(get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); + CHECK_THROWS( + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 3b1e1899ca..8f1b8efaf7 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -11,24 +11,25 @@ TEST_SUITE(FF_TEST_SUITE) { int n3 = 3; int n4 = 4; - auto make_leaf = [](int n) { - return leaf_only_binary_sp_tree_make_leaf(n); + auto make_leaf = [](int n) { + return leaf_only_binary_sp_tree_make_leaf(n); }; - auto make_series_split = [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_series_split(l, r); - }; + auto make_series_split = + [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_series_split(l, r); + }; - auto make_parallel_split = [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_parallel_split(l, r); - }; + auto make_parallel_split = + [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_parallel_split(l, r); + }; SUBCASE("input is actually left associative") { SUBCASE("just node") { - LeafOnlyBinarySPDecompositionTree input = - make_leaf(n1); + LeafOnlyBinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -37,12 +38,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_series_split( - make_leaf(n1), - make_leaf(n2)), - make_leaf(n3)); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -51,12 +48,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = - make_parallel_split( - make_parallel_split( - make_leaf(n1), - make_leaf(n2)), - make_leaf(n3)); + LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -65,14 +58,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_parallel_split( - make_leaf(n1), - make_leaf(n2)), - make_parallel_split( - make_leaf(n3), - make_leaf(n4))); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -83,12 +71,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not left associative") { SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; @@ -97,12 +81,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = - make_parallel_split( - make_leaf(n1), - make_parallel_split( - make_leaf(n2), - make_leaf(n3))); + LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 9e34a769ba..88e08e7624 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -11,24 +11,25 @@ TEST_SUITE(FF_TEST_SUITE) { int n3 = 3; int n4 = 4; - auto make_leaf = [](int n) { - return leaf_only_binary_sp_tree_make_leaf(n); + auto make_leaf = [](int n) { + return leaf_only_binary_sp_tree_make_leaf(n); }; - auto make_series_split = [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_series_split(l, r); - }; + auto make_series_split = + [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_series_split(l, r); + }; - auto make_parallel_split = [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_parallel_split(l, r); - }; + auto make_parallel_split = + [](LeafOnlyBinarySPDecompositionTree const &l, + LeafOnlyBinarySPDecompositionTree const &r) { + return leaf_only_binary_sp_tree_make_parallel_split(l, r); + }; SUBCASE("input is actually right associative") { SUBCASE("just node") { - LeafOnlyBinarySPDecompositionTree input = - make_leaf(n1); + LeafOnlyBinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -37,12 +38,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -51,12 +48,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = - make_parallel_split( - make_leaf(n1), - make_parallel_split( - make_leaf(n2), - make_leaf(n3))); + LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -65,14 +58,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_parallel_split( - make_leaf(n1), - make_leaf(n2)), - make_parallel_split( - make_leaf(n3), - make_leaf(n4))); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -83,12 +71,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not right associative") { SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = - make_series_split( - make_series_split( - make_leaf(n1), - make_leaf(n2)), - make_leaf(n3)); + LeafOnlyBinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; @@ -97,12 +81,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = - make_parallel_split( - make_parallel_split( - make_leaf(n1), - make_leaf(n2)), - make_leaf(n3)); + LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; From 4b180df1f5eba1b1cac18c39b6cba6a99c7179cd Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 5 Oct 2024 11:33:22 -0700 Subject: [PATCH 25/29] Move graph_optimize_state.cc to correct location --- .../src/compiler/graph_optimize_state.cc | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 lib/compiler/src/compiler/graph_optimize_state.cc diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc new file mode 100644 index 0000000000..4b4f323ea4 --- /dev/null +++ b/lib/compiler/src/compiler/graph_optimize_state.cc @@ -0,0 +1,85 @@ +#include "compiler/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState( + GraphOptimizeResult const &graph_optimize_result, float runtime) + : graph_optimize_result(graph_optimize_result), runtime(runtime) {} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + // Note(@wmdi): This is a hack to implement a partially correct homomorphism + // check. Switch to the homomorphism check used in substitutions right after + // https://github.com/flexflow/FlexFlow/pull/1471 is merged. + auto layers1 = topological_ordering(graph_optimize_result.pcg); + auto layers2 = topological_ordering(other.graph_optimize_result.pcg); + if (layers1.size() != layers2.size()) { + return false; + } + std::unordered_map mapping; + for (size_t i = 0; i < layers1.size(); ++i) { + if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != + get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { + return false; + } + auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); + auto inputs2 = + get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); + if (inputs1.size() != inputs2.size()) { + return false; + } + for (size_t j = 0; j < inputs1.size(); ++j) { + if (inputs1[j] != mapping.at(inputs2[j])) { + return false; + } + } + auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); + auto outputs2 = + get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); + if (outputs1.size() != outputs2.size()) { + return false; + } + for (size_t j = 0; j < outputs1.size(); ++j) { + mapping.emplace(outputs2[j], outputs1[j]); + } + } + return true; +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + size_t seed = 0; + auto layers = topological_ordering(state.graph_optimize_result.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (auto layer : layers) { + ::FlexFlow::hash_combine( + seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); + auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (auto input : inputs) { + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(input) == layers[i]) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } + } + } + return seed; +} + +} // namespace std From dcd2e130851a7ee7221666f99534961ef99a5425 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 7 Oct 2024 16:20:12 -0700 Subject: [PATCH 26/29] Further code simplification and polishing --- .../export_model_arch/json_sp_model_export.h | 17 ++ .../json_sp_model_export.struct.toml | 10 +- .../src/export_model_arch.cc | 4 +- .../export_model_arch/json_sp_model_export.cc | 20 ++ flake.lock | 6 +- ...tracted_tensor_set_movement_across_split.h | 2 +- .../get_tensor_set_movement_across_split.h | 2 +- .../machine_mapping_constraints.h | 2 +- .../get_machine_mapping_problem_tree.h | 2 +- .../machine_mapping_problem_tree.h | 44 +--- .../machine_mapping_problem_tree.struct.toml | 18 -- .../machine_mapping_problem_tree.variant.toml | 25 ++ .../mm_problem_tree_parallel_split.h | 14 -- ...mm_problem_tree_parallel_split.struct.toml | 23 +- ...blem_tree_parallel_split_label.struct.toml | 11 - .../mm_problem_tree_series_split.h | 16 -- .../mm_problem_tree_series_split.struct.toml | 27 +- ...roblem_tree_series_split_label.struct.toml | 15 -- .../machine_mapping/transitive_reduced_pcg.h | 2 +- ...on_graph_binary_parallel_split.struct.toml | 27 ++ ...tion_graph_binary_series_split.struct.toml | 27 ++ ...omputation_graph_binary_sp_decomposition.h | 15 +- ...graph_binary_sp_decomposition.variant.toml | 25 ++ ...tion_graph_series_parallel_decomposition.h | 0 ..._graph_binary_sp_decomposition.struct.toml | 16 -- ...get_pcg_balanced_binary_sp_decomposition.h | 0 .../get_pcg_series_parallel_decomposition.h | 0 .../pcg/pcg_binary_parallel_split.h | 13 + .../pcg/pcg_binary_parallel_split.struct.toml | 27 ++ .../pcg/pcg_binary_series_split.h | 13 + .../pcg/pcg_binary_series_split.struct.toml | 27 ++ .../pcg/pcg_binary_sp_decomposition.h | 37 +++ .../pcg_binary_sp_decomposition.variant.toml | 25 ++ .../pcg_binary_parallel_split.h | 14 -- .../pcg_binary_parallel_split.struct.toml | 16 -- .../series_parallel/pcg_binary_series_split.h | 17 -- .../pcg_binary_series_split.struct.toml | 16 -- .../pcg_binary_sp_decomposition.h | 57 ----- .../pcg_binary_sp_decomposition.struct.toml | 16 -- ...racted_tensor_set_movement_across_split.cc | 11 +- .../get_optimal_machine_mapping.cc | 37 ++- .../get_machine_mapping_problem_tree.cc | 36 +-- .../machine_mapping_problem_tree.cc | 120 ++++----- .../mm_problem_tree_parallel_split.cc | 19 -- .../mm_problem_tree_series_split.cc | 22 -- .../machine_mapping/transitive_reduced_pcg.cc | 10 +- ...mputation_graph_binary_sp_decomposition.cc | 145 +++++++++++ ...ion_graph_series_parallel_decomposition.cc | 2 +- ...mputation_graph_binary_sp_decomposition.cc | 80 ------ .../get_pcg_series_parallel_decomposition.cc | 2 +- .../pcg/pcg_binary_sp_decomposition.cc | 93 +++++++ .../pcg_binary_parallel_split.cc | 19 -- .../pcg_binary_series_split.cc | 29 --- .../pcg_binary_sp_decomposition.cc | 81 ------ ...racted_tensor_set_movement_across_split.cc | 72 +++--- .../get_optimal_machine_mapping.cc | 35 ++- .../get_machine_mapping_problem_tree.cc | 82 +++++-- ...ion_graph_series_parallel_decomposition.cc | 34 +-- .../pcg/initializer_attrs.variant.toml | 2 +- .../include/utils/archetypes/value_type.h | 36 +++ lib/utils/include/utils/fmt/json.h | 21 ++ .../full_binary_tree/find_paths_to_leaf.h | 46 ++-- .../include/utils/full_binary_tree/fmt.h | 47 ---- .../utils/full_binary_tree/full_binary_tree.h | 106 -------- .../full_binary_tree.variant.toml | 24 ++ ...ull_binary_tree_implementation.struct.toml | 33 +++ .../full_binary_tree_parent_node.struct.toml | 34 +++ .../full_binary_tree_visitor.struct.toml | 10 +- .../full_binary_tree/get_all_leaf_paths.h | 41 ++-- .../utils/full_binary_tree/get_child.h | 16 +- .../utils/full_binary_tree/get_label.h | 16 -- .../utils/full_binary_tree/get_leaves.h | 36 +-- .../utils/full_binary_tree/get_left_child.h | 16 -- .../utils/full_binary_tree/get_node_type.h | 18 +- .../full_binary_tree/get_num_tree_nodes.h | 26 ++ .../utils/full_binary_tree/get_right_child.h | 16 -- .../full_binary_tree/get_subtree_at_path.h | 34 +-- .../include/utils/full_binary_tree/hash.h | 29 --- .../include/utils/full_binary_tree/make.h | 32 --- .../include/utils/full_binary_tree/require.h | 7 +- .../include/utils/full_binary_tree/visit.h | 29 +-- .../binary_parallel_split.h | 14 -- .../binary_parallel_split.struct.toml | 19 +- .../binary_series_split.h | 14 -- .../binary_series_split.struct.toml | 19 +- .../binary_sp_decomposition_tree.h | 38 +-- .../binary_sp_decomposition_tree.struct.toml | 16 -- .../binary_sp_decomposition_tree.variant.toml | 25 ++ .../find_paths_to_leaf.h | 16 +- .../generic_binary_parallel_split.struct.toml | 29 ++- ...ic_binary_parallel_split_label.struct.toml | 16 -- .../generic_binary_series_split.struct.toml | 28 ++- ...eric_binary_series_split_label.struct.toml | 16 -- ...c_binary_sp_decomposition_tree.struct.toml | 29 --- ..._binary_sp_decomposition_tree.variant.toml | 31 +++ ...ary_sp_decomposition_tree_implementation.h | 63 +++++ ...omposition_tree_implementation.struct.toml | 47 ++++ ...osition_tree_transform_visitor.struct.toml | 28 +++ ..._sp_decomposition_tree_visitor.struct.toml | 21 +- .../generic_binary_sp_split_label.h | 59 ----- ...generic_binary_sp_split_label.variant.toml | 26 -- .../get_all_leaf_paths.h | 15 +- .../get_leaves.h | 49 +--- .../get_left_child.h | 28 --- .../get_node_type.h | 29 +-- .../get_num_tree_nodes.h | 44 +--- .../get_right_child.h | 28 --- .../get_subtree_at_path.h | 32 +-- .../is_binary_sp_tree_left_associative.h | 49 ++-- .../is_binary_sp_tree_right_associative.h | 48 ++-- .../make.h | 65 ----- .../require.h | 47 +--- .../transform.h | 230 +++++++++--------- .../visit.h | 30 +-- .../wrap.h | 48 ---- .../get_leaves.h | 24 +- .../get_node_type.h | 24 +- .../is_binary_sp_tree_left_associative.h | 24 +- .../is_binary_sp_tree_right_associative.h | 24 +- .../json.h | 78 ++++++ .../leaf_only_binary_parallel_split.h | 23 -- ...eaf_only_binary_parallel_split.struct.toml | 17 +- .../leaf_only_binary_series_split.h | 23 -- .../leaf_only_binary_series_split.struct.toml | 16 +- ...y_binary_sp_decomposition_tree.struct.toml | 24 -- ..._binary_sp_decomposition_tree.variant.toml | 29 +++ ...osition_tree_transform_visitor.struct.toml | 16 ++ ...nly_binary_sp_decomposition_tree_visitor.h | 35 +++ ..._sp_decomposition_tree_visitor.struct.toml | 14 +- .../make.h | 42 ---- .../require.h | 72 +++--- .../transform.h | 91 ++++--- .../wrap.h | 40 --- .../parallel_split.struct.toml | 32 +++ .../series_parallel/series_parallel_splits.h | 136 +++++------ .../series_parallel/series_split.struct.toml | 31 +++ .../utils/json/check_is_json_deserializable.h | 14 ++ .../utils/json/check_is_json_serializable.h | 14 ++ lib/utils/src/utils/archetypes/value_type.cc | 8 + lib/utils/src/utils/fmt/json.cc | 8 + .../full_binary_tree/find_paths_to_leaf.cc | 14 ++ lib/utils/src/utils/full_binary_tree/fmt.cc | 12 - .../full_binary_tree/get_all_leaf_paths.cc | 11 + .../src/utils/full_binary_tree/get_child.cc | 15 ++ .../src/utils/full_binary_tree/get_label.cc | 8 - .../src/utils/full_binary_tree/get_leaves.cc | 15 ++ .../full_binary_tree/get_num_tree_nodes.cc | 14 ++ .../full_binary_tree/get_subtree_at_path.cc | 16 ++ lib/utils/src/utils/full_binary_tree/make.cc | 11 - lib/utils/src/utils/full_binary_tree/visit.cc | 6 +- ...t_transitive_reduced_edges_across_split.cc | 5 +- .../binary_parallel_split.cc | 18 -- .../binary_series_split.cc | 18 -- .../binary_sp_decomposition_tree.cc | 67 ++--- .../find_paths_to_leaf.cc | 14 +- ...ry_sp_decomposition_tree_implementation.cc | 14 ++ .../generic_binary_sp_split_label.cc | 16 -- .../get_all_leaf_paths.cc | 12 +- .../get_leaves.cc | 16 +- .../get_left_child.cc | 10 - .../get_num_tree_nodes.cc | 15 +- .../get_right_child.cc | 10 - .../get_subtree_at_path.cc | 14 +- .../is_binary_sp_tree_left_associative.cc | 12 +- .../is_binary_sp_tree_right_associative.cc | 12 +- .../make.cc | 18 -- .../require.cc | 29 ++- .../transform.cc | 16 +- .../visit.cc | 17 ++ .../wrap.cc | 10 - .../get_leaves.cc | 4 +- .../get_node_type.cc | 4 +- .../is_binary_sp_tree_left_associative.cc | 4 +- .../is_binary_sp_tree_right_associative.cc | 4 +- .../json.cc | 8 + .../leaf_only_binary_parallel_split.cc | 10 - .../leaf_only_binary_series_split.cc | 10 - ...ly_binary_sp_decomposition_tree_visitor.cc | 9 + .../make.cc | 16 -- .../require.cc | 14 +- .../transform.cc | 21 +- .../wrap.cc | 10 - ...ft_associative_binary_sp_tree_from_nary.cc | 16 +- ...ht_associative_binary_sp_tree_from_nary.cc | 16 +- .../get_series_parallel_decomposition.cc | 18 +- .../intermediate_sp_decomposition_tree.cc | 14 +- .../series_parallel_decomposition.cc | 2 +- .../series_parallel/series_parallel_splits.cc | 168 ++++++------- .../json/check_is_json_deserializable.cc | 1 + .../utils/json/check_is_json_serializable.cc | 1 + ...sitive_reduced_boundary_nodes_for_split.cc | 15 +- ...t_transitive_reduced_edges_across_split.cc | 32 ++- ...transitive_reduced_outputs_across_split.cc | 15 +- .../get_leaves.cc | 195 +++++++++------ .../get_num_tree_nodes.cc | 177 ++++++++------ .../is_binary_sp_tree_left_associative.cc | 53 ++-- .../is_binary_sp_tree_right_associative.cc | 51 ++-- ...ft_associative_binary_sp_tree_from_nary.cc | 34 ++- .../nary_sp_tree_from_binary.cc | 72 +++--- ...ht_associative_binary_sp_tree_from_nary.cc | 34 ++- .../get_series_parallel_decomposition.cc | 48 ++-- 201 files changed, 2960 insertions(+), 2988 deletions(-) create mode 100644 bin/export-model-arch/include/export_model_arch/json_sp_model_export.h create mode 100644 bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h delete mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml rename lib/compiler/include/compiler/series_parallel/{ => computation_graph}/computation_graph_binary_sp_decomposition.h (58%) create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml rename lib/compiler/include/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.h (100%) delete mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml rename lib/compiler/include/compiler/series_parallel/{ => pcg}/get_pcg_balanced_binary_sp_decomposition.h (100%) rename lib/compiler/include/compiler/series_parallel/{ => pcg}/get_pcg_series_parallel_decomposition.h (100%) create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h delete mode 100644 lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc delete mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc rename lib/compiler/src/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.cc (97%) delete mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc rename lib/compiler/src/compiler/series_parallel/{ => pcg}/get_pcg_series_parallel_decomposition.cc (70%) create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc delete mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc delete mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc delete mode 100644 lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc rename lib/compiler/test/src/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.cc (96%) create mode 100644 lib/utils/include/utils/archetypes/value_type.h create mode 100644 lib/utils/include/utils/fmt/json.h delete mode 100644 lib/utils/include/utils/full_binary_tree/fmt.h delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml delete mode 100644 lib/utils/include/utils/full_binary_tree/get_label.h delete mode 100644 lib/utils/include/utils/full_binary_tree/get_left_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h delete mode 100644 lib/utils/include/utils/full_binary_tree/get_right_child.h delete mode 100644 lib/utils/include/utils/full_binary_tree/hash.h delete mode 100644 lib/utils/include/utils/full_binary_tree/make.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h create mode 100644 lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/series_split.struct.toml create mode 100644 lib/utils/include/utils/json/check_is_json_deserializable.h create mode 100644 lib/utils/include/utils/json/check_is_json_serializable.h create mode 100644 lib/utils/src/utils/archetypes/value_type.cc create mode 100644 lib/utils/src/utils/fmt/json.cc create mode 100644 lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/fmt.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_child.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/get_label.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_leaves.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/make.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc create mode 100644 lib/utils/src/utils/json/check_is_json_deserializable.cc create mode 100644 lib/utils/src/utils/json/check_is_json_serializable.cc rename lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/{leaf_only_binary_sp_decomposition_tree => generic_binary_sp_decomposition_tree}/is_binary_sp_tree_left_associative.cc (58%) rename lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/{leaf_only_binary_sp_decomposition_tree => generic_binary_sp_decomposition_tree}/is_binary_sp_tree_right_associative.cc (60%) diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h new file mode 100644 index 0000000000..df4e140b99 --- /dev/null +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H +#define _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H + +#include +#include "export_model_arch/json_sp_model_export.dtg.h" + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::JsonSPModelExport> { + static ::FlexFlow::JsonSPModelExport from_json(json const &); + static void to_json(json &, ::FlexFlow::JsonSPModelExport const &); +}; + +} // namespace nlohmann + +#endif diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml index efaf368bc8..3c08bc150e 100644 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -3,24 +3,22 @@ name = "JsonSPModelExport" features = [ "eq", "hash", - "json", "fmt", + "json", ] includes = [ "pcg/file_format/v1/v1_computation_graph.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h", ] [[fields]] name = "sp_decomposition" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" [[fields]] name = "computation_graph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 1c2dfd6ea3..0ff4f47a40 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,5 +1,5 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" diff --git a/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc b/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc new file mode 100644 index 0000000000..ca8d76d803 --- /dev/null +++ b/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc @@ -0,0 +1,20 @@ +#include "export_model_arch/json_sp_model_export.h" + +using namespace ::FlexFlow; + +namespace nlohmann { + +JsonSPModelExport adl_serializer::from_json(json const &j) { + NOT_IMPLEMENTED(); +} + +static void sp_decomposition_to_json(json &j, LeafOnlyBinarySPDecompositionTree const &t) { +} + +void adl_serializer::to_json(json &j, JsonSPModelExport const &m) { + j["computation_graph"] = m.computation_graph; + sp_decomposition_to_json(j["sp_decomposition"], m.sp_decomposition); +} + + +} // namespace nlohmann diff --git a/flake.lock b/flake.lock index c5f86c613c..87fae7f446 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1727727609, - "narHash": "sha256-BSnh4wZV7LLXDQ4YIhCHz/uJ4N88vv5cBb1LKWJlltM=", + "lastModified": 1728341842, + "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", "owner": "lockshaw", "repo": "proj", - "rev": "e17da953eaea9e728e9dfde9c12a2435122253b1", + "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h index 8390c5b9cb..8567a7a3e6 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -3,7 +3,7 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h index ee3d2bf159..2aed9a20e4 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -4,7 +4,7 @@ #include "compiler/cost_estimator/tensor_set_movement.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index 87c556910f..d314ab493b 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -5,7 +5,7 @@ #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h index 2635f4a318..68d02aaa54 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_view.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 4064b2f0c9..2eccd36719 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -5,32 +5,16 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" namespace FlexFlow { -MachineMappingProblemTree mm_problem_tree_make_series_split( - AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &pre, - MachineMappingProblemTree const &post); -MachineMappingProblemTree - mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs); -MachineMappingProblemTree - mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &); +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); -MMProblemTreeSeriesSplit - require_series_split(MachineMappingProblemTree const &); -MMProblemTreeParallelSplit - require_parallel_split(MachineMappingProblemTree const &); -UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &); - -MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &); -MachineMappingProblemTree - wrap_parallel_split(MMProblemTreeParallelSplit const &); - std::unordered_multiset get_leaves(MachineMappingProblemTree const &); std::unordered_set @@ -40,28 +24,6 @@ std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, BinaryTreePath const &); -template -Result visit(MachineMappingProblemTree const &t, F &&f) { - SPDecompositionTreeNodeType node_type = get_node_type(t); - switch (node_type) { - case SPDecompositionTreeNodeType::SERIES: { - Result result = f(require_series_split(t)); - return result; - } - case SPDecompositionTreeNodeType::PARALLEL: { - Result result = f(require_parallel_split(t)); - return result; - } - case SPDecompositionTreeNodeType::NODE: { - Result result = f(require_leaf(t)); - return result; - } - default: - throw mk_runtime_error( - fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); - } -} - } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml deleted file mode 100644 index e322133768..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "MachineMappingProblemTree" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml new file mode 100644 index 0000000000..1949f143cb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::MMProblemTreeSeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::MMProblemTreeParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::UnmappedOpCostEstimateKey" +key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h deleted file mode 100644 index 63e724fa94..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_MM_PROBLEM_TREE_PARALLEL_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_MM_PROBLEM_TREE_PARALLEL_SPLIT_H - -#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" - -namespace FlexFlow { - -MachineMappingProblemTree get_lhs_child(MMProblemTreeParallelSplit const &); -MachineMappingProblemTree get_rhs_child(MMProblemTreeParallelSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml index b277ca44bd..5247b2006a 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -6,13 +6,22 @@ features = [ "fmt", ] -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +fwd_decls = [ + "struct MachineMappingProblemTree", ] +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + [[fields]] -name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml deleted file mode 100644 index 367ffb399f..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml +++ /dev/null @@ -1,11 +0,0 @@ -namespace = "FlexFlow" -name = "MMProblemTreeParallelSplitLabel" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [] - -fields = [] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h deleted file mode 100644 index a7faced4d8..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H - -#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" - -namespace FlexFlow { - -MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &); -MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &); -AbstractedTensorSetMovement const & - get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml index 299114862c..d4f61bb3f5 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -6,13 +6,28 @@ features = [ "fmt", ] +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", - "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", ] [[fields]] -name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml deleted file mode 100644 index 0887d67b49..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "MMProblemTreeSeriesSplitLabel" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", -] - -[[fields]] -name = "tensor_set_movement" -type = "::FlexFlow::AbstractedTensorSetMovement" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h index 60c47ba049..2b2bc9bf84 100644 --- a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -3,7 +3,7 @@ #include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" #include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..9654a2546e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml new file mode 100644 index 0000000000..aa66c80b43 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h similarity index 58% rename from lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index b855fbff07..ea6723a9cd 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -1,15 +1,26 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" namespace FlexFlow { +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> generic_impl_for_computation_graph_sp_tree(); + SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &); + std::optional get_computation_graph_left_assoc_binary_sp_decomposition( ComputationGraph const &); diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..452470620b --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::ComputationGraphBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::ComputationGraphBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h similarity index 100% rename from lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml deleted file mode 100644 index 2e6bb0b611..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinarySPDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h similarity index 100% rename from lib/compiler/include/compiler/series_parallel/get_pcg_balanced_binary_sp_decomposition.h rename to lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h diff --git a/lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h similarity index 100% rename from lib/compiler/include/compiler/series_parallel/get_pcg_series_parallel_decomposition.h rename to lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h new file mode 100644 index 0000000000..0ac8cee95b --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..f7f7026716 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h new file mode 100644 index 0000000000..d0bda09229 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml new file mode 100644 index 0000000000..af2c8c4dae --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h new file mode 100644 index 0000000000..e8c02ebfb5 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t> generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &); + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..52372fb270 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::PCGBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::PCGBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::parallel_layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h deleted file mode 100644 index 0bbc81ded3..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H - -#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" - -namespace FlexFlow { - -PCGBinarySPDecomposition get_left_child(PCGBinaryParallelSplit const &); -PCGBinarySPDecomposition get_right_child(PCGBinaryParallelSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml deleted file mode 100644 index f7d80138c5..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", -] - -[[fields]] -name = "raw_split" -type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h deleted file mode 100644 index 386bfee4f4..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H - -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" - -namespace FlexFlow { - -BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &); - -PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &); -PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml deleted file mode 100644 index 184b272c55..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", -] - -[[fields]] -name = "raw_split" -type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h deleted file mode 100644 index e2e170b4d5..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H - -#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include - -namespace FlexFlow { - -std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); -std::unordered_multiset - get_parallel_layers(PCGBinarySPDecomposition const &); - -SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); - -PCGBinarySPDecomposition - make_pcg_series_split(PCGBinarySPDecomposition const &, - PCGBinarySPDecomposition const &); -PCGBinarySPDecomposition - make_pcg_parallel_split(PCGBinarySPDecomposition const &, - PCGBinarySPDecomposition const &); -PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &); - -PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &); -PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &); - -PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &); -PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &); -parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &); - -std::unordered_set - find_paths_to_leaf(PCGBinarySPDecomposition const &, - parallel_layer_guid_t const &); - -template -ReturnType visit(PCGBinarySPDecomposition const &d, F &&f) { - SPDecompositionTreeNodeType node_type = get_node_type(d); - switch (node_type) { - case SPDecompositionTreeNodeType::SERIES: - return f(require_series(d)); - case SPDecompositionTreeNodeType::PARALLEL: - return f(require_parallel(d)); - case SPDecompositionTreeNodeType::NODE: - return f(require_leaf(d)); - default: - throw mk_runtime_error(fmt::format("Unknown node type {}", node_type)); - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml deleted file mode 100644 index bead04b307..0000000000 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PCGBinarySPDecomposition" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 56fafbc5e3..53b8d5bdd6 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,7 +1,6 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_series_split.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" @@ -17,7 +16,9 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { + auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) + -> AbstractedSingleTensorMovement + { std::unordered_set tensor_edges = filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; @@ -39,13 +40,13 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( transform(src_layers, [&](parallel_layer_guid_t const &l) { return get_only( - find_paths_to_leaf(get_left_child(split), l)); + find_paths_to_leaf(split.get_left_child(), l)); }), /*dst_machine_views=*/ transform(dst_layers, [&](parallel_layer_guid_t const &l) { return get_only( - find_paths_to_leaf(get_right_child(split), l)); + find_paths_to_leaf(split.get_right_child(), l)); }), }; }; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 20da56eb55..bf44ef0fd7 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -4,13 +4,11 @@ #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/machine_mapping_result.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/machine_specification.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.dtg.h" @@ -47,9 +45,7 @@ MachineMappingResult } } - MachineMappingResult result = visit( - problem_tree, - overload{ + MachineMappingResult result = problem_tree.visit(overload{ [&](MMProblemTreeSeriesSplit const &series_split) { return get_optimal_machine_mapping( result_cache, @@ -89,9 +85,9 @@ MachineMappingResult boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedOpCostEstimateKey leaf = - require_leaf(mm_problem_tree_get_subtree_at_path( - wrap_series_split(series_split), l) - .value()); + mm_problem_tree_get_subtree_at_path( + MachineMappingProblemTree{series_split}, l) + .value().get(); return context.allowed_machine_views(leaf, resources); }); return transform( @@ -110,7 +106,7 @@ MachineMappingResult MachineMappingResult pre_result = get_optimal_machine_mapping(result_cache, context, - get_pre_child(series_split), + series_split.get_left_child(), resources, pre_candidate); @@ -126,7 +122,7 @@ MachineMappingResult MachineMappingResult post_result = get_optimal_machine_mapping(result_cache, context, - get_post_child(series_split), + series_split.get_right_child(), resources, post_candidate); @@ -134,8 +130,7 @@ MachineMappingResult }; MachineMappingResult result = infeasible_machine_mapping_result(); - AbstractedTensorSetMovement tensor_movement = - get_abstracted_tensor_movement(series_split); + AbstractedTensorSetMovement tensor_movement = series_split.tensor_set_movement; for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : @@ -178,15 +173,15 @@ MachineMappingResult get_optimal_machine_mapping( MachineSpecification const &resources, MachineMappingConstraints const &constraints) { - MachineMappingProblemTree lhs = get_lhs_child(parallel_split); - MachineMappingProblemTree rhs = get_rhs_child(parallel_split); + MachineMappingProblemTree lhs = parallel_split.get_left_child(); + MachineMappingProblemTree rhs = parallel_split.get_right_child(); MachineMappingResult series_result = [&] { - MMProblemTreeSeriesSplit series_split = - require_series_split(mm_problem_tree_make_series_split( - /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), - /*pre=*/lhs, - /*post=*/rhs)); + MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, + }; return get_optimal_machine_mapping(result_cache, context, diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 42f1cd3809..ada271580f 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -3,10 +3,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "compiler/series_parallel/pcg_binary_series_split.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/overload.h" @@ -22,26 +19,33 @@ MachineMappingProblemTree get_machine_mapping_problem_tree( to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { - return visit( - sp, - overload{[&](PCGBinarySeriesSplit const &series) { + return sp.visit(overload{ + [&](PCGBinarySeriesSplit const &series) { AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); - return mm_problem_tree_make_series_split( + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ /*tensor_set_movement=*/tensor_movement, - /*lhs=*/to_problem_tree(get_left_child(series)), - /*rhs=*/to_problem_tree(get_right_child(series))); + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; }, [&](PCGBinaryParallelSplit const ¶llel) { - return mm_problem_tree_make_parallel_split( - to_problem_tree(get_left_child(parallel)), - to_problem_tree(get_right_child(parallel))); + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; }, [&](parallel_layer_guid_t const &leaf) { - return mm_problem_tree_make_leaf( - get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf)); - }}); + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); }; return to_problem_tree(sp_decomposition_tree); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 992a73db03..a5b3cab43e 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -1,105 +1,67 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" namespace FlexFlow { -MachineMappingProblemTree mm_problem_tree_make_series_split( - AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - make_generic_binary_series_split( - MMProblemTreeSeriesSplitLabel{tensor_set_movement}, - lhs.raw_tree, - rhs.raw_tree), - }; -} - -MachineMappingProblemTree - mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - make_generic_binary_parallel_split( - MMProblemTreeParallelSplitLabel{}, lhs.raw_tree, rhs.raw_tree), - }; -} - -MachineMappingProblemTree - mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &leaf_label) { - return MachineMappingProblemTree{ - make_generic_binary_sp_leaf(leaf_label), +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { + return GenericBinarySPDecompositionTreeImplementation< + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/[](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/[](MachineMappingProblemTree const &tree) -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, }; } SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { - return get_node_type(tree.raw_tree); -} - -MMProblemTreeSeriesSplit - require_series_split(MachineMappingProblemTree const &t) { - return MMProblemTreeSeriesSplit{ - require_generic_binary_series_split(t.raw_tree), - }; -} - -MMProblemTreeParallelSplit - require_parallel_split(MachineMappingProblemTree const &t) { - return MMProblemTreeParallelSplit{ - require_generic_binary_parallel_split(t.raw_tree), - }; -} - -UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { - return require_generic_binary_leaf(t.raw_tree); -} - -MachineMappingProblemTree - wrap_series_split(MMProblemTreeSeriesSplit const &series) { - return MachineMappingProblemTree{ - wrap_series_split(series.raw_split), - }; -} - -MachineMappingProblemTree - wrap_parallel_split(MMProblemTreeParallelSplit const ¶llel) { - return MachineMappingProblemTree{ - wrap_parallel_split(parallel.raw_split), - }; + return tree.visit(overload { + [](MMProblemTreeSeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](MMProblemTreeParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + [](UnmappedOpCostEstimateKey const &) { return SPDecompositionTreeNodeType::NODE; }, + }); } std::unordered_multiset - get_leaves(MachineMappingProblemTree const &t) { - return get_leaves(t.raw_tree); + get_leaves(MachineMappingProblemTree const &tree) { + return get_leaves(tree, generic_binary_sp_impl_for_mm_problem_tree()); } std::unordered_set - get_all_leaf_paths(MachineMappingProblemTree const &t) { - return get_all_leaf_paths(t.raw_tree); + get_all_leaf_paths(MachineMappingProblemTree const &tree) { + return get_all_leaf_paths(tree, generic_binary_sp_impl_for_mm_problem_tree()); } std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, BinaryTreePath const &path) { - std::optional< - GenericBinarySPDecompositionTree> - raw_subtree = get_subtree_at_path(tree.raw_tree, path); - - if (!raw_subtree.has_value()) { - return std::nullopt; - } else { - return MachineMappingProblemTree{raw_subtree.value()}; - } + return get_subtree_at_path(tree, generic_binary_sp_impl_for_mm_problem_tree(), path); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc deleted file mode 100644 index e31613ee25..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" - -namespace FlexFlow { - -MachineMappingProblemTree get_lhs_child(MMProblemTreeParallelSplit const &p) { - return MachineMappingProblemTree{ - get_left_child(p.raw_split), - }; -} - -MachineMappingProblemTree get_rhs_child(MMProblemTreeParallelSplit const &p) { - return MachineMappingProblemTree{ - get_right_child(p.raw_split), - }; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc deleted file mode 100644 index ac67baaf47..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.cc +++ /dev/null @@ -1,22 +0,0 @@ -#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" - -namespace FlexFlow { - -MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &s) { - return MachineMappingProblemTree{ - s.raw_split.pre, - }; -} - -MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &s) { - return MachineMappingProblemTree{ - s.raw_split.post, - }; -} - -AbstractedTensorSetMovement const & - get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &s) { - return s.raw_split.label.tensor_set_movement; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index 618d93e9f2..004aca6a81 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -1,6 +1,6 @@ #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_series_split.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "utils/containers/flatmap.h" @@ -41,7 +41,7 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = get_raw_graph_series_split(split); + BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); std::unordered_set raw_edges = get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); @@ -57,7 +57,7 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = get_raw_graph_series_split(split); + BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); std::unordered_set raw_outputs = get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); @@ -72,7 +72,7 @@ PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = get_raw_graph_series_split(split); + BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); SplitBoundaryNodes raw_boundary = get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..6d6e63429b --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,145 @@ +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/[](ComputationGraphBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/[](ComputationGraphBinarySPDecomposition const &tree) -> layer_guid_t const & { + return tree.get(); + }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &tree) { + return tree.visit(overload { + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { + return tree.get(); +} + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &bin) { + return bin.visit(overload { + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp(series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp(series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, + }); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_left_associative(tree, generic_impl_for_computation_graph_sp_tree()); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_right_associative(tree, generic_impl_for_computation_graph_sp_tree()); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 97% rename from lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 184ad93f4d..8f78d423b3 100644 --- a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "op-attrs/computation_graph_op_attrs.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph/computation_graph_edge.h" diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc deleted file mode 100644 index 25fda37c1e..0000000000 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" - -namespace FlexFlow { - -SPDecompositionTreeNodeType - get_node_type(ComputationGraphBinarySPDecomposition const &d) { - return get_node_type(d.raw_tree); -} - -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_leaf_only_binary_leaf(d.raw_tree); -} - -std::optional - get_computation_graph_left_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - left_associative_binary_sp_tree_from_nary(sp_decomposition); - - auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ - [](Node const &n) { return layer_guid_t{n}; }, - }; - return ComputationGraphBinarySPDecomposition{ - transform(raw_binary_tree.raw_tree, visitor)}; -} - -std::optional - get_computation_graph_right_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - right_associative_binary_sp_tree_from_nary(sp_decomposition); - - auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ - [](Node const &n) { return layer_guid_t{n}; }, - }; - return ComputationGraphBinarySPDecomposition{ - transform(raw_binary_tree.raw_tree, visitor)}; -} - -bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_left_associative(d.raw_tree); -} - -bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_right_associative(d.raw_tree); -} - -std::unordered_multiset - get_layers(ComputationGraphBinarySPDecomposition const &d) { - return get_leaves(d.raw_tree); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc similarity index 70% rename from lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc rename to lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc index 95e810fe8f..220614bb8b 100644 --- a/lib/compiler/src/compiler/series_parallel/get_pcg_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_pcg_series_parallel_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..d0c54c91aa --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -0,0 +1,93 @@ +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t> generic_impl_for_pcg_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/[](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/[](PCGBinarySPDecomposition const &tree) -> parallel_layer_guid_t const & { + return tree.get(); + }, + }; +} + + +BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload { + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_parallel_split_from_pcg_parallel_split(parallel), + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, + }); +} + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_pcg_sp_tree()); +} + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload { + [](PCGBinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](PCGBinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + [](parallel_layer_guid_t const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &tree, + parallel_layer_guid_t const &leaf) { + return find_paths_to_leaf(tree, generic_impl_for_pcg_sp_tree(), leaf); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc deleted file mode 100644 index 0888b5c02d..0000000000 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" - -namespace FlexFlow { - -PCGBinarySPDecomposition get_left_child(PCGBinaryParallelSplit const &s) { - return PCGBinarySPDecomposition{ - get_left_child(s.raw_split), - }; -} - -PCGBinarySPDecomposition get_right_child(PCGBinaryParallelSplit const &s) { - return PCGBinarySPDecomposition{ - get_right_child(s.raw_split), - }; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc deleted file mode 100644 index 1d1ac2b9e4..0000000000 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "compiler/series_parallel/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" - -namespace FlexFlow { - -BinarySeriesSplit get_raw_graph_series_split(PCGBinarySeriesSplit const &s) { - auto visitor = - LeafOnlyBinarySPDecompositionTreeVisitor{ - [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }}; - - return BinarySeriesSplit{ - transform(s.raw_split, visitor), - }; -} - -PCGBinarySPDecomposition get_left_child(PCGBinarySeriesSplit const &s) { - return PCGBinarySPDecomposition{ - get_left_child(s.raw_split), - }; -} - -PCGBinarySPDecomposition get_right_child(PCGBinarySeriesSplit const &s) { - return PCGBinarySPDecomposition{ - get_right_child(s.raw_split), - }; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc deleted file mode 100644 index 1b53a3c047..0000000000 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" - -namespace FlexFlow { - -std::optional - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { - NOT_IMPLEMENTED(); -} - -std::unordered_multiset - get_parallel_layers(PCGBinarySPDecomposition const &d) { - return get_leaves(d.raw_tree); -} - -SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &d) { - return get_node_type(d.raw_tree); -} - -PCGBinarySPDecomposition - make_pcg_series_split(PCGBinarySPDecomposition const &lhs, - PCGBinarySPDecomposition const &rhs) { - return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), - }; -} - -PCGBinarySPDecomposition - make_pcg_parallel_split(PCGBinarySPDecomposition const &lhs, - PCGBinarySPDecomposition const &rhs) { - return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), - }; -} - -PCGBinarySPDecomposition make_pcg_leaf_node(parallel_layer_guid_t const &l) { - return PCGBinarySPDecomposition{ - leaf_only_binary_sp_tree_make_leaf(l), - }; -} - -PCGBinarySPDecomposition wrap_series_split(PCGBinarySeriesSplit const &s) { - return PCGBinarySPDecomposition{ - wrap_series_split(s.raw_split), - }; -} - -PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) { - return PCGBinarySPDecomposition{ - wrap_parallel_split(p.raw_split), - }; -} - -PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) { - return PCGBinarySeriesSplit{ - require_leaf_only_binary_series_split(d.raw_tree), - }; -} - -PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) { - return PCGBinaryParallelSplit{ - require_leaf_only_binary_parallel_split(d.raw_tree), - }; -} - -parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { - return require_leaf_only_binary_leaf(d.raw_tree); -} - -std::unordered_set - find_paths_to_leaf(PCGBinarySPDecomposition const &spd, - parallel_layer_guid_t const &l) { - return find_paths_to_leaf(spd.raw_tree, l); -} - -} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 3587316e4b..b63ce95ae0 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -1,6 +1,5 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/get_only.h" #include "utils/full_binary_tree/binary_tree_path.h" @@ -10,6 +9,18 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_abstracted_tensor_set_movement_across_split") { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + ParallelComputationGraph pcg = empty_parallel_computation_graph(); ParallelTensorShape input_shape = ParallelTensorShape{ @@ -58,9 +69,10 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); - PCGBinarySeriesSplit split = require_series( - make_pcg_series_split(make_pcg_leaf_node(input1.parallel_layer), - make_pcg_leaf_node(input2.parallel_layer))); + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), + }; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( @@ -81,10 +93,11 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer_2 = add_parallel_layer( pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(layer_1.parallel_layer)), - make_pcg_leaf_node(layer_2.parallel_layer))); + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), + }; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( @@ -126,13 +139,14 @@ TEST_SUITE(FF_TEST_SUITE) { {get_only(layer_1.outputs), get_only(layer_2.outputs)}, {relu_output_attrs}); - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_series_split( - make_pcg_leaf_node(layer_1.parallel_layer), - make_pcg_leaf_node(layer_2.parallel_layer))), - make_pcg_leaf_node(layer_3.parallel_layer))); + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_series_split( + make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), + }; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( @@ -172,11 +186,12 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult layer_3 = add_parallel_layer( pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), - make_pcg_leaf_node(layer_1.parallel_layer)), - make_pcg_parallel_split(make_pcg_leaf_node(layer_2.parallel_layer), - make_pcg_leaf_node(layer_3.parallel_layer)))); + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_parallel_split(make_leaf(layer_2.parallel_layer), + make_leaf(layer_3.parallel_layer)), + }; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( @@ -226,14 +241,15 @@ TEST_SUITE(FF_TEST_SUITE) { {get_only(layer_1.outputs), get_only(layer_2.outputs)}, {relu_output_attrs}); - PCGBinarySeriesSplit split = require_series(make_pcg_series_split( - make_pcg_series_split( - make_pcg_leaf_node(input.parallel_layer), - make_pcg_parallel_split( - make_pcg_leaf_node(layer_1.parallel_layer), - make_pcg_leaf_node(layer_2.parallel_layer))), - make_pcg_parallel_split(make_pcg_leaf_node(layer_3.parallel_layer), - make_pcg_leaf_node(layer_4.parallel_layer)))); + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_parallel_split( + make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_3.parallel_layer), + make_leaf(layer_4.parallel_layer)) + }; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index de26a5f2ad..7194fc038c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -15,6 +15,32 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { + auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto make_series_split = [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + auto make_parallel_split = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); @@ -117,7 +143,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingCache cache = empty_machine_mapping_cache(); SUBCASE("single layer") { - MachineMappingProblemTree problem_tree = mm_problem_tree_make_leaf(k1); + MachineMappingProblemTree problem_tree = make_leaf(k1); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( @@ -140,9 +166,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in sequence") { MachineMappingProblemTree problem_tree = - mm_problem_tree_make_series_split(movement1, - mm_problem_tree_make_leaf(k1), - mm_problem_tree_make_leaf(k2)); + make_series_split(movement1, make_leaf(k1), make_leaf(k2)); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( @@ -176,8 +200,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in parallel") { MachineMappingProblemTree problem_tree = - mm_problem_tree_make_parallel_split(mm_problem_tree_make_leaf(k1), - mm_problem_tree_make_leaf(k2)); + make_parallel_split(make_leaf(k1), make_leaf(k2)); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 1940e6d8a3..09d4af7756 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,6 +1,5 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/get_only.h" #include @@ -9,6 +8,58 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_machine_mapping_problem_tree") { + auto pcg_make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinarySeriesSplit{ + lhs, + rhs, + }, + }; + }; + + auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto mm_problem_tree_make_series = [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_parallel = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; + ParallelComputationGraph pcg = empty_parallel_computation_graph(); ParallelTensorShape input_shape = ParallelTensorShape{ @@ -62,12 +113,11 @@ TEST_SUITE(FF_TEST_SUITE) { UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = - make_pcg_leaf_node(input_layer); + PCGBinarySPDecomposition sp_decomposition = PCGBinarySPDecomposition{input_layer}; MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_key); + MachineMappingProblemTree correct = MachineMappingProblemTree{input_key}; CHECK(result == correct); } @@ -105,13 +155,13 @@ TEST_SUITE(FF_TEST_SUITE) { /*output_shapes=*/{relu_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = make_pcg_series_split( - make_pcg_leaf_node(input_layer), make_pcg_leaf_node(relu_layer)); + PCGBinarySPDecomposition sp_decomposition = pcg_make_series( + pcg_make_leaf(input_layer), pcg_make_leaf(relu_layer)); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_series_split( + MachineMappingProblemTree correct = mm_problem_tree_make_series( AbstractedTensorSetMovement{{ AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, @@ -142,13 +192,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t input2_layer = input2_added.parallel_layer; UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = make_pcg_parallel_split( - make_pcg_leaf_node(input1_layer), make_pcg_leaf_node(input2_layer)); + PCGBinarySPDecomposition sp_decomposition = pcg_make_parallel( + pcg_make_leaf(input1_layer), pcg_make_leaf(input2_layer)); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_parallel_split( + MachineMappingProblemTree correct = mm_problem_tree_make_parallel( mm_problem_tree_make_leaf(input1_key), mm_problem_tree_make_leaf(input2_key)); @@ -190,15 +240,15 @@ TEST_SUITE(FF_TEST_SUITE) { /*output_shapes=*/{ew_op_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = make_pcg_series_split( - make_pcg_parallel_split(make_pcg_leaf_node(input1_layer), - make_pcg_leaf_node(input2_layer)), - make_pcg_leaf_node(ew_op_layer)); + PCGBinarySPDecomposition sp_decomposition = pcg_make_series( + pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_series_split( + MachineMappingProblemTree correct = mm_problem_tree_make_series( AbstractedTensorSetMovement{{ AbstractedSingleTensorMovement{ /*parallel_tensor_shape=*/input_shape, @@ -228,7 +278,7 @@ TEST_SUITE(FF_TEST_SUITE) { }, }}, /*pre=*/ - mm_problem_tree_make_parallel_split( + mm_problem_tree_make_parallel( mm_problem_tree_make_leaf(input1_key), mm_problem_tree_make_leaf(input2_key)), /*post=*/mm_problem_tree_make_leaf(ew_op_key)); diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 96% rename from lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 564cffaebe..2b59669aad 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" @@ -89,14 +89,14 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ input_layer.raw_node, projection_weights_layer.raw_node, bias_weights_layer.raw_node, - }, + }}, operator_layer.raw_node, - }}; + }}}; CHECK(result == correct); } @@ -159,17 +159,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ w1.raw_node, input.raw_node, w2.raw_node, - }, - ParallelSplit{ + }}, + ParallelSplit{{ op1.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("SP with or without preprocessing, but preprocessing would SP " @@ -214,16 +214,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ - SeriesSplit{ + SeriesParallelDecomposition{ParallelSplit{{ + SeriesSplit{{ input1.raw_node, op1.raw_node, - }, - SeriesSplit{ + }}, + SeriesSplit{{ input2.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("not SP with or without weight nodes") { diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 2e878c5c53..c67999ff32 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -41,7 +41,7 @@ key = "normal" [[values]] type = "::FlexFlow::TruncatedNormalInitializerAttrs" -key = "normal" +key = "truncated_normal" [[values]] type = "::FlexFlow::ConstantInitializerAttrs" diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h new file mode 100644 index 0000000000..4831afa408 --- /dev/null +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H + +#include +#include + +namespace FlexFlow { + +template +struct value_type { + value_type() = delete; + + value_type(value_type const &) { assert(false); } + value_type &operator=(value_type const &) { assert(false); } + + value_type(value_type &&) { assert(false); } + value_type &operator=(value_type &&) { assert(false); } + + bool operator==(value_type const &) const { assert(false); } + bool operator!=(value_type const &) const { assert(false); } +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::value_type> { + size_t operator()(::FlexFlow::value_type const &) const { + assert (false); + }; +}; + +} + +#endif diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h new file mode 100644 index 0000000000..15ad0de4e0 --- /dev/null +++ b/lib/utils/include/utils/fmt/json.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H + +#include +#include + +namespace fmt { + +template +struct formatter<::nlohmann::json, Char> : formatter { + template + auto format(::nlohmann::json const &j, FormatContext &ctx) { + std::ostringstream oss; + oss << j; + return formatter::format(oss.str(), ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 11c9e1db81..07928f7871 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -5,34 +5,36 @@ #include "utils/containers/transform.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" #include namespace FlexFlow { -template +template std::unordered_set - find_paths_to_leaf(FullBinaryTree const &tree, - LeafLabel const &leaf) { - return visit>( - tree, - overload{ - [&](LeafLabel const &l) -> std::unordered_set { - if (l == leaf) { - return {binary_tree_root_path()}; - } else { - return {}; - } - }, - [&](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(find_paths_to_leaf(get_left_child(parent), leaf), - nest_inside_left_child), - transform(find_paths_to_leaf(get_right_child(parent), leaf), - nest_inside_right_child)); - }}); + find_paths_to_leaf(Tree const &tree, FullBinaryTreeImplementation const &impl, Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h deleted file mode 100644 index 4450b70596..0000000000 --- a/lib/utils/include/utils/full_binary_tree/fmt.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H - -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" -#include - -namespace FlexFlow { - -template -std::string - format_as(FullBinaryTreeParentNode const &t) { - return fmt::format( - "<{} ({} {})>", t.label, get_left_child(t), get_right_child(t)); -} - -template -std::string format_as(FullBinaryTree const &t) { - auto visitor = FullBinaryTreeVisitor{ - [](FullBinaryTreeParentNode const &parent) { - return fmt::to_string(parent); - }, - [](LeafLabel const &leaf) { return fmt::format("{}", leaf); }, - }; - - return visit(t, visitor); -} - -template -std::ostream & - operator<<(std::ostream &s, - FullBinaryTreeParentNode const &t) { - return (s << fmt::to_string(t)); -} - -template -std::ostream &operator<<(std::ostream &s, - FullBinaryTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h deleted file mode 100644 index 562edf52c1..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h +++ /dev/null @@ -1,106 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct FullBinaryTree; - -template -struct FullBinaryTreeParentNode { - explicit FullBinaryTreeParentNode( - ParentLabel const &label, - FullBinaryTree const &lhs, - FullBinaryTree const &rhs) - : label(label), - left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; - - bool operator==(FullBinaryTreeParentNode const &other) const { - if (this->tie_ptr() == other.tie_ptr()) { - return true; - } - - return this->tie() == other.tie(); - } - - bool operator!=(FullBinaryTreeParentNode const &other) const { - if (this->tie_ptr() == other.tie_ptr()) { - return false; - } - - return this->tie() != other.tie(); - } - - bool operator<(FullBinaryTreeParentNode const &other) const { - return this->tie() < other.tie(); - } - -public: - ParentLabel label; - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple> const &, - std::shared_ptr> const &> - tie_ptr() const { - return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); - } - - std::tuple const &, - FullBinaryTree const &> - tie() const { - return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct FullBinaryTree { -public: - FullBinaryTree() = delete; - explicit FullBinaryTree( - FullBinaryTreeParentNode const &t) - : root{t} {} - - explicit FullBinaryTree(LeafLabel const &t) : root{t} {} - - bool operator==(FullBinaryTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(FullBinaryTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(FullBinaryTree const &other) const { - return this->tie() < other.tie(); - } - -public: - std::variant, LeafLabel> - root; - -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml new file mode 100644 index 0000000000..2183abe900 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "FullBinaryTree" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "ParentLabel", + "LeafLabel", +] + +includes = [ + "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h", +] + +[[values]] +type = "::FlexFlow::FullBinaryTreeParentNode" +key = "parent" + +[[values]] +type = "LeafLabel" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml new file mode 100644 index 0000000000..bf08701840 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "get_left_child" +type = "std::function" + +[[fields]] +name = "get_right_child" +type = "std::function" + +[[fields]] +name = "is_leaf" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" + +[[fields]] +name = "require_parent" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml new file mode 100644 index 0000000000..3403271621 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeParentNode" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "template struct FullBinaryTree", +] + +post_includes = [ + "utils/full_binary_tree/full_binary_tree.dtg.h", +] + +template_params = [ + "ParentLabel", + "LeafLabel", +] + +[[fields]] +name = "label" +type = "ParentLabel" + +[[fields]] +name = "left_child" +type = "::FlexFlow::FullBinaryTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::FullBinaryTree" +indirect = true diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml index cb637057db..7418d7a016 100644 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -4,19 +4,19 @@ features = [] template_params = [ "Result", - "ParentLabel", - "LeafLabel", + "Tree", + "Parent", + "Leaf", ] includes = [ "", - "utils/full_binary_tree/full_binary_tree.h", ] [[fields]] name = "parent_func" -type = "std::function const &)>" +type = "std::function" [[fields]] name = "leaf_func" -type = "std::function" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 23008b4cc0..20c2eb8b62 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -5,33 +5,34 @@ #include "utils/containers/transform.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/visit.h" #include "utils/overload.h" #include namespace FlexFlow { -template +template std::unordered_set - get_all_leaf_paths(FullBinaryTree const &tree) { - return visit>( - tree, - overload{ - [](LeafLabel const &) { - return std::unordered_set{binary_tree_root_path()}; - }, - [](FullBinaryTreeParentNode const &parent) { - return set_union( - transform(get_all_leaf_paths(get_left_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(get_right_child(parent)), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - }}); + get_all_leaf_paths(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index db7ea95a04..5c1e21014d 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -3,22 +3,20 @@ #include "utils/exception.h" #include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" #include namespace FlexFlow { -template -FullBinaryTree - get_child(FullBinaryTreeParentNode const &t, - BinaryTreePathEntry const &e) { +template +Tree get_child(Parent const &parent, + FullBinaryTreeImplementation const &impl, + BinaryTreePathEntry const &e) { switch (e) { case BinaryTreePathEntry::LEFT_CHILD: - return get_left_child(t); + return impl.get_left_child(parent); case BinaryTreePathEntry::RIGHT_CHILD: - return get_right_child(t); + return impl.get_right_child(parent); default: throw mk_runtime_error( fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); diff --git a/lib/utils/include/utils/full_binary_tree/get_label.h b/lib/utils/include/utils/full_binary_tree/get_label.h deleted file mode 100644 index e89fdab98e..0000000000 --- a/lib/utils/include/utils/full_binary_tree/get_label.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H - -#include "utils/full_binary_tree/full_binary_tree.h" - -namespace FlexFlow { - -template -ParentLabel get_full_binary_tree_parent_label( - FullBinaryTreeParentNode const &p) { - return p.label; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index 8ebc945db7..87633f29a9 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -2,26 +2,32 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H #include "utils/containers/multiset_union.h" -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" #include namespace FlexFlow { -template -std::unordered_multiset - get_leaves(FullBinaryTree const &t) { - return visit>( - t, - overload{ - [](FullBinaryTreeParentNode const &parent) { - return multiset_union(get_leaves(get_left_child(parent)), - get_leaves(get_right_child(parent))); - }, - [](ChildLabel const &leaf) { - return std::unordered_multiset{leaf}; - }}); +template +std::unordered_multiset + get_leaves(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) + -> std::unordered_multiset + { + return multiset_union(get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) + -> std::unordered_multiset + { + return {leaf}; + }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h deleted file mode 100644 index 5d5148d594..0000000000 --- a/lib/utils/include/utils/full_binary_tree/get_left_child.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H - -#include "utils/full_binary_tree/full_binary_tree.h" - -namespace FlexFlow { - -template -FullBinaryTree const & - get_left_child(FullBinaryTreeParentNode const &t) { - return *t.left_child_ptr; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h index 1a73ce8743..d49faa3694 100644 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -1,23 +1,19 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" #include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" +#include "utils/overload.h" namespace FlexFlow { template FullBinaryTreeNodeType - get_node_type(FullBinaryTree const &t) { - if (std::holds_alternative(t.root)) { - return FullBinaryTreeNodeType::LEAF; - } else { - bool is_parent = std::holds_alternative< - FullBinaryTreeParentNode>(t.root); - assert(is_parent); - - return FullBinaryTreeNodeType::PARENT; - } + get_node_type(FullBinaryTree const &tree) { + return tree.template visit(overload { + [](FullBinaryTreeParentNode const &) { return FullBinaryTreeNodeType::PARENT; }, + [](LeafLabel const &) { return FullBinaryTreeNodeType::LEAF; }, + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..69d4e2ea49 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H + +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor{ + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, + }; + + return visit(tree, impl, visitor); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h deleted file mode 100644 index 937e803395..0000000000 --- a/lib/utils/include/utils/full_binary_tree/get_right_child.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H - -#include "utils/full_binary_tree/full_binary_tree.h" - -namespace FlexFlow { - -template -FullBinaryTree const & - get_right_child(FullBinaryTreeParentNode const &t) { - return *t.right_child_ptr; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index 0a6fba4a77..bbdc74850c 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -3,32 +3,34 @@ #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_child.h" #include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" +#include "utils/full_binary_tree/get_child.h" #include namespace FlexFlow { -template -std::optional> - get_subtree_at_path(FullBinaryTree const &t, +template +std::optional + get_subtree_at_path(Tree const &tree, + FullBinaryTreeImplementation const &impl, BinaryTreePath const &p) { if (p == binary_tree_root_path()) { - return t; + return tree; } - return visit>>( - t, - overload{ - [&](FullBinaryTreeParentNode const &parent) { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { + return std::nullopt; + }, + }; - return get_subtree_at_path(get_child(parent, curr), rest); - }, - [&](LeafLabel const &leaf) { return std::nullopt; }}); + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h deleted file mode 100644 index 6893b990c7..0000000000 --- a/lib/utils/include/utils/full_binary_tree/hash.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H - -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::FullBinaryTreeParentNode> { - size_t operator()( - ::FlexFlow::FullBinaryTreeParentNode const &t) - const { - return get_std_hash(t.tie()); - } -}; - -template -struct hash<::FlexFlow::FullBinaryTree> { - size_t operator()( - ::FlexFlow::FullBinaryTree const &t) const { - return get_std_hash(t.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/make.h b/lib/utils/include/utils/full_binary_tree/make.h deleted file mode 100644 index 488f7f83fd..0000000000 --- a/lib/utils/include/utils/full_binary_tree/make.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H - -#include "utils/full_binary_tree/full_binary_tree.h" - -namespace FlexFlow { - -template -FullBinaryTree make_full_binary_tree_parent( - ParentLabel const &label, - FullBinaryTree const &lhs, - FullBinaryTree const &rhs) { - return FullBinaryTree{ - FullBinaryTreeParentNode{ - label, - lhs, - rhs, - }, - }; -} - -template -FullBinaryTree - make_full_binary_tree_leaf(LeafLabel const &label) { - return FullBinaryTree{ - label, - }; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h index 65bcb9b3bd..7b2625a9e8 100644 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H -#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" namespace FlexFlow { @@ -9,13 +10,13 @@ template FullBinaryTreeParentNode const & require_full_binary_tree_parent_node( FullBinaryTree const &t) { - return std::get>(t.root); + return t.template get>(); } template LeafLabel const &require_full_binary_tree_leaf( FullBinaryTree const &t) { - return std::get(t.root); + return t.template get(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 502165f2ab..f9349069a9 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -3,30 +3,19 @@ #include "utils/exception.h" #include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" #include "utils/full_binary_tree/get_node_type.h" -#include "utils/full_binary_tree/require.h" namespace FlexFlow { -template -Result visit(FullBinaryTree const &tt, F f) { - auto visitor = FullBinaryTreeVisitor{f, f}; - - return visit(tt, visitor); -} - -template -Result visit(FullBinaryTree const &t, - FullBinaryTreeVisitor const &v) { - FullBinaryTreeNodeType node_type = get_node_type(t); - switch (node_type) { - case FullBinaryTreeNodeType::PARENT: - return v.parent_func(require_full_binary_tree_parent_node(t)); - case FullBinaryTreeNodeType::LEAF: - return v.leaf_func(require_full_binary_tree_leaf(t)); - default: - throw mk_runtime_error( - fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); +template +Result visit(Tree const &tree, + FullBinaryTreeImplementation const &impl, + FullBinaryTreeVisitor const &visitor) { + if (impl.is_leaf(tree)) { + return visitor.leaf_func(impl.require_leaf(tree)); + } else { + return visitor.parent_func(impl.require_parent(tree)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h deleted file mode 100644 index 0d66f80f35..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_PARALLEL_SPLIT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -BinarySPDecompositionTree get_left_child(BinaryParallelSplit const &); -BinarySPDecompositionTree get_right_child(BinaryParallelSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml index 0dcae5177a..37e3bbee09 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -6,11 +6,20 @@ features = [ "fmt", ] -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", - "utils/graph/node/node.dtg.h", +fwd_decls = [ + "struct BinarySPDecompositionTree", ] +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + [[fields]] -name = "raw_split" -type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::Node>" +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h deleted file mode 100644 index efd77a89bd..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SERIES_SPLIT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -BinarySPDecompositionTree get_left_child(BinarySeriesSplit const &); -BinarySPDecompositionTree get_right_child(BinarySeriesSplit const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml index 45472cb243..7e6e86ba76 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -6,11 +6,20 @@ features = [ "fmt", ] -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", - "utils/graph/node/node.dtg.h", +fwd_decls = [ + "struct BinarySPDecompositionTree", ] +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + [[fields]] -name = "raw_split" -type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::Node>" +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index b87516e88a..28e9beeebd 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -4,51 +4,25 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include namespace FlexFlow { -BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree make_leaf_node(Node const &); +GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> generic_impl_for_binary_sp_tree(); bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); -BinarySeriesSplit require_series(BinarySPDecompositionTree const &); -BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &); -Node require_leaf(BinarySPDecompositionTree const &); - SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); -template -Return visit(BinarySPDecompositionTree const &tree, F &&f) { - SPDecompositionTreeNodeType node_type = get_node_type(tree); - switch (node_type) { - case SPDecompositionTreeNodeType::SERIES: { - Return result = f(require_series(tree)); - return result; - } - case SPDecompositionTreeNodeType::PARALLEL: { - Return result = f(require_parallel(tree)); - return result; - } - case SPDecompositionTreeNodeType::NODE: { - Return result = f(require_leaf(tree)); - return result; - } - default: - throw mk_runtime_error(fmt::format( - "Unhandled SPDecompositionTreeNodeType value: {}", node_type)); - } -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml deleted file mode 100644 index 0000213398..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "BinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..c586b49d9d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::Node" +key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h index b2d50676b9..9eaf84149f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -2,17 +2,19 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H #include "utils/full_binary_tree/find_paths_to_leaf.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template +template std::unordered_set - find_paths_to_leaf(GenericBinarySPDecompositionTree const &tree, - LeafLabel const &leaf) { - return find_paths_to_leaf(tree.raw_tree, leaf); + find_paths_to_leaf(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return find_paths_to_leaf(tree, full_binary_impl, needle); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml index f613d2f04e..139074299f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml @@ -4,27 +4,32 @@ features = [ "eq", "hash", "fmt", - "json", ] -template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", - "LeafLabel", +fwd_decls = [ + "template struct GenericBinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" ] -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +template_params = [ + "SeriesLabel", + "ParallelLabel", + "LeafLabel", ] [[fields]] name = "label" -type = "ParallelSplitLabel" +type = "ParallelLabel" [[fields]] -name = "lhs" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +name = "left_child" +type = "::FlexFlow::GenericBinarySPDecompositionTree" +indirect = true [[fields]] -name = "rhs" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +name = "right_child" +type = "::FlexFlow::GenericBinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml deleted file mode 100644 index d187b7c93a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinaryParallelSplitLabel" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -template_params = [ - "ParallelSplitLabel" -] - -[[fields]] -name = "raw_label" -type = "ParallelSplitLabel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml index 025dca1826..054532e2e0 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml @@ -4,28 +4,32 @@ features = [ "eq", "hash", "fmt", - "json", ] -template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", - "LeafLabel", +fwd_decls = [ + "template struct GenericBinarySPDecompositionTree", ] +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +] -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +template_params = [ + "SeriesLabel", + "ParallelLabel", + "LeafLabel", ] [[fields]] name = "label" -type = "SeriesSplitLabel" +type = "SeriesLabel" [[fields]] -name = "pre" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +name = "left_child" +type = "::FlexFlow::GenericBinarySPDecompositionTree" +indirect = true [[fields]] -name = "post" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +name = "right_child" +type = "::FlexFlow::GenericBinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml deleted file mode 100644 index 74e00ada81..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySeriesSplitLabel" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -template_params = [ - "SeriesSplitLabel" -] - -[[fields]] -name = "raw_label" -type = "SeriesSplitLabel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml deleted file mode 100644 index 00c49992ef..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/full_binary_tree.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h", -] - -src_includes = [ - "utils/full_binary_tree/json.h", - "utils/full_binary_tree/hash.h", - "utils/full_binary_tree/fmt.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::FullBinaryTree<::FlexFlow::GenericBinarySPSplitLabel, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..5edaa80a09 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +template_params = [ + "SeriesLabel", + "ParallelLabel", + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::GenericBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::GenericBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "LeafLabel" +key = "leaf" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h new file mode 100644 index 0000000000..fd29b69567 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include +#include "utils/overload.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &impl) { + + using Parent = std::variant; + + auto full_binary_impl = FullBinaryTreeImplementation{ + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload { + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, parent); + }, + /*get_right_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload { + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, parent); + }, + /*is_leaf=*/[impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/[impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/[impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format("Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + } + }; + + return full_binary_impl; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml new file mode 100644 index 0000000000..3ccbfd959b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", + "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", +] + +[[fields]] +name = "series_get_left_child" +type = "std::function" + +[[fields]] +name = "parallel_get_left_child" +type = "std::function" + +[[fields]] +name = "series_get_right_child" +type = "std::function" + +[[fields]] +name = "parallel_get_right_child" +type = "std::function" + +[[fields]] +name = "get_node_type" +type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" + +[[fields]] +name = "require_series" +type = "std::function" + +[[fields]] +name = "require_parallel" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml new file mode 100644 index 0000000000..df0b0f2ea7 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeTransformVisitor" +features = [] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", + "SeriesSplitLabel2", + "ParallelSplitLabel2", + "LeafLabel2", +] + +includes = [ + "", +] + +[[fields]] +name = "series_split_func" +type = "std::function" + +[[fields]] +name = "parallel_split_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml index 7c491ad49d..6275c82a0c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml @@ -3,12 +3,11 @@ name = "GenericBinarySPDecompositionTreeVisitor" features = [] template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", - "LeafLabel", - "SeriesSplitLabel2", - "ParallelSplitLabel2", - "LeafLabel2", + "ReturnType", + "Tree", + "Series", + "Parallel", + "Leaf", ] includes = [ @@ -16,13 +15,13 @@ includes = [ ] [[fields]] -name = "series_split_func" -type = "std::function" +name = "series_func" +type = "std::function" [[fields]] -name = "parallel_split_func" -type = "std::function" +name = "parallel_func" +type = "std::function" [[fields]] name = "leaf_func" -type = "std::function" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h deleted file mode 100644 index 2cafb4e5b9..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_SPLIT_LABEL_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_SPLIT_LABEL_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h" -#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -SPDecompositionTreeNodeType get_node_type( - GenericBinarySPSplitLabel const &label) { - return label.template visit(overload{ - [](GenericBinarySeriesSplitLabel const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplitLabel const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - }); -} - -template -GenericBinarySPSplitLabel - make_generic_binary_series_split_label(SeriesLabel const &label) { - return GenericBinarySPSplitLabel{ - GenericBinarySeriesSplitLabel{ - label, - }, - }; -} - -template -GenericBinarySPSplitLabel - make_generic_binary_parallel_split_label(ParallelLabel const &label) { - return GenericBinarySPSplitLabel{ - GenericBinaryParallelSplitLabel{ - label, - }, - }; -} - -template -SeriesLabel require_generic_binary_series_split_label( - GenericBinarySPSplitLabel const &label) { - return label.template get>() - .raw_label; -} - -template -ParallelLabel require_generic_binary_parallel_split_label( - GenericBinarySPSplitLabel const &label) { - return label.template get>() - .raw_label; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml deleted file mode 100644 index c528c61f37..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPSplitLabel" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", -] - -[[values]] -type = "::FlexFlow::GenericBinarySeriesSplitLabel" -key = "series" - -[[values]] -type = "::FlexFlow::GenericBinaryParallelSplitLabel" -key = "parallel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h index 6eb9166df0..4637cbd81c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -2,16 +2,19 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H #include "utils/full_binary_tree/get_all_leaf_paths.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template +template std::unordered_set get_all_leaf_paths( - GenericBinarySPDecompositionTree const &tree) { - return get_all_leaf_paths(tree.raw_tree); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_all_leaf_paths(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index c5d0e1bd30..7bbc5cf603 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -1,51 +1,20 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#include "utils/containers/multiset_union.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" -#include +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](LeafLabel const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { - return get_leaves(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_leaves(p); - }, - }); -} +template +std::unordered_multiset + get_leaves(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -std::unordered_multiset get_leaves( - GenericBinarySeriesSplit const &s) { - return multiset_union(get_leaves(get_left_child(s)), - get_leaves(get_right_child(s))); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -std::unordered_multiset get_leaves( - GenericBinaryParallelSplit const - &p) { - return multiset_union(get_leaves(get_left_child(p)), - get_leaves(get_right_child(p))); + return get_leaves(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h deleted file mode 100644 index 95a75a835c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_left_child( - GenericBinarySeriesSplit const - &s) { - return s.pre; -} - -template -GenericBinarySPDecompositionTree - get_left_child( - GenericBinaryParallelSplit const - &p) { - return p.lhs; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 8f80c32dbf..c4e86a252d 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -1,10 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#include "utils/full_binary_tree/get_label.h" -#include "utils/full_binary_tree/visit.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include "utils/overload.h" @@ -14,20 +11,18 @@ template SPDecompositionTreeNodeType get_node_type(GenericBinarySPDecompositionTree const &tt) { - auto visitor = FullBinaryTreeVisitor< - SPDecompositionTreeNodeType, - GenericBinarySPSplitLabel, - LeafLabel>{ - [](FullBinaryTreeParentNode< - GenericBinarySPSplitLabel, - LeafLabel> const &parent) { - return get_node_type(get_full_binary_tree_parent_label(parent)); - }, - [](LeafLabel const &) { return SPDecompositionTreeNodeType::NODE; }, - }; - - return visit(tt.raw_tree, visitor); + LeafLabel> const &tree) { + return tree.template visit(overload { + [](GenericBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](GenericBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](LeafLabel const &) { + return SPDecompositionTreeNodeType::NODE; + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index f9619df862..b5fe0d4131 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -1,47 +1,19 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { - return visit(tt, - overload{ - [](LeafLabel const &t) { return 1; }, - [](GenericBinarySeriesSplit const &s) { - return get_num_tree_nodes(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_num_tree_nodes(p); - }, - }); -} +template +int get_num_tree_nodes(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -int get_num_tree_nodes( - GenericBinarySeriesSplit const &s) { - return 1 + get_num_tree_nodes(get_left_child(s)) + - get_num_tree_nodes(get_right_child(s)); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -int get_num_tree_nodes( - GenericBinaryParallelSplit const - &p) { - return 1 + get_num_tree_nodes(get_left_child(p)) + - get_num_tree_nodes(get_right_child(p)); + return get_num_tree_nodes(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h deleted file mode 100644 index 4820bfdc7a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_right_child( - GenericBinarySeriesSplit const - &s) { - return s.post; -} - -template -GenericBinarySPDecompositionTree - get_right_child( - GenericBinaryParallelSplit const - &p) { - return p.rhs; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h index 5ec0c03c3a..8a687d9702 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -1,37 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H -#include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/get_subtree_at_path.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" #include namespace FlexFlow { -template -std::optional> - get_subtree_at_path(GenericBinarySPDecompositionTree const &tree, +template +std::optional + get_subtree_at_path(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, BinaryTreePath const &path) { - std::optional, - LeafLabel>> - raw_subtree = get_subtree_at_path(tree.raw_tree, path); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); - if (!raw_subtree.has_value()) { - return std::nullopt; - } else { - return GenericBinarySPDecompositionTree{ - raw_subtree.value(), - }; - } + return get_subtree_at_path(tree, full_binary_impl, path); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 5331a10c86..17ff9c5dd1 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,38 +1,33 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](LeafLabel const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_right_child(s)) && - is_binary_sp_tree_left_associative(get_left_child(s)) && - is_binary_sp_tree_left_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_right_child(p)) && - is_binary_sp_tree_left_associative(get_left_child(p)) && - is_binary_sp_tree_left_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative(impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative(impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative(impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative(impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { + return true; + }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index e7a03b1e0e..b284ce763e 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,38 +1,32 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](LeafLabel const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative(impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative(impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative(impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative(impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { + return true; + }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h deleted file mode 100644 index b1f635389c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ /dev/null @@ -1,65 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H - -#include "utils/full_binary_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - make_generic_binary_series_split( - SeriesLabel const &label, - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - make_generic_binary_series_split_label( - label), - lhs.raw_tree, - rhs.raw_tree), - }; -} - -template -GenericBinarySPDecompositionTree - make_generic_binary_parallel_split( - ParallelLabel const &label, - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - make_generic_binary_parallel_split_label( - label), - lhs.raw_tree, - rhs.raw_tree), - }; -} - -template -GenericBinarySPDecompositionTree - make_generic_binary_sp_leaf(LeafLabel const &leaf) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_leaf< - GenericBinarySPSplitLabel>(leaf), - }; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 4dae420449..315fca844d 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -1,11 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/full_binary_tree/get_label.h" -#include "utils/full_binary_tree/require.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" namespace FlexFlow { @@ -14,24 +11,8 @@ GenericBinarySeriesSplit require_generic_binary_series_split( GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode< - GenericBinarySPSplitLabel, - LeafLabel> - parent = require_full_binary_tree_parent_node(t.raw_tree); - - return GenericBinarySeriesSplit{ - /*label=*/require_generic_binary_series_split_label( - get_full_binary_tree_parent_label(parent)), - /*pre=*/ - GenericBinarySPDecompositionTree{ - get_left_child(parent), - }, - /*post=*/ - GenericBinarySPDecompositionTree{ - get_right_child(parent), - }, - }; + LeafLabel> const &tree) { + return tree.template get>(); } template @@ -39,32 +20,16 @@ GenericBinaryParallelSplit require_generic_binary_parallel_split( GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode< - GenericBinarySPSplitLabel, - LeafLabel> - parent = require_full_binary_tree_parent_node(t.raw_tree); - - return GenericBinaryParallelSplit{ - /*label=*/require_generic_binary_parallel_split_label( - get_full_binary_tree_parent_label(parent)), - /*lhs=*/ - GenericBinarySPDecompositionTree{ - get_left_child(parent), - }, - /*rhs=*/ - GenericBinarySPDecompositionTree{ - get_right_child(parent), - }, - }; + LeafLabel> const &tree) { + return tree.template get>(); } template LeafLabel require_generic_binary_leaf( GenericBinarySPDecompositionTree const &t) { - return require_full_binary_tree_leaf(t.raw_tree); + LeafLabel> const &tree) { + return tree.template get(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index 045bd41652..00419ace55 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -1,116 +1,124 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - transform( - GenericBinarySPDecompositionTree const &tt, - GenericBinarySPDecompositionTreeVisitor const &visitor); - -template -GenericBinarySeriesSplit transform( - GenericBinarySeriesSplit const &s, - GenericBinarySPDecompositionTreeVisitor const &visitor) { - return GenericBinarySeriesSplit{ - visitor.series_split_func(s.label), - transform(get_left_child(s), visitor), - transform(get_right_child(s), visitor), - }; -}; - -template -GenericBinaryParallelSplit transform( - GenericBinaryParallelSplit const &s, - GenericBinarySPDecompositionTreeVisitor const &visitor) { - return GenericBinaryParallelSplit{ - visitor.parallel_split_func(s.label), - transform(get_left_child(s), visitor), - transform(get_right_child(s), visitor), - }; -}; - -template -GenericBinarySPDecompositionTree - transform( - GenericBinarySPDecompositionTree const &tt, - GenericBinarySPDecompositionTreeVisitor const &visitor) { - return visit>( - tt, - overload{ - [&](GenericBinarySeriesSplit const &s) { - return wrap_series_split(transform(s, visitor)); - }, - [&](GenericBinaryParallelSplit const &s) { - return wrap_parallel_split(transform(s, visitor)); - }, - [&](LeafLabel const &t) { - return make_generic_binary_sp_leaf( - visitor.leaf_func(t)); - }, - }); -} - -} // namespace FlexFlow +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" +// #include "utils/overload.h" +// +// namespace FlexFlow { +// +// template +// GenericBinarySPDecompositionTree +// transform( +// GenericBinarySPDecompositionTree const &tt, +// GenericBinarySPDecompositionTreeTransformVisitor const &visitor); +// +// template +// GenericBinarySeriesSplit transform( +// GenericBinarySeriesSplit const &s, +// GenericBinarySPDecompositionTreeTransformVisitor const &visitor) { +// return GenericBinarySeriesSplit{ +// visitor.series_split_func(s.label), +// transform(get_left_child(s), visitor), +// transform(get_right_child(s), visitor), +// }; +// }; +// +// template +// GenericBinaryParallelSplit transform( +// GenericBinaryParallelSplit const &s, +// GenericBinarySPDecompositionTreeTransformVisitor const &visitor) { +// return GenericBinaryParallelSplit{ +// visitor.parallel_split_func(s.label), +// transform(get_left_child(s), visitor), +// transform(get_right_child(s), visitor), +// }; +// }; +// +// template +// GenericBinarySPDecompositionTree +// transform( +// GenericBinarySPDecompositionTree const &tree, +// GenericBinarySPDecompositionTreeTransformVisitor const &transform_visitor) { +// +// using ResultType = +// GenericBinarySPDecompositionTree; +// +// auto visitor = GenericBinarySPDecompositionTreeVisitor< +// ResultType, +// SeriesLabel, +// ParallelLabel, +// LeafLabel> +// { +// [&](GenericBinarySeriesSplit const &s) -> ResultType { +// return generic_binary_sp_tree_wrap_series_split(transform(s, transform_visitor)); +// }, +// [&](GenericBinaryParallelSplit const &s) -> ResultType { +// return generic_binary_sp_tree_wrap_parallel_split(transform(s, transform_visitor)); +// }, +// [&](LeafLabel const &t) -> ResultType { +// return generic_binary_sp_tree_wrap_leaf( +// transform_visitor.leaf_func(t)); +// }, +// }; +// +// return visit(tree, visitor); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 2688c1dd55..89bb45f0fb 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -1,39 +1,33 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, - F f) { - SPDecompositionTreeNodeType node_type = get_node_type(tt); +template +ReturnType visit( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); switch (node_type) { case SPDecompositionTreeNodeType::SERIES: { - Result result = f(require_generic_binary_series_split(tt)); + ReturnType result = visitor.series_func(impl.require_series(tree)); return result; } case SPDecompositionTreeNodeType::PARALLEL: { - Result result = f(require_generic_binary_parallel_split(tt)); + ReturnType result = visitor.parallel_func(impl.require_parallel(tree)); return result; } case SPDecompositionTreeNodeType::NODE: { - Result result = f(require_generic_binary_leaf(tt)); + ReturnType result = visitor.leaf_func(impl.require_leaf(tree)); return result; } default: - throw mk_runtime_error( - fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); + throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType value: {}", node_type)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h deleted file mode 100644 index ba9c5c496c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_WRAP_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_WRAP_H - -#include "utils/full_binary_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - wrap_series_split( - GenericBinarySeriesSplit const - &series_split) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - /*label=*/make_generic_binary_series_split_label( - series_split.label), - /*lhs=*/series_split.pre.raw_tree, - /*rhs=*/series_split.post.raw_tree), - }; -} - -template -GenericBinarySPDecompositionTree - wrap_parallel_split( - GenericBinaryParallelSplit const - ¶llel_split) { - return GenericBinarySPDecompositionTree{ - make_full_binary_tree_parent( - /*label=*/make_generic_binary_parallel_split_label( - parallel_split.label), - /*lhs=*/parallel_split.lhs.raw_tree, - /*rhs=*/parallel_split.rhs.raw_tree), - }; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h index 9a8a744771..738b4c013d 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h @@ -1,17 +1,17 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -std::unordered_multiset - get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { - return get_leaves(t.raw_tree); -} - -} // namespace FlexFlow +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// +// namespace FlexFlow { +// +// template +// std::unordered_multiset +// get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { +// return get_leaves(t.raw_tree); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h index f83103b4de..78dd8990c7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h @@ -1,17 +1,17 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -SPDecompositionTreeNodeType - get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { - return get_node_type(tree.raw_tree); -} - -} // namespace FlexFlow +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// +// namespace FlexFlow { +// +// template +// SPDecompositionTreeNodeType +// get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { +// return get_node_type(tree.raw_tree); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 7d6242030a..df365360ca 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,17 +1,17 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -bool is_binary_sp_tree_left_associative( - LeafOnlyBinarySPDecompositionTree const &t) { - return is_binary_sp_tree_left_associative(t.raw_tree); -} - -} // namespace FlexFlow +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// +// namespace FlexFlow { +// +// template +// bool is_binary_sp_tree_left_associative( +// LeafOnlyBinarySPDecompositionTree const &t) { +// return is_binary_sp_tree_left_associative(t.raw_tree); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index 8fbc6d38a0..8eb09f3654 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,17 +1,17 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -bool is_binary_sp_tree_right_associative( - LeafOnlyBinarySPDecompositionTree const &t) { - return is_binary_sp_tree_right_associative(t.raw_tree); -} - -} // namespace FlexFlow +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// +// namespace FlexFlow { +// +// template +// bool is_binary_sp_tree_right_associative( +// LeafOnlyBinarySPDecompositionTree const &t) { +// return is_binary_sp_tree_right_associative(t.raw_tree); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h new file mode 100644 index 0000000000..b5c1ed3f95 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h @@ -0,0 +1,78 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_JSON_H + +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/visit.h" +// #include "utils/json/check_is_json_deserializable.h" +// #include "utils/json/check_is_json_serializable.h" +// #include "utils/fmt/json.h" +// #include + +namespace nlohmann { + +// template +// struct adl_serializer<::FlexFlow::LeafOnlyBinarySPDecompositionTree> { +// static ::FlexFlow::LeafOnlyBinarySPDecompositionTree from_json(json const &j) { +// CHECK_IS_JSON_SERIALIZABLE(LeafLabel); +// +// using namespace ::FlexFlow; +// +// using Tree = LeafOnlyBinarySPDecompositionTree; +// +// std::string type = j.at("type").get(); +// +// if (type == "series") { +// return leaf_only_binary_sp_tree_wrap_series_split( +// LeafOnlyBinarySeriesSplit{ +// /*lhs=*/j.at("left_child").get(), +// /*rhs=*/j.at("right_child").get(), +// }); +// } else if (type == "parallel") { +// return leaf_only_binary_sp_tree_wrap_parallel_split( +// LeafOnlyBinaryParallelSplit{ +// /*lhs=*/j.at("left_child").get(), +// /*rhs=*/j.at("right_child").get(), +// }); +// } else if (type == "leaf") { +// return leaf_only_binary_sp_tree_wrap_leaf(j.at("label").get()); +// } else { +// throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j)); +// } +// } +// +// static void to_json(json &j, ::FlexFlow::LeafOnlyBinarySPDecompositionTree const &tree) { +// CHECK_IS_JSON_DESERIALIZABLE(LeafLabel); +// +// using namespace FlexFlow; +// +// using Tree = LeafOnlyBinarySPDecompositionTree; +// +// auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ +// /*series_func=*/[&](LeafOnlyBinarySeriesSplit const &split) { +// j["type"] = "series"; +// j["left_child"] = split.lhs; +// j["right_child"] = split.rhs; +// return std::monostate{}; +// }, +// /*parallel_func=*/[&](LeafOnlyBinaryParallelSplit const &split) { +// j["type"] = "parallel"; +// j["left_child"] = split.lhs; +// j["right_child"] = split.rhs; +// return std::monostate{}; +// }, +// /*leaf_func=*/[&](LeafLabel const &leaf_label) { +// j["type"] = "leaf"; +// j["label"] = leaf_label; +// return std::monostate{}; +// }, +// }; +// +// visit(tree, visitor); +// } +// }; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h deleted file mode 100644 index 81fbe0c1fa..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -LeafOnlyBinarySPDecompositionTree - get_left_child(LeafOnlyBinaryParallelSplit const &s) { - return s.lhs; -} - -template -LeafOnlyBinarySPDecompositionTree - get_right_child(LeafOnlyBinaryParallelSplit const &s) { - return s.rhs; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml index d5579fd58c..802ee854ab 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml @@ -6,18 +6,25 @@ features = [ "fmt", ] -template_params = [ - "LeafLabel", +fwd_decls = [ + "template struct LeafOnlyBinarySPDecompositionTree", ] -includes = [ +post_includes = [ "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] + +template_params = [ + "LeafLabel", +] + [[fields]] -name = "lhs" +name = "left_child" type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" +indirect = true [[fields]] -name = "rhs" +name = "right_child" type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h deleted file mode 100644 index d95e741516..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -LeafOnlyBinarySPDecompositionTree - get_left_child(LeafOnlyBinarySeriesSplit const &s) { - return s.pre; -} - -template -LeafOnlyBinarySPDecompositionTree - get_right_child(LeafOnlyBinarySeriesSplit const &s) { - return s.post; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml index a7ff2dcc70..95a647aea3 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml @@ -6,18 +6,24 @@ features = [ "fmt", ] -template_params = [ - "LeafLabel", +fwd_decls = [ + "template struct LeafOnlyBinarySPDecompositionTree", ] -includes = [ +post_includes = [ "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] +template_params = [ + "LeafLabel", +] + [[fields]] -name = "pre" +name = "left_child" type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" +indirect = true [[fields]] -name = "post" +name = "right_child" type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml deleted file mode 100644 index bf52ecc6df..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt" -] - -template_params = [ - "LeafLabel", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", - "", -] - -src_includes = [ - "utils/fmt/monostate.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..fdab932039 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt" +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", + "", +] + +[[values]] +type = "::FlexFlow::LeafOnlyBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::LeafOnlyBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "LeafLabel" +key = "leaf" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml new file mode 100644 index 0000000000..b94972c323 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySPDecompositionTreeTransformVisitor" +features = [] + +template_params = [ + "LeafLabel", + "LeafLabel2", +] + +includes = [ + "", +] + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h new file mode 100644 index 0000000000..f1dbcb97af --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_VISITOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_VISITOR_H + +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" + +namespace FlexFlow { + +// template +// GenericBinarySPDecompositionTreeVisitor +// generic_visitor_from_leaf_only_visitor(LeafOnlyBinarySPDecompositionTreeVisitor const &leaf_only) { +// return GenericBinarySPDecompositionTreeVisitor{ +// [leaf_only](GenericBinarySeriesSplit const &split) { +// return leaf_only.series_func( +// LeafOnlyBinarySeriesSplit{ +// LeafOnlyBinarySPDecompositionTree{split.lhs}, +// LeafOnlyBinarySPDecompositionTree{split.rhs}, +// }); +// }, +// [leaf_only](GenericBinaryParallelSplit const &split) { +// return leaf_only.parallel_func( +// LeafOnlyBinaryParallelSplit{ +// LeafOnlyBinarySPDecompositionTree{split.lhs}, +// LeafOnlyBinarySPDecompositionTree{split.rhs}, +// }); +// }, +// [leaf_only](LeafLabel const &leaf_label) { +// return leaf_only.leaf_func(leaf_label); +// }, +// }; +// } + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml index 27203b8b05..7174afcaa7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml @@ -3,14 +3,24 @@ name = "LeafOnlyBinarySPDecompositionTreeVisitor" features = [] template_params = [ + "ReturnType", "LeafLabel", - "LeafLabel2", ] includes = [ "", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", ] +[[fields]] +name = "series_func" +type = "std::function const &)>" + +[[fields]] +name = "parallel_func" +type = "std::function const &)>" + [[fields]] name = "leaf_func" -type = "std::function" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h deleted file mode 100644 index c82a4560ae..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_series_split( - LeafOnlyBinarySPDecompositionTree const &pre, - LeafOnlyBinarySPDecompositionTree const &post) { - return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_series_split( - std::monostate{}, pre.raw_tree, post.raw_tree), - }; -} - -template -LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_parallel_split( - LeafOnlyBinarySPDecompositionTree const &lhs, - LeafOnlyBinarySPDecompositionTree const &rhs) { - return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_parallel_split( - std::monostate{}, lhs.raw_tree, rhs.raw_tree), - }; -} - -template -LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_leaf(LeafLabel const &label) { - return LeafOnlyBinarySPDecompositionTree{ - make_generic_binary_sp_leaf( - label), - }; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h index 77d7c2fd8d..e83120a0f9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -1,45 +1,45 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { -template -LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( - LeafOnlyBinarySPDecompositionTree const &t) { - GenericBinarySeriesSplit raw = - require_generic_binary_series_split(t.raw_tree); - - return LeafOnlyBinarySeriesSplit{ - LeafOnlyBinarySPDecompositionTree{raw.pre}, - LeafOnlyBinarySPDecompositionTree{raw.post}, - }; -} - -template -LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split( - LeafOnlyBinarySPDecompositionTree const &t) { - GenericBinaryParallelSplit raw = - require_generic_binary_parallel_split(t.raw_tree); - - return LeafOnlyBinaryParallelSplit{ - LeafOnlyBinarySPDecompositionTree{raw.lhs}, - LeafOnlyBinarySPDecompositionTree{raw.rhs}, - }; -} - -template -LeafLabel require_leaf_only_binary_leaf( - LeafOnlyBinarySPDecompositionTree const &t) { - return require_generic_binary_leaf(t.raw_tree); -} - +// template +// LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( +// LeafOnlyBinarySPDecompositionTree const &t) { +// GenericBinarySeriesSplit raw = +// require_generic_binary_series_split(t.raw_tree); +// +// return LeafOnlyBinarySeriesSplit{ +// LeafOnlyBinarySPDecompositionTree{raw.lhs}, +// LeafOnlyBinarySPDecompositionTree{raw.rhs}, +// }; +// } +// +// template +// LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split( +// LeafOnlyBinarySPDecompositionTree const &t) { +// GenericBinaryParallelSplit raw = +// require_generic_binary_parallel_split(t.raw_tree); +// +// return LeafOnlyBinaryParallelSplit{ +// LeafOnlyBinarySPDecompositionTree{raw.lhs}, +// LeafOnlyBinarySPDecompositionTree{raw.rhs}, +// }; +// } +// +// template +// LeafLabel require_leaf_only_binary_leaf( +// LeafOnlyBinarySPDecompositionTree const &t) { +// return require_generic_binary_leaf(t.raw_tree); +// } +// } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h index a18ce37899..58a60c91ba 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -1,57 +1,52 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" namespace FlexFlow { -template -LeafOnlyBinarySeriesSplit transform( - LeafOnlyBinarySeriesSplit const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const - &visitor) { - return LeafOnlyBinarySeriesSplit{ - transform(t.pre, visitor), - transform(t.post, visitor), - }; -} - -template -LeafOnlyBinaryParallelSplit transform( - LeafOnlyBinaryParallelSplit const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const - &visitor) { - return LeafOnlyBinaryParallelSplit{ - transform(t.lhs, visitor), - transform(t.rhs, visitor), - }; -} - -template -LeafOnlyBinarySPDecompositionTree transform( - LeafOnlyBinarySPDecompositionTree const &t, - LeafOnlyBinarySPDecompositionTreeVisitor const - &visitor) { - using GenericVisitor = GenericBinarySPDecompositionTreeVisitor; - - GenericVisitor generic_visitor = GenericVisitor{ - [&](std::monostate const &x) { return x; }, - [&](std::monostate const &x) { return x; }, - [&](LeafLabel const &t) { return visitor.leaf_func(t); }, - }; - - return LeafOnlyBinarySPDecompositionTree{ - transform(t.raw_tree, generic_visitor), - }; -} +// template +// LeafOnlyBinarySeriesSplit transform( +// LeafOnlyBinarySeriesSplit const &t, +// LeafOnlyBinarySPDecompositionTreeVisitor const +// &visitor) { +// return transform(leaf_only_binary_sp_tree_wrap_series_split(t), visitor); +// } +// +// template +// LeafOnlyBinaryParallelSplit transform( +// LeafOnlyBinaryParallelSplit const &t, +// LeafOnlyBinarySPDecompositionTreeVisitor const +// &visitor) { +// return transform(leaf_only_binary_sp_tree_wrap_parallel_split(t), visitor); +// } +// +// template +// LeafOnlyBinarySPDecompositionTree transform( +// LeafOnlyBinarySPDecompositionTree const &t, +// LeafOnlyBinarySPDecompositionTreeVisitor const +// &visitor) { +// using GenericVisitor = GenericBinarySPDecompositionTreeTransformVisitor; +// +// GenericVisitor generic_visitor = GenericVisitor{ +// [&](std::monostate const &x) { return x; }, +// [&](std::monostate const &x) { return x; }, +// [&](LeafLabel const &t) { return visitor.leaf_func(t); }, +// }; +// +// return LeafOnlyBinarySPDecompositionTree{ +// transform(t.raw_tree, generic_visitor), +// }; +// } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h deleted file mode 100644 index e13cea0fdb..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_WRAP_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_WRAP_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -LeafOnlyBinarySPDecompositionTree - wrap_series_split(LeafOnlyBinarySeriesSplit const &split) { - return LeafOnlyBinarySPDecompositionTree{ - wrap_series_split( - GenericBinarySeriesSplit{ - std::monostate{}, - split.pre.raw_tree, - split.post.raw_tree, - }), - }; -} - -template -LeafOnlyBinarySPDecompositionTree - wrap_parallel_split(LeafOnlyBinaryParallelSplit const &split) { - return LeafOnlyBinarySPDecompositionTree{ - wrap_parallel_split( - GenericBinaryParallelSplit{ - std::monostate{}, - split.lhs.raw_tree, - split.rhs.raw_tree, - }), - }; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml new file mode 100644 index 0000000000..dd68adf3f6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct SeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 18434d2b67..98eb913aeb 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,79 +1,75 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#include "utils/graph/node/node.dtg.h" -#include -#include +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" namespace FlexFlow { -struct SeriesSplit; -struct ParallelSplit; - -struct SeriesSplit { -public: - SeriesSplit() = delete; - explicit SeriesSplit(std::vector> const &); - explicit SeriesSplit( - std::initializer_list> const &); - - bool operator==(SeriesSplit const &) const; - bool operator!=(SeriesSplit const &) const; - -public: - std::vector> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(SeriesSplit const &); -std::ostream &operator<<(std::ostream &, SeriesSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::SeriesSplit> { - size_t operator()(::FlexFlow::SeriesSplit const &) const; -}; - -} // namespace std - -namespace FlexFlow { - -struct ParallelSplit { -public: - ParallelSplit() = delete; - explicit ParallelSplit( - std::unordered_multiset> const &); - explicit ParallelSplit( - std::initializer_list> const &); - - bool operator==(ParallelSplit const &) const; - bool operator!=(ParallelSplit const &) const; - -public: - std::unordered_multiset> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(ParallelSplit const &); -std::ostream &operator<<(std::ostream &, ParallelSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::ParallelSplit> { - size_t operator()(::FlexFlow::ParallelSplit const &) const; -}; +// struct SeriesSplit { +// public: +// SeriesSplit() = delete; +// explicit SeriesSplit(std::vector> const &); +// explicit SeriesSplit( +// std::initializer_list> const &); +// +// bool operator==(SeriesSplit const &) const; +// bool operator!=(SeriesSplit const &) const; +// +// public: +// std::vector> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(SeriesSplit const &); +// std::ostream &operator<<(std::ostream &, SeriesSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::SeriesSplit> { +// size_t operator()(::FlexFlow::SeriesSplit const &) const; +// }; +// +// } // namespace std +// +// namespace FlexFlow { +// +// struct ParallelSplit { +// public: +// ParallelSplit() = delete; +// explicit ParallelSplit( +// std::unordered_multiset> const &); +// explicit ParallelSplit( +// std::initializer_list> const &); +// +// bool operator==(ParallelSplit const &) const; +// bool operator!=(ParallelSplit const &) const; +// +// public: +// std::unordered_multiset> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(ParallelSplit const &); +// std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::ParallelSplit> { +// size_t operator()(::FlexFlow::ParallelSplit const &) const; +// }; } // namespace std diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml new file mode 100644 index 0000000000..fdb0a29972 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "SeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h new file mode 100644 index 0000000000..f72485dcbd --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H + +#include "utils/json/is_json_deserializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h new file mode 100644 index 0000000000..f3d1a058f8 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H + +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ + #TYPENAME " should be json serializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc new file mode 100644 index 0000000000..9c197112a1 --- /dev/null +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -0,0 +1,8 @@ +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template + struct value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc new file mode 100644 index 0000000000..783b75973c --- /dev/null +++ b/lib/utils/src/utils/fmt/json.cc @@ -0,0 +1,8 @@ +#include "utils/fmt/json.h" + +namespace fmt { + +template + struct formatter<::nlohmann::json, char>; + +} diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..b3ddab6cbc --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -0,0 +1,14 @@ +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template + std::unordered_set + find_paths_to_leaf(Tree const &, FullBinaryTreeImplementation const &, Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/fmt.cc b/lib/utils/src/utils/full_binary_tree/fmt.cc deleted file mode 100644 index 9e4d328be3..0000000000 --- a/lib/utils/src/utils/full_binary_tree/fmt.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/full_binary_tree/fmt.h" - -namespace FlexFlow { - -template std::string format_as(FullBinaryTreeParentNode const &); -template std::string format_as(FullBinaryTree const &); -template std::ostream &operator<<(std::ostream &, - FullBinaryTreeParentNode const &); -template std::ostream &operator<<(std::ostream &, - FullBinaryTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..cbbffb0b4a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -0,0 +1,11 @@ +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template + std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, value_type<1>, value_type<2>> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc new file mode 100644 index 0000000000..3283db398b --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_child.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template + Tree get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_label.cc b/lib/utils/src/utils/full_binary_tree/get_label.cc deleted file mode 100644 index 1270dcbc9d..0000000000 --- a/lib/utils/src/utils/full_binary_tree/get_label.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/full_binary_tree/get_label.h" - -namespace FlexFlow { - -template int get_full_binary_tree_parent_label( - FullBinaryTreeParentNode const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc new file mode 100644 index 0000000000..18221cd98a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template + std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..b651309c32 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -0,0 +1,14 @@ +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template + int get_num_tree_nodes(Tree const &, + FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..689237752a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -0,0 +1,16 @@ +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template + std::optional + get_subtree_at_path(Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/make.cc b/lib/utils/src/utils/full_binary_tree/make.cc deleted file mode 100644 index 8de1e60eb7..0000000000 --- a/lib/utils/src/utils/full_binary_tree/make.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/full_binary_tree/make.h" - -namespace FlexFlow { - -template FullBinaryTree - make_full_binary_tree_parent(int const &, - FullBinaryTree const &, - FullBinaryTree const &); -template FullBinaryTree make_full_binary_tree_leaf(int const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc index 1b75630269..c8a36dff66 100644 --- a/lib/utils/src/utils/full_binary_tree/visit.cc +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -2,7 +2,9 @@ namespace FlexFlow { -template int visit(FullBinaryTree const &, - FullBinaryTreeVisitor const &); +template + int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index fd2154d6c0..8a4adf0b3a 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -2,7 +2,6 @@ #include "utils/containers/flatmap.h" #include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" namespace FlexFlow { @@ -11,9 +10,9 @@ std::unordered_set get_transitive_reduced_edges_across_split( TransitiveReducedDataflowGraphView const &tr_g, BinarySeriesSplit const &split) { std::unordered_set src_subgraph = - unordered_set_of(get_leaves(get_left_child(split))); + unordered_set_of(get_leaves(split.get_left_child())); std::unordered_set dst_subgraph = - unordered_set_of(get_leaves(get_right_child(split))); + unordered_set_of(get_leaves(split.get_right_child())); std::unordered_set raw_edges = get_edges_from_subgraph_to_subgraph( diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc deleted file mode 100644 index 248522a3a3..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" - -namespace FlexFlow { - -BinarySPDecompositionTree get_left_child(BinaryParallelSplit const &s) { - return BinarySPDecompositionTree{ - get_left_child(s.raw_split), - }; -} - -BinarySPDecompositionTree get_right_child(BinaryParallelSplit const &s) { - return BinarySPDecompositionTree{ - get_right_child(s.raw_split), - }; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc deleted file mode 100644 index 1e80bd68b6..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" - -namespace FlexFlow { - -BinarySPDecompositionTree get_left_child(BinarySeriesSplit const &split) { - return BinarySPDecompositionTree{ - get_left_child(split.raw_split), - }; -} - -BinarySPDecompositionTree get_right_child(BinarySeriesSplit const &split) { - return BinarySPDecompositionTree{ - get_right_child(split.raw_split), - }; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 79042fd061..be7115ee9f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,65 +1,28 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" namespace FlexFlow { -BinarySPDecompositionTree - make_series_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - leaf_only_binary_sp_tree_make_series_split(lhs.raw_tree, rhs.raw_tree), - }; +GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> generic_impl_for_binary_sp_tree() { + NOT_IMPLEMENTED(); } -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - leaf_only_binary_sp_tree_make_parallel_split(lhs.raw_tree, rhs.raw_tree), - }; +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_left_associative(tree, generic_impl_for_binary_sp_tree()); } -BinarySPDecompositionTree make_leaf_node(Node const &n) { - return BinarySPDecompositionTree{ - leaf_only_binary_sp_tree_make_leaf(n), - }; +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_left_associative(tt.raw_tree); -} - -bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_right_associative(tt.raw_tree); -} - -std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { - return get_leaves(tt.raw_tree); -} - -BinarySeriesSplit require_series(BinarySPDecompositionTree const &tt) { - return BinarySeriesSplit{ - require_leaf_only_binary_series_split(tt.raw_tree), - }; -} - -BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &tt) { - return BinaryParallelSplit{ - require_leaf_only_binary_parallel_split(tt.raw_tree), - }; -} - -Node require_leaf(BinarySPDecompositionTree const &tt) { - return require_leaf_only_binary_leaf(tt.raw_tree); -} - -SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tt) { - return get_node_type(tt.raw_tree); +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tree) { + return get_leaves(tree, generic_impl_for_binary_sp_tree()); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc index d14dd7641c..e30b9f97a6 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -1,9 +1,17 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template std::unordered_set - find_paths_to_leaf(GenericBinarySPDecompositionTree const &, - int const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + std::unordered_set + find_paths_to_leaf(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc new file mode 100644 index 0000000000..bc6b4b1ccf --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -0,0 +1,14 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc deleted file mode 100644 index 16ca73e01d..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" - -namespace FlexFlow { - -template SPDecompositionTreeNodeType - get_node_type(GenericBinarySPSplitLabel const &); -template GenericBinarySPSplitLabel - make_generic_binary_series_split_label(int const &); -template GenericBinarySPSplitLabel - make_generic_binary_parallel_split_label(int const &); -template int require_generic_binary_series_split_label( - GenericBinarySPSplitLabel const &); -template int require_generic_binary_parallel_split_label( - GenericBinarySPSplitLabel const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc index 970f401584..7bc9c4bfe4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -1,8 +1,16 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template std::unordered_set - get_all_leaf_paths(GenericBinarySPDecompositionTree const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index ccbd4a8c10..6c80f4ba9b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,12 +1,16 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &); -template std::unordered_multiset - get_leaves(GenericBinarySeriesSplit const &); -template std::unordered_multiset - get_leaves(GenericBinaryParallelSplit const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 697fb417d4..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" - -namespace FlexFlow { - -template GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &); -template GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 694d981733..89e8deb437 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1,12 +1,15 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template int - get_num_tree_nodes(GenericBinarySPDecompositionTree const &); -template int - get_num_tree_nodes(GenericBinarySeriesSplit const &); -template int - get_num_tree_nodes(GenericBinaryParallelSplit const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + int get_num_tree_nodes(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index ec56627455..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" - -namespace FlexFlow { - -template GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &); -template GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc index ac5509045a..e95284fa5e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -1,9 +1,17 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template std::optional> - get_subtree_at_path(GenericBinarySPDecompositionTree const &, - BinaryTreePath const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + std::optional + get_subtree_at_path(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 18bc1f3030..2b478edb20 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1,8 +1,16 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &); +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template + bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index fc6a5ee041..e50a861219 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,8 +1,16 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &); +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template + bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc deleted file mode 100644 index 27219bd4d8..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" - -namespace FlexFlow { - -template GenericBinarySPDecompositionTree - make_generic_binary_series_split( - int const &, - GenericBinarySPDecompositionTree const &, - GenericBinarySPDecompositionTree const &); -template GenericBinarySPDecompositionTree - make_generic_binary_parallel_split( - int const &label, - GenericBinarySPDecompositionTree const &, - GenericBinarySPDecompositionTree const &); -template GenericBinarySPDecompositionTree - make_generic_binary_sp_leaf(int const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc index 10029ceedd..9a3fa879d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -1,14 +1,27 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/archetypes/value_type.h" namespace FlexFlow { -template GenericBinarySeriesSplit - require_generic_binary_series_split( - GenericBinarySPDecompositionTree const &); -template GenericBinaryParallelSplit - require_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &); -template int require_generic_binary_leaf( - GenericBinarySPDecompositionTree const &); +using SeriesLabel = value_type<0>; +using ParallelLabel = value_type<1>; +using LeafLabel = value_type<2>; +template + GenericBinarySeriesSplit + require_generic_binary_series_split( + GenericBinarySPDecompositionTree const &); +template + GenericBinaryParallelSplit + require_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &); +template + LeafLabel require_generic_binary_leaf( + GenericBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc index 3193f8828c..bdb59887a1 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -2,13 +2,13 @@ namespace FlexFlow { -template GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &, - GenericBinarySPDecompositionTreeVisitor const &); +// template GenericBinarySeriesSplit +// transform(GenericBinarySeriesSplit const &, +// GenericBinarySPDecompositionTreeTransformVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc index 25409333f2..b7175e0e1b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using ReturnType = value_type<0>; +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template + ReturnType visit( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc deleted file mode 100644 index 007f1dbd52..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" - -namespace FlexFlow { - -template GenericBinarySPDecompositionTree - wrap_series_split(GenericBinarySeriesSplit const &); -template GenericBinarySPDecompositionTree - wrap_parallel_split(GenericBinaryParallelSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc index 61e5c8a9fa..79aac82e12 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template std::unordered_multiset - get_leaves(LeafOnlyBinarySPDecompositionTree const &); +// template std::unordered_multiset +// get_leaves(LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc index 90fed4010d..9f5516fbb3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template SPDecompositionTreeNodeType - get_node_type(LeafOnlyBinarySPDecompositionTree const &); +// template SPDecompositionTreeNodeType +// get_node_type(LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 9e00926b58..393189f092 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template bool is_binary_sp_tree_left_associative( - LeafOnlyBinarySPDecompositionTree const &); +// template bool is_binary_sp_tree_left_associative( +// LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index bec3410841..0a7724b13a 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -template bool is_binary_sp_tree_right_associative( - LeafOnlyBinarySPDecompositionTree const &); +// template bool is_binary_sp_tree_right_associative( +// LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..6d8ead4165 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h" + +namespace nlohmann { + +// template +// struct adl_serializer<::FlexFlow::LeafOnlyBinarySPDecompositionTree>; + +} // namespace nlohmann diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc deleted file mode 100644 index 62e21510a8..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" - -namespace FlexFlow { - -template LeafOnlyBinarySPDecompositionTree - get_left_child(LeafOnlyBinaryParallelSplit const &); -template LeafOnlyBinarySPDecompositionTree - get_right_child(LeafOnlyBinaryParallelSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc deleted file mode 100644 index efb5d779a8..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" - -namespace FlexFlow { - -template LeafOnlyBinarySPDecompositionTree - get_left_child(LeafOnlyBinarySeriesSplit const &); -template LeafOnlyBinarySPDecompositionTree - get_right_child(LeafOnlyBinarySeriesSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc new file mode 100644 index 0000000000..fe671a8e8f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h" + +namespace FlexFlow { + +// template +// GenericBinarySPDecompositionTreeVisitor +// generic_visitor_from_leaf_only_visitor(LeafOnlyBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc deleted file mode 100644 index 07ae9e604a..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" - -namespace FlexFlow { - -template LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_series_split( - LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTree const &); -template LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_parallel_split( - LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTree const &); -template LeafOnlyBinarySPDecompositionTree - leaf_only_binary_sp_tree_make_leaf(int const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc index 75c568fa4a..cefb44d0c4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc @@ -2,12 +2,12 @@ namespace FlexFlow { -template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( - LeafOnlyBinarySPDecompositionTree const &); -template LeafOnlyBinaryParallelSplit - require_leaf_only_binary_parallel_split( - LeafOnlyBinarySPDecompositionTree const &); -template int require_leaf_only_binary_leaf( - LeafOnlyBinarySPDecompositionTree const &); +// template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( +// LeafOnlyBinarySPDecompositionTree const &); +// template LeafOnlyBinaryParallelSplit +// require_leaf_only_binary_parallel_split( +// LeafOnlyBinarySPDecompositionTree const &); +// template int require_leaf_only_binary_leaf( +// LeafOnlyBinarySPDecompositionTree const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc index c7fb0811df..2421ffdc43 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc @@ -2,15 +2,14 @@ namespace FlexFlow { -template LeafOnlyBinarySeriesSplit transform( - LeafOnlyBinarySeriesSplit const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); -template LeafOnlyBinaryParallelSplit transform( - LeafOnlyBinaryParallelSplit const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); - -template LeafOnlyBinarySPDecompositionTree transform( - LeafOnlyBinarySPDecompositionTree const &, - LeafOnlyBinarySPDecompositionTreeVisitor const &); - +// template LeafOnlyBinarySeriesSplit transform( LeafOnlyBinarySeriesSplit const &, +// LeafOnlyBinarySPDecompositionTreeVisitor const &); +// template LeafOnlyBinaryParallelSplit transform( +// LeafOnlyBinaryParallelSplit const &, +// LeafOnlyBinarySPDecompositionTreeVisitor const &); +// +// template LeafOnlyBinarySPDecompositionTree transform( +// LeafOnlyBinarySPDecompositionTree const &, +// LeafOnlyBinarySPDecompositionTreeVisitor const &); +// } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc deleted file mode 100644 index 5d417271b4..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" - -namespace FlexFlow { - -template LeafOnlyBinarySPDecompositionTree - wrap_series_split(LeafOnlyBinarySeriesSplit const &); -template LeafOnlyBinarySPDecompositionTree - wrap_parallel_split(LeafOnlyBinaryParallelSplit const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index feb4749d0c..f40acce2ee 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,7 +18,7 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( from_parallel_child; auto from_node = [](Node const &n) -> BinarySPDecompositionTree { - return make_leaf_node(n); + return BinarySPDecompositionTree{n}; }; auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { @@ -26,19 +26,23 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( transform(s.children, from_series_child); return foldl1(children, [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) { - return make_series_split(accum, x); + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; }); }; auto from_parallel = [&](ParallelSplit const &s) -> BinarySPDecompositionTree { std::vector children = - transform(vector_of(s.children), from_parallel_child); + transform(vector_of(s.get_children()), from_parallel_child); return foldl1(children, [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) { - return make_parallel_split(accum, x); + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; }); }; diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index a4f6000900..2477140d71 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,25 +16,29 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( std::variant const &)> from_parallel_child; - auto from_node = [](Node const &n) { return make_leaf_node(n); }; + auto from_node = [](Node const &n) { return BinarySPDecompositionTree{n}; }; auto from_series = [&](SeriesSplit const &s) { std::vector children = transform(s.children, from_series_child); return foldr1(children, [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) { - return make_series_split(x, accum); + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; }); }; auto from_parallel = [&](ParallelSplit const &s) { std::vector children = - transform(vector_of(s.children), from_parallel_child); + transform(vector_of(s.get_children()), from_parallel_child); return foldr1(children, [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) { - return make_parallel_split(x, accum); + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; }); }; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index ab231f256c..84ef2fc106 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -36,7 +36,7 @@ std::optional ttsp_edge_to_sp_tree = map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict .as_unordered_map(), - [](Node const &n) { return make_leaf_node(n); }); + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -46,8 +46,12 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = make_parallel_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -62,8 +66,12 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = make_series_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 0ad586d499..07df693ae1 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,7 +1,5 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/containers/extend.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" @@ -51,16 +49,14 @@ std::variant flatten_ast( std::variant from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return visit>( - binary, - overload{ + return binary.template visit>(overload{ [](Node const &n) { return n; }, [](BinarySeriesSplit const &s) { return IntermediateSpDecompositionTree{ SplitType::SERIES, { - from_binary_sp_tree(get_left_child(s)), - from_binary_sp_tree(get_right_child(s)), + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), }, }; }, @@ -68,8 +64,8 @@ std::variant return IntermediateSpDecompositionTree{ SplitType::PARALLEL, { - from_binary_sp_tree(get_left_child(p)), - from_binary_sp_tree(get_right_child(p)), + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), }, }; }, diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index e697533054..b7a84b871a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -64,7 +64,7 @@ std::unordered_multiset get_nodes(SeriesSplit const &serial) { std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { return multiset_union(transform( - vector_of(parallel.children), + vector_of(parallel.get_children()), [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 0e04a4f904..7d36371e49 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,85 +1,85 @@ #include "utils/graph/series_parallel/series_parallel_splits.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" -#include "utils/hash-utils.h" -#include "utils/hash/unordered_multiset.h" -#include "utils/hash/vector.h" - -namespace FlexFlow { - -SeriesSplit::SeriesSplit( - std::vector> const &children) - : children(children) {} - -SeriesSplit::SeriesSplit( - std::initializer_list> const &children) - : children(children) {} - -bool SeriesSplit::operator==(SeriesSplit const &other) const { - return this->tie() == other.tie(); -} - -bool SeriesSplit::operator!=(SeriesSplit const &other) const { - return this->tie() != other.tie(); -} - -SeriesSplit::Tie SeriesSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(SeriesSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { - return s << fmt::to_string(split); -} - -ParallelSplit::ParallelSplit( - std::unordered_multiset> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit( - std::initializer_list> const &children) - : children(children) {} - -bool ParallelSplit::operator==(ParallelSplit const &other) const { - return this->tie() == other.tie(); -} - -bool ParallelSplit::operator!=(ParallelSplit const &other) const { - return this->tie() != other.tie(); -} - -ParallelSplit::Tie ParallelSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(ParallelSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { - return s << fmt::to_string(split); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::SeriesSplit>::operator()( - ::FlexFlow::SeriesSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -size_t hash<::FlexFlow::ParallelSplit>::operator()( - ::FlexFlow::ParallelSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -} // namespace std +// #include "utils/fmt/unordered_multiset.h" +// #include "utils/fmt/variant.h" +// #include "utils/fmt/vector.h" +// #include "utils/hash-utils.h" +// #include "utils/hash/unordered_multiset.h" +// #include "utils/hash/vector.h" +// +// namespace FlexFlow { +// +// SeriesSplit::SeriesSplit( +// std::vector> const &children) +// : children(children) {} +// +// SeriesSplit::SeriesSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool SeriesSplit::operator==(SeriesSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool SeriesSplit::operator!=(SeriesSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// SeriesSplit::Tie SeriesSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(SeriesSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// ParallelSplit::ParallelSplit( +// std::unordered_multiset> const &children) +// : children(children) {} +// +// ParallelSplit::ParallelSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool ParallelSplit::operator==(ParallelSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool ParallelSplit::operator!=(ParallelSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// ParallelSplit::Tie ParallelSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(ParallelSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// } // namespace FlexFlow +// +// namespace std { +// +// size_t hash<::FlexFlow::SeriesSplit>::operator()( +// ::FlexFlow::SeriesSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// size_t hash<::FlexFlow::ParallelSplit>::operator()( +// ::FlexFlow::ParallelSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// } // namespace std diff --git a/lib/utils/src/utils/json/check_is_json_deserializable.cc b/lib/utils/src/utils/json/check_is_json_deserializable.cc new file mode 100644 index 0000000000..7e17ced7e5 --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/check_is_json_serializable.cc b/lib/utils/src/utils/json/check_is_json_serializable.cc new file mode 100644 index 0000000000..1c9af4d3cb --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_serializable.h" diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc index 60e109faa9..9364e02afc 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -10,6 +10,14 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); @@ -31,9 +39,10 @@ TEST_SUITE(FF_TEST_SUITE) { TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; SplitBoundaryNodes result = get_transitive_reduced_boundary_nodes_for_split(tr_g, split); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index ed66292462..1b49c7218d 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -12,6 +12,18 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_edges_across_split") { DataflowGraph g = DataflowGraph::create(); + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + SUBCASE("multiple nodes with edges across") { NodeAddedResult n1_added = g.add_node({}, 1); Node n1 = n1_added.node; @@ -32,9 +44,10 @@ TEST_SUITE(FF_TEST_SUITE) { TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(make_series_split( - make_parallel_split(make_leaf_node(n1), make_leaf_node(n2)), - make_parallel_split(make_leaf_node(n3), make_leaf_node(n4)))); + BinarySeriesSplit split = BinarySeriesSplit{ + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4)), + }; std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); @@ -68,8 +81,10 @@ TEST_SUITE(FF_TEST_SUITE) { TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series( - make_series_split(make_leaf_node(n1), make_leaf_node(n2))); + BinarySeriesSplit split = BinarySeriesSplit{ + make_leaf(n1), + make_leaf(n2), + }; std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); @@ -111,9 +126,10 @@ TEST_SUITE(FF_TEST_SUITE) { TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; std::unordered_set result = get_transitive_reduced_edges_across_split(tr_g, split); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc index 1bd27c5f35..222e9b20bb 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -10,6 +10,14 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_outputs_across_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); @@ -31,9 +39,10 @@ TEST_SUITE(FF_TEST_SUITE) { TransitiveReducedDataflowGraphView tr_g = get_dataflow_graph_transitive_reduction(g); - BinarySeriesSplit split = require_series(make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_series_split(make_leaf_node(n3), make_leaf_node(n4)))); + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; std::unordered_set result = get_transitive_reduced_outputs_across_split(tr_g, split); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index af58bfb777..8981312c4b 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,87 +1,126 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "test/utils/doctest/fmt/unordered_multiset.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { - CHECK("TODO"); - // SUBCASE("leaf") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_sp_leaf(5); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {5}; - // - // CHECK(result == correct); - // } - // - // SUBCASE("series split") { - // SUBCASE("children are not the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {5, 6}; - // - // CHECK(result == correct); - // } - // - // SUBCASE("children are the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(5)); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {5, 5}; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("parallel split") { - // SUBCASE("children are not the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {5, 6}; - // - // CHECK(result == correct); - // } - // - // SUBCASE("children are the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(5)); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {5, 5}; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("nested") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(4), - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(2), - // make_generic_binary_sp_leaf(5))), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(4), - // make_generic_binary_sp_leaf(2))); - // - // std::unordered_multiset result = get_leaves(input); - // std::unordered_multiset correct = {2, 2, 4, 4, 5}; - // - // CHECK(result == correct); - // } + TEST_CASE("get_leaves") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> impl = generic_impl_for_binary_sp_tree(); + + auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { + return get_leaves(tree, impl); + }; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{n1}; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1}; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = + BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = + BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; + + CHECK(result == correct); + } + } + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + + SUBCASE("nested") { + BinarySPDecompositionTree input = + make_parallel_split( + make_series_split( + make_leaf(n1), + make_series_split( + make_leaf(n2), + make_leaf(n3))), + make_parallel_split( + make_leaf(n2), + make_leaf(n1))); + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1, n2, n2, n3}; + + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 0d36ccbe92..f6424df03c 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1,85 +1,108 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { - // SUBCASE("leaf") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_sp_leaf(5); - // - // int result = get_num_tree_nodes(input); - // int correct = 1; - // - // CHECK(result == correct); - // } - // - // SUBCASE("series split") { - // SUBCASE("children are not the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // - // int result = get_num_tree_nodes(input); - // int correct = 3; - // - // CHECK(result == correct); - // } - // - // SUBCASE("children are the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(5)); - // - // int result = get_num_tree_nodes(input); - // int correct = 3; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("parallel split") { - // SUBCASE("children are not the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(6)); - // - // int result = get_num_tree_nodes(input); - // int correct = 3; - // - // CHECK(result == correct); - // } - // - // SUBCASE("children are the same") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - // make_generic_binary_sp_leaf(5)); - // - // int result = get_num_tree_nodes(input); - // int correct = 3; - // - // CHECK(result == correct); - // } - // } - // - // SUBCASE("nested") { - // GenericBinarySPDecompositionTree input = - // make_generic_binary_parallel_split( - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(4), - // make_generic_binary_series_split( - // make_generic_binary_sp_leaf(2), - // make_generic_binary_sp_leaf(5))), - // make_generic_binary_parallel_split( - // make_generic_binary_sp_leaf(4), - // make_generic_binary_sp_leaf(2))); - // - // int result = get_num_tree_nodes(input); - // int correct = 9; - // - // CHECK(result == correct); - // } + TEST_CASE("get_num_tree_nodes") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + + auto generic_get_num_tree_nodes = [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = + make_leaf(n1); + + int result = generic_get_num_tree_nodes(input); + int correct = 1; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n2)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n1)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n2)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n1)); + + int result = generic_get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = + make_parallel_split( + make_series_split( + make_leaf(n1), + make_series_split( + make_leaf(n2), + make_leaf(n3))), + make_parallel_split( + make_leaf(n2), + make_leaf(n1))); + + int result = generic_get_num_tree_nodes(input); + int correct = 9; + + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc similarity index 58% rename from lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc rename to lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 8f1b8efaf7..f31de9bccf 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1,35 +1,38 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_binary_sp_tree_left_associative") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; - - auto make_leaf = [](int n) { - return leaf_only_binary_sp_tree_make_leaf(n); + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_series_split = - [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_series_split(l, r); - }; + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; - auto make_parallel_split = - [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_parallel_split(l, r); - }; + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; SUBCASE("input is actually left associative") { SUBCASE("just node") { - LeafOnlyBinarySPDecompositionTree input = make_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -38,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); @@ -48,7 +51,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + BinarySPDecompositionTree input = make_parallel_split( make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); @@ -58,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_parallel_split(make_leaf(n1), make_leaf(n2)), make_parallel_split(make_leaf(n3), make_leaf(n4))); @@ -71,7 +74,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not left associative") { SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); @@ -81,7 +84,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + BinarySPDecompositionTree input = make_parallel_split( make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc similarity index 60% rename from lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc rename to lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 88e08e7624..6ee0a72c23 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,35 +1,38 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_binary_sp_tree_right_associative") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; - - auto make_leaf = [](int n) { - return leaf_only_binary_sp_tree_make_leaf(n); + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_series_split = - [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_series_split(l, r); - }; + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; - auto make_parallel_split = - [](LeafOnlyBinarySPDecompositionTree const &l, - LeafOnlyBinarySPDecompositionTree const &r) { - return leaf_only_binary_sp_tree_make_parallel_split(l, r); - }; + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; SUBCASE("input is actually right associative") { SUBCASE("just node") { - LeafOnlyBinarySPDecompositionTree input = make_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -38,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); @@ -48,7 +51,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + BinarySPDecompositionTree input = make_parallel_split( make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); @@ -58,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_parallel_split(make_leaf(n1), make_leaf(n2)), make_parallel_split(make_leaf(n3), make_leaf(n4))); @@ -71,7 +74,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not right associative") { SUBCASE("just series") { - LeafOnlyBinarySPDecompositionTree input = make_series_split( + BinarySPDecompositionTree input = make_series_split( make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); @@ -81,7 +84,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - LeafOnlyBinarySPDecompositionTree input = make_parallel_split( + BinarySPDecompositionTree input = make_parallel_split( make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 1e3217a2de..db7ef92507 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,34 +18,46 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_leaf_node(n3)); + make_series_split(make_leaf(n1), make_leaf(n2)), + make_leaf(n3)); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -64,20 +76,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 0befbde5cc..217ee5305a 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -14,8 +14,20 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + SUBCASE("leaf") { - BinarySPDecompositionTree input = make_leaf_node(n1); + BinarySPDecompositionTree input = make_leaf(n1); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; @@ -25,35 +37,35 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative series") { BinarySPDecompositionTree input = make_series_split( - make_series_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_series_split(make_leaf(n2), make_leaf(n1)), + make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative series") { BinarySPDecompositionTree input = make_series_split( - make_leaf_node(n2), - make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), + make_series_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("series with duplicate children") { BinarySPDecompositionTree input = - make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + make_series_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + SeriesParallelDecomposition{SeriesSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -61,35 +73,35 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_parallel_split(make_leaf(n2), make_leaf(n1)), + make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_leaf_node(n2), - make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), + make_parallel_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("parallel with duplicate children") { BinarySPDecompositionTree input = - make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + make_parallel_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + SeriesParallelDecomposition{ParallelSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -99,31 +111,31 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree input = make_parallel_split( make_parallel_split( make_parallel_split( - make_leaf_node(n1), + make_leaf(n1), make_series_split( - make_series_split(make_series_split(make_leaf_node(n2), - make_leaf_node(n3)), - make_leaf_node(n3)), - make_leaf_node(n5))), - make_series_split(make_leaf_node(n6), make_leaf_node(n4))), - make_leaf_node(n5)); + make_series_split(make_series_split(make_leaf(n2), + make_leaf(n3)), + make_leaf(n3)), + make_leaf(n5))), + make_series_split(make_leaf(n6), make_leaf(n4))), + make_leaf(n5)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index db1b440481..53234eb513 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,34 +16,46 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { + return BinarySPDecompositionTree{n}; + }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_leaf_node(n1), - make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -62,20 +74,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 45f796c824..e5b9045739 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -24,10 +24,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ + SeriesParallelDecomposition{ParallelSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -39,10 +39,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -59,13 +59,13 @@ TEST_SUITE(FF_TEST_SUITE) { get_series_parallel_decomposition(g); std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), - ParallelSplit{ + ParallelSplit{{ n.at(1), n.at(2), - }, - }, + }}, + }}, }; CHECK(result == correct); } @@ -86,20 +86,20 @@ TEST_SUITE(FF_TEST_SUITE) { }); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), - ParallelSplit{ - SeriesSplit{ + ParallelSplit{{ + SeriesSplit{{ n.at(1), n.at(3), - }, - SeriesSplit{ + }}, + SeriesSplit{{ n.at(2), n.at(4), - }, - }, + }}, + }}, n.at(5), - }}; + }}}; std::optional result = get_series_parallel_decomposition(g); @@ -122,16 +122,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ - ParallelSplit{ + SeriesSplit{{ + ParallelSplit{{ n.at(0), n.at(1), - }, - ParallelSplit{ + }}, + ParallelSplit{{ n.at(2), n.at(3), - }, - }, + }}, + }}, }; std::optional result = @@ -177,12 +177,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), n.at(1), n.at(2), n.at(3), - }, + }}, }; std::optional result = get_series_parallel_decomposition(g); From 39c8f1ce781ec0a360b6284f14f0c05ec1b9eec3 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 7 Oct 2024 21:16:19 -0700 Subject: [PATCH 27/29] Pass all tests --- .../export_model_arch/json_sp_model_export.h | 17 -- .../json_sp_model_export.struct.toml | 6 +- .../src/export_model_arch.cc | 4 +- .../export_model_arch/json_sp_model_export.cc | 20 -- ...omputation_graph_binary_sp_decomposition.h | 5 + .../pcg/pcg_binary_parallel_split.h | 4 +- .../pcg/pcg_binary_series_split.h | 4 +- ...mputation_graph_binary_sp_decomposition.cc | 27 +++ .../pcg/pcg_binary_parallel_split.cc | 13 ++ .../pcg/pcg_binary_series_split.cc | 13 ++ .../pcg/pcg_binary_sp_decomposition.cc | 6 +- .../v1/v1_binary_sp_decomposition/json.h | 29 +++ .../v1_binary_parallel_split.struct.toml | 25 +++ .../v1_binary_series_split.struct.toml | 25 +++ .../v1_binary_sp_decomposition.variant.toml | 24 +++ .../v1/v1_binary_sp_decomposition/json.cc | 75 ++++++++ .../v1/v1_binary_sp_decomposition/json.cc | 177 ++++++++++++++++++ .../binary_sp_decomposition_tree.cc | 41 +++- ...ft_associative_binary_sp_tree_from_nary.cc | 1 - .../get_num_tree_nodes.cc | 2 +- .../is_binary_sp_tree_left_associative.cc | 2 +- .../is_binary_sp_tree_right_associative.cc | 2 +- ...ft_associative_binary_sp_tree_from_nary.cc | 2 +- .../nary_sp_tree_from_binary.cc | 2 +- ...ht_associative_binary_sp_tree_from_nary.cc | 2 +- 25 files changed, 471 insertions(+), 57 deletions(-) delete mode 100644 bin/export-model-arch/include/export_model_arch/json_sp_model_export.h delete mode 100644 bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h deleted file mode 100644 index df4e140b99..0000000000 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H -#define _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H - -#include -#include "export_model_arch/json_sp_model_export.dtg.h" - -namespace nlohmann { - -template <> -struct adl_serializer<::FlexFlow::JsonSPModelExport> { - static ::FlexFlow::JsonSPModelExport from_json(json const &); - static void to_json(json &, ::FlexFlow::JsonSPModelExport const &); -}; - -} // namespace nlohmann - -#endif diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml index 3c08bc150e..efaf10c255 100644 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -9,16 +9,16 @@ features = [ includes = [ "pcg/file_format/v1/v1_computation_graph.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", ] src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", ] [[fields]] name = "sp_decomposition" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" +type = "::FlexFlow::V1BinarySPDecomposition" [[fields]] name = "computation_graph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 0ff4f47a40..8b3b90b21b 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -105,9 +105,7 @@ tl::expected to_v1_including_node_numbering(computation_graph); V1ComputationGraph v1_cg = v1_result.first; bidict layer_numbering = v1_result.second; - GenericBinarySPDecompositionTree v1_sp_decomposition = - transform(sp_decomposition.raw_tree, - [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + V1BinarySPDecomposition v1_sp_decomposition = to_v1(sp_decomposition, layer_numbering); return JsonSPModelExport{ v1_sp_decomposition, diff --git a/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc b/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc deleted file mode 100644 index ca8d76d803..0000000000 --- a/bin/export-model-arch/src/export_model_arch/json_sp_model_export.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "export_model_arch/json_sp_model_export.h" - -using namespace ::FlexFlow; - -namespace nlohmann { - -JsonSPModelExport adl_serializer::from_json(json const &j) { - NOT_IMPLEMENTED(); -} - -static void sp_decomposition_to_json(json &j, LeafOnlyBinarySPDecompositionTree const &t) { -} - -void adl_serializer::to_json(json &j, JsonSPModelExport const &m) { - j["computation_graph"] = m.computation_graph; - sp_decomposition_to_json(j["sp_decomposition"], m.sp_decomposition); -} - - -} // namespace nlohmann diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index ea6723a9cd..eb50ee365e 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -3,9 +3,11 @@ #include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" namespace FlexFlow { @@ -32,6 +34,9 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &); std::unordered_multiset get_layers(ComputationGraphBinarySPDecomposition const &); +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, + bidict const &layer_numbering); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h index 0ac8cee95b..05a1ae1169 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_PARALLEL_SPLIT_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H #include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h index d0bda09229..83e53e3d41 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SERIES_SPLIT_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H #include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc index 6d6e63429b..f26b899109 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -142,4 +142,31 @@ std::unordered_multiset return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); } +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit(overload { + [&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + } + }); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc new file mode 100644 index 0000000000..7e6327d06a --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -0,0 +1,13 @@ +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &pcg_split) { + return BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc new file mode 100644 index 0000000000..b0fec5f6ce --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -0,0 +1,13 @@ +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &pcg_split) { + return BinarySeriesSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index d0c54c91aa..0555c2a14d 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,5 +1,4 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" -#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" @@ -55,7 +54,10 @@ BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecompositi }, [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { return BinarySPDecompositionTree{ - binary_parallel_split_from_pcg_parallel_split(parallel), + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, }; }, [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h new file mode 100644 index 0000000000..62cfd6ec62 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H + +#include +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::V1BinarySPDecomposition> { + static ::FlexFlow::V1BinarySPDecomposition from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySPDecomposition const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinarySeriesSplit> { + static ::FlexFlow::V1BinarySeriesSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySeriesSplit const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinaryParallelSplit> { + static ::FlexFlow::V1BinaryParallelSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinaryParallelSplit const &); +}; + +} // namespace nlohmann + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..d2d0c3bc77 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml new file mode 100644 index 0000000000..317fa8b6ce --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..0fe0b1761f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "V1BinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::V1BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::V1BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "int" +key = "leaf" diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..3adb79eb8f --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,75 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include "utils/exception.h" +#include "utils/overload.h" +#include "utils/fmt/json.h" + +using namespace ::FlexFlow; + +namespace nlohmann { + +V1BinarySPDecomposition adl_serializer::from_json(json const &j) { + std::string type = j.at("type").get(); + + if (type == "series") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "parallel") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "leaf") { + return V1BinarySPDecomposition{ + j.at("value").get(), + }; + } else { + throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j)); + } +} + +void adl_serializer::to_json(json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload { + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, + }); +} + +V1BinarySeriesSplit adl_serializer::from_json(json const &j) { + return V1BinarySeriesSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json(json &j, V1BinarySeriesSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +V1BinaryParallelSplit adl_serializer::from_json(json const &j) { + return V1BinaryParallelSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json(json &j, V1BinaryParallelSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..e9f2573914 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,177 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }, + }; + + nlohmann::json example_json = { + {"type", "series"}, + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_tree; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySPDecomposition result = example_json.get(); + V1BinarySPDecomposition correct = example_tree; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySeriesSplit result = example_json.get(); + V1BinarySeriesSplit correct = example_split; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinaryParallelSplit result = example_json.get(); + V1BinaryParallelSplit correct = example_split; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index be7115ee9f..56718fa71f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -10,7 +10,38 @@ GenericBinarySPDecompositionTreeImplementation< BinarySeriesSplit, BinaryParallelSplit, Node> generic_impl_for_binary_sp_tree() { - NOT_IMPLEMENTED(); + + return GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node> + { + /*series_get_left_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/[](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/[](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/[](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/[](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, + }; } bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { @@ -25,4 +56,12 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tree) return get_leaves(tree, generic_impl_for_binary_sp_tree()); } +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload { + [](BinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](BinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index f40acce2ee..33ac5f00e9 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -3,7 +3,6 @@ #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index f6424df03c..f61ff83bf9 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -22,7 +22,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index f31de9bccf..05ff0b4aaa 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -23,7 +23,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 6ee0a72c23..4e889ceab0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -23,7 +23,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index db7ef92507..20f939a8f0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -23,7 +23,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 217ee5305a..5db50ab2ef 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 53234eb513..19b9cfd944 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; auto make_leaf = [](Node const &n) { From 75f7e98d4d10940acb4974164a61bda09b765aa4 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 7 Oct 2024 21:25:49 -0700 Subject: [PATCH 28/29] Remove a bunch of unnecessary code --- .../src/export_model_arch.cc | 1 - .../full_binary_tree.variant.toml | 24 ---- .../full_binary_tree_parent_node.struct.toml | 34 ----- .../utils/full_binary_tree/get_node_type.h | 21 --- .../include/utils/full_binary_tree/json.h | 84 ------------ .../include/utils/full_binary_tree/require.h | 24 ---- .../utils/full_binary_tree/transform.h | 49 ------- .../include/utils/full_binary_tree/visit.h | 1 - .../generic_binary_parallel_split.struct.toml | 35 ----- .../generic_binary_series_split.struct.toml | 35 ----- ..._binary_sp_decomposition_tree.variant.toml | 31 ----- ...osition_tree_transform_visitor.struct.toml | 28 ---- .../get_node_type.h | 30 ----- .../generic_binary_sp_decomposition_tree/is.h | 31 ----- .../require.h | 37 ------ .../transform.h | 124 ------------------ .../find_paths_to_leaf.h | 18 --- .../get_leaves.h | 17 --- .../get_node_type.h | 17 --- .../is_binary_sp_tree_left_associative.h | 17 --- .../is_binary_sp_tree_right_associative.h | 17 --- .../json.h | 78 ----------- ...eaf_only_binary_parallel_split.struct.toml | 30 ----- .../leaf_only_binary_series_split.struct.toml | 29 ---- ..._binary_sp_decomposition_tree.variant.toml | 29 ---- ...osition_tree_transform_visitor.struct.toml | 16 --- ...nly_binary_sp_decomposition_tree_visitor.h | 35 ----- ..._sp_decomposition_tree_visitor.struct.toml | 26 ---- .../require.h | 45 ------- .../transform.h | 53 -------- .../utils/full_binary_tree/get_node_type.cc | 7 - .../src/utils/full_binary_tree/require.cc | 10 -- .../get_node_type.cc | 8 -- .../is.cc | 11 -- .../require.cc | 27 ---- .../transform.cc | 14 -- .../get_leaves.cc | 8 -- .../get_node_type.cc | 8 -- .../is_binary_sp_tree_left_associative.cc | 8 -- .../is_binary_sp_tree_right_associative.cc | 8 -- .../json.cc | 8 -- ...ly_binary_sp_decomposition_tree_visitor.cc | 9 -- .../require.cc | 13 -- .../transform.cc | 15 --- .../is_binary_sp_tree_right_associative.cc | 2 +- 45 files changed, 1 insertion(+), 1171 deletions(-) delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml delete mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml delete mode 100644 lib/utils/include/utils/full_binary_tree/get_node_type.h delete mode 100644 lib/utils/include/utils/full_binary_tree/json.h delete mode 100644 lib/utils/include/utils/full_binary_tree/require.h delete mode 100644 lib/utils/include/utils/full_binary_tree/transform.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h delete mode 100644 lib/utils/src/utils/full_binary_tree/get_node_type.cc delete mode 100644 lib/utils/src/utils/full_binary_tree/require.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 8b3b90b21b..9da33023a0 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -13,7 +13,6 @@ #include "utils/cli/cli_parse.h" #include "utils/cli/cli_parse_result.h" #include "utils/cli/cli_spec.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/get_series_parallel_decomposition.h" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml deleted file mode 100644 index 2183abe900..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.variant.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTree" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h", -] - -[[values]] -type = "::FlexFlow::FullBinaryTreeParentNode" -key = "parent" - -[[values]] -type = "LeafLabel" -key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml deleted file mode 100644 index 3403271621..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml +++ /dev/null @@ -1,34 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeParentNode" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "template struct FullBinaryTree", -] - -post_includes = [ - "utils/full_binary_tree/full_binary_tree.dtg.h", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -[[fields]] -name = "label" -type = "ParentLabel" - -[[fields]] -name = "left_child" -type = "::FlexFlow::FullBinaryTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::FullBinaryTree" -indirect = true diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h deleted file mode 100644 index d49faa3694..0000000000 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H - -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -FullBinaryTreeNodeType - get_node_type(FullBinaryTree const &tree) { - return tree.template visit(overload { - [](FullBinaryTreeParentNode const &) { return FullBinaryTreeNodeType::PARENT; }, - [](LeafLabel const &) { return FullBinaryTreeNodeType::LEAF; }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h deleted file mode 100644 index a589c541da..0000000000 --- a/lib/utils/include/utils/full_binary_tree/json.h +++ /dev/null @@ -1,84 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H - -#include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" -#include - -namespace nlohmann { - -template -struct adl_serializer< - ::FlexFlow::FullBinaryTreeParentNode> { - static ::FlexFlow::FullBinaryTreeParentNode - from_json(json const &j) { - return ::FlexFlow::FullBinaryTreeParentNode{ - j.at("left_child") - .template get< - ::FlexFlow::FullBinaryTreeParentNode>(), - j.at("right_child") - .template get< - ::FlexFlow::FullBinaryTreeParentNode>(), - }; - } - - static void to_json( - json &j, - ::FlexFlow::FullBinaryTreeParentNode const &v) { - j["__type"] = "FullBinaryTreeParentNode"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::FullBinaryTree> { - static ::FlexFlow::FullBinaryTree - from_json(json const &j) { - std::string key = j.at("type").get(); - - if (key == "parent") { - return ::FlexFlow::FullBinaryTree{ - j.at("value") - .get<::FlexFlow::FullBinaryTreeParentNode>(), - }; - } else if (key == "leaf") { - return ::FlexFlow::FullBinaryTree{ - j.at("value").get(), - }; - } else { - throw ::FlexFlow::mk_runtime_error( - fmt::format("Unknown json type key: {}", key)); - } - } - - static void - to_json(json &j, - ::FlexFlow::FullBinaryTree const &v) { - j["__type"] = "FullBinaryTree"; - ::FlexFlow::visit( - v, - ::FlexFlow::overload{ - [&](::FlexFlow::FullBinaryTreeParentNode const &s) { - j["type"] = "parent"; - j["value"] = s; - return std::monostate{}; - }, - [&](LeafLabel const &t) { - j["type"] = "leaf"; - j["value"] = t; - return std::monostate{}; - }, - }); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h deleted file mode 100644 index 7b2625a9e8..0000000000 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H - -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" - -namespace FlexFlow { - -template -FullBinaryTreeParentNode const & - require_full_binary_tree_parent_node( - FullBinaryTree const &t) { - return t.template get>(); -} - -template -LeafLabel const &require_full_binary_tree_leaf( - FullBinaryTree const &t) { - return t.template get(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/transform.h b/lib/utils/include/utils/full_binary_tree/transform.h deleted file mode 100644 index 6e33064025..0000000000 --- a/lib/utils/include/utils/full_binary_tree/transform.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H - -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/get_left_child.h" -#include "utils/full_binary_tree/get_right_child.h" -#include "utils/full_binary_tree/visit.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template , - typename LeafLabel2 = std::invoke_result_t> -FullBinaryTreeParentNode - transform(FullBinaryTreeParentNode const &t, F f) { - return FullBinaryTreeParentNode{ - transform(get_left_child(t), f), - transform(get_right_child(t), f), - }; -} - -template , - typename LeafLabel2 = std::invoke_result_t> -FullBinaryTree - transform(FullBinaryTree const &t, F f) { - return visit>( - t, - overload{ - [&](FullBinaryTreeParentNode const &parent) { - return FullBinaryTree{ - transform(parent, f), - }; - }, - [&](LeafLabel const &leaf) { - return FullBinaryTree{ - f(leaf), - }; - }}); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index f9349069a9..87aa115c8c 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -4,7 +4,6 @@ #include "utils/exception.h" #include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" -#include "utils/full_binary_tree/get_node_type.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml deleted file mode 100644 index 139074299f..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml +++ /dev/null @@ -1,35 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "template struct GenericBinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -] - -template_params = [ - "SeriesLabel", - "ParallelLabel", - "LeafLabel", -] - -[[fields]] -name = "label" -type = "ParallelLabel" - -[[fields]] -name = "left_child" -type = "::FlexFlow::GenericBinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::GenericBinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml deleted file mode 100644 index 054532e2e0..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml +++ /dev/null @@ -1,35 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "template struct GenericBinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -] - -template_params = [ - "SeriesLabel", - "ParallelLabel", - "LeafLabel", -] - -[[fields]] -name = "label" -type = "SeriesLabel" - -[[fields]] -name = "left_child" -type = "::FlexFlow::GenericBinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::GenericBinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml deleted file mode 100644 index 5edaa80a09..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.variant.toml +++ /dev/null @@ -1,31 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt", - "json", -] - -template_params = [ - "SeriesLabel", - "ParallelLabel", - "LeafLabel", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", -] - -[[values]] -type = "::FlexFlow::GenericBinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::GenericBinaryParallelSplit" -key = "parallel" - -[[values]] -type = "LeafLabel" -key = "leaf" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml deleted file mode 100644 index df0b0f2ea7..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.struct.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "GenericBinarySPDecompositionTreeTransformVisitor" -features = [] - -template_params = [ - "SeriesSplitLabel", - "ParallelSplitLabel", - "LeafLabel", - "SeriesSplitLabel2", - "ParallelSplitLabel2", - "LeafLabel2", -] - -includes = [ - "", -] - -[[fields]] -name = "series_split_func" -type = "std::function" - -[[fields]] -name = "parallel_split_func" -type = "std::function" - -[[fields]] -name = "leaf_func" -type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h deleted file mode 100644 index c4e86a252d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tree) { - return tree.template visit(overload { - [](GenericBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplit const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](LeafLabel const &) { - return SPDecompositionTreeNodeType::NODE; - } - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h deleted file mode 100644 index a7046bedbe..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" - -namespace FlexFlow { - -template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { - return get_node_type(t) == SPDecompositionTreeNodeType::SERIES; -} - -template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { - return get_node_type(t) == SPDecompositionTreeNodeType::PARALLEL; -} - -template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { - return get_node_type(t) == SPDecompositionTreeNodeType::NODE; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h deleted file mode 100644 index 315fca844d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" - -namespace FlexFlow { - -template -GenericBinarySeriesSplit - require_generic_binary_series_split( - GenericBinarySPDecompositionTree const &tree) { - return tree.template get>(); -} - -template -GenericBinaryParallelSplit - require_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &tree) { - return tree.template get>(); -} - -template -LeafLabel require_generic_binary_leaf( - GenericBinarySPDecompositionTree const &tree) { - return tree.template get(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h deleted file mode 100644 index 00419ace55..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_transform_visitor.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" -// #include "utils/overload.h" -// -// namespace FlexFlow { -// -// template -// GenericBinarySPDecompositionTree -// transform( -// GenericBinarySPDecompositionTree const &tt, -// GenericBinarySPDecompositionTreeTransformVisitor const &visitor); -// -// template -// GenericBinarySeriesSplit transform( -// GenericBinarySeriesSplit const &s, -// GenericBinarySPDecompositionTreeTransformVisitor const &visitor) { -// return GenericBinarySeriesSplit{ -// visitor.series_split_func(s.label), -// transform(get_left_child(s), visitor), -// transform(get_right_child(s), visitor), -// }; -// }; -// -// template -// GenericBinaryParallelSplit transform( -// GenericBinaryParallelSplit const &s, -// GenericBinarySPDecompositionTreeTransformVisitor const &visitor) { -// return GenericBinaryParallelSplit{ -// visitor.parallel_split_func(s.label), -// transform(get_left_child(s), visitor), -// transform(get_right_child(s), visitor), -// }; -// }; -// -// template -// GenericBinarySPDecompositionTree -// transform( -// GenericBinarySPDecompositionTree const &tree, -// GenericBinarySPDecompositionTreeTransformVisitor const &transform_visitor) { -// -// using ResultType = -// GenericBinarySPDecompositionTree; -// -// auto visitor = GenericBinarySPDecompositionTreeVisitor< -// ResultType, -// SeriesLabel, -// ParallelLabel, -// LeafLabel> -// { -// [&](GenericBinarySeriesSplit const &s) -> ResultType { -// return generic_binary_sp_tree_wrap_series_split(transform(s, transform_visitor)); -// }, -// [&](GenericBinaryParallelSplit const &s) -> ResultType { -// return generic_binary_sp_tree_wrap_parallel_split(transform(s, transform_visitor)); -// }, -// [&](LeafLabel const &t) -> ResultType { -// return generic_binary_sp_tree_wrap_leaf( -// transform_visitor.leaf_func(t)); -// }, -// }; -// -// return visit(tree, visitor); -// } -// -// } // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h deleted file mode 100644 index 1d7f9ae88c..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -template -std::unordered_set - find_paths_to_leaf(LeafOnlyBinarySPDecompositionTree const &tree, - LeafLabel const &leaf) { - return find_paths_to_leaf(tree.raw_tree, leaf); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h deleted file mode 100644 index 738b4c013d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -// -// namespace FlexFlow { -// -// template -// std::unordered_multiset -// get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { -// return get_leaves(t.raw_tree); -// } -// -// } // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h deleted file mode 100644 index 78dd8990c7..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -// -// namespace FlexFlow { -// -// template -// SPDecompositionTreeNodeType -// get_node_type(LeafOnlyBinarySPDecompositionTree const &tree) { -// return get_node_type(tree.raw_tree); -// } -// -// } // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h deleted file mode 100644 index df365360ca..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -// -// namespace FlexFlow { -// -// template -// bool is_binary_sp_tree_left_associative( -// LeafOnlyBinarySPDecompositionTree const &t) { -// return is_binary_sp_tree_left_associative(t.raw_tree); -// } -// -// } // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h deleted file mode 100644 index 8eb09f3654..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -// -// namespace FlexFlow { -// -// template -// bool is_binary_sp_tree_right_associative( -// LeafOnlyBinarySPDecompositionTree const &t) { -// return is_binary_sp_tree_right_associative(t.raw_tree); -// } -// -// } // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h deleted file mode 100644 index b5c1ed3f95..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h +++ /dev/null @@ -1,78 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_JSON_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/visit.h" -// #include "utils/json/check_is_json_deserializable.h" -// #include "utils/json/check_is_json_serializable.h" -// #include "utils/fmt/json.h" -// #include - -namespace nlohmann { - -// template -// struct adl_serializer<::FlexFlow::LeafOnlyBinarySPDecompositionTree> { -// static ::FlexFlow::LeafOnlyBinarySPDecompositionTree from_json(json const &j) { -// CHECK_IS_JSON_SERIALIZABLE(LeafLabel); -// -// using namespace ::FlexFlow; -// -// using Tree = LeafOnlyBinarySPDecompositionTree; -// -// std::string type = j.at("type").get(); -// -// if (type == "series") { -// return leaf_only_binary_sp_tree_wrap_series_split( -// LeafOnlyBinarySeriesSplit{ -// /*lhs=*/j.at("left_child").get(), -// /*rhs=*/j.at("right_child").get(), -// }); -// } else if (type == "parallel") { -// return leaf_only_binary_sp_tree_wrap_parallel_split( -// LeafOnlyBinaryParallelSplit{ -// /*lhs=*/j.at("left_child").get(), -// /*rhs=*/j.at("right_child").get(), -// }); -// } else if (type == "leaf") { -// return leaf_only_binary_sp_tree_wrap_leaf(j.at("label").get()); -// } else { -// throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j)); -// } -// } -// -// static void to_json(json &j, ::FlexFlow::LeafOnlyBinarySPDecompositionTree const &tree) { -// CHECK_IS_JSON_DESERIALIZABLE(LeafLabel); -// -// using namespace FlexFlow; -// -// using Tree = LeafOnlyBinarySPDecompositionTree; -// -// auto visitor = LeafOnlyBinarySPDecompositionTreeVisitor{ -// /*series_func=*/[&](LeafOnlyBinarySeriesSplit const &split) { -// j["type"] = "series"; -// j["left_child"] = split.lhs; -// j["right_child"] = split.rhs; -// return std::monostate{}; -// }, -// /*parallel_func=*/[&](LeafOnlyBinaryParallelSplit const &split) { -// j["type"] = "parallel"; -// j["left_child"] = split.lhs; -// j["right_child"] = split.rhs; -// return std::monostate{}; -// }, -// /*leaf_func=*/[&](LeafLabel const &leaf_label) { -// j["type"] = "leaf"; -// j["label"] = leaf_label; -// return std::monostate{}; -// }, -// }; -// -// visit(tree, visitor); -// } -// }; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml deleted file mode 100644 index 802ee854ab..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml +++ /dev/null @@ -1,30 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinaryParallelSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "template struct LeafOnlyBinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", -] - - -template_params = [ - "LeafLabel", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml deleted file mode 100644 index 95a647aea3..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySeriesSplit" -features = [ - "eq", - "hash", - "fmt", -] - -fwd_decls = [ - "template struct LeafOnlyBinarySPDecompositionTree", -] - -post_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", -] - -template_params = [ - "LeafLabel", -] - -[[fields]] -name = "left_child" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" -indirect = true - -[[fields]] -name = "right_child" -type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" -indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml deleted file mode 100644 index fdab932039..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.variant.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySPDecompositionTree" -features = [ - "eq", - "hash", - "fmt" -] - -template_params = [ - "LeafLabel", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", - "", -] - -[[values]] -type = "::FlexFlow::LeafOnlyBinarySeriesSplit" -key = "series" - -[[values]] -type = "::FlexFlow::LeafOnlyBinaryParallelSplit" -key = "parallel" - -[[values]] -type = "LeafLabel" -key = "leaf" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml deleted file mode 100644 index b94972c323..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_transform_visitor.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySPDecompositionTreeTransformVisitor" -features = [] - -template_params = [ - "LeafLabel", - "LeafLabel2", -] - -includes = [ - "", -] - -[[fields]] -name = "leaf_func" -type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h deleted file mode 100644 index f1dbcb97af..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_VISITOR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_VISITOR_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" - -namespace FlexFlow { - -// template -// GenericBinarySPDecompositionTreeVisitor -// generic_visitor_from_leaf_only_visitor(LeafOnlyBinarySPDecompositionTreeVisitor const &leaf_only) { -// return GenericBinarySPDecompositionTreeVisitor{ -// [leaf_only](GenericBinarySeriesSplit const &split) { -// return leaf_only.series_func( -// LeafOnlyBinarySeriesSplit{ -// LeafOnlyBinarySPDecompositionTree{split.lhs}, -// LeafOnlyBinarySPDecompositionTree{split.rhs}, -// }); -// }, -// [leaf_only](GenericBinaryParallelSplit const &split) { -// return leaf_only.parallel_func( -// LeafOnlyBinaryParallelSplit{ -// LeafOnlyBinarySPDecompositionTree{split.lhs}, -// LeafOnlyBinarySPDecompositionTree{split.rhs}, -// }); -// }, -// [leaf_only](LeafLabel const &leaf_label) { -// return leaf_only.leaf_func(leaf_label); -// }, -// }; -// } - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml deleted file mode 100644 index 7174afcaa7..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySPDecompositionTreeVisitor" -features = [] - -template_params = [ - "ReturnType", - "LeafLabel", -] - -includes = [ - "", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", -] - -[[fields]] -name = "series_func" -type = "std::function const &)>" - -[[fields]] -name = "parallel_func" -type = "std::function const &)>" - -[[fields]] -name = "leaf_func" -type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h deleted file mode 100644 index e83120a0f9..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -// template -// LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( -// LeafOnlyBinarySPDecompositionTree const &t) { -// GenericBinarySeriesSplit raw = -// require_generic_binary_series_split(t.raw_tree); -// -// return LeafOnlyBinarySeriesSplit{ -// LeafOnlyBinarySPDecompositionTree{raw.lhs}, -// LeafOnlyBinarySPDecompositionTree{raw.rhs}, -// }; -// } -// -// template -// LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split( -// LeafOnlyBinarySPDecompositionTree const &t) { -// GenericBinaryParallelSplit raw = -// require_generic_binary_parallel_split(t.raw_tree); -// -// return LeafOnlyBinaryParallelSplit{ -// LeafOnlyBinarySPDecompositionTree{raw.lhs}, -// LeafOnlyBinarySPDecompositionTree{raw.rhs}, -// }; -// } -// -// template -// LeafLabel require_leaf_only_binary_leaf( -// LeafOnlyBinarySPDecompositionTree const &t) { -// return require_generic_binary_leaf(t.raw_tree); -// } -// -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h deleted file mode 100644 index 58a60c91ba..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H - -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" -// #include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.dtg.h" - -namespace FlexFlow { - -// template -// LeafOnlyBinarySeriesSplit transform( -// LeafOnlyBinarySeriesSplit const &t, -// LeafOnlyBinarySPDecompositionTreeVisitor const -// &visitor) { -// return transform(leaf_only_binary_sp_tree_wrap_series_split(t), visitor); -// } -// -// template -// LeafOnlyBinaryParallelSplit transform( -// LeafOnlyBinaryParallelSplit const &t, -// LeafOnlyBinarySPDecompositionTreeVisitor const -// &visitor) { -// return transform(leaf_only_binary_sp_tree_wrap_parallel_split(t), visitor); -// } -// -// template -// LeafOnlyBinarySPDecompositionTree transform( -// LeafOnlyBinarySPDecompositionTree const &t, -// LeafOnlyBinarySPDecompositionTreeVisitor const -// &visitor) { -// using GenericVisitor = GenericBinarySPDecompositionTreeTransformVisitor; -// -// GenericVisitor generic_visitor = GenericVisitor{ -// [&](std::monostate const &x) { return x; }, -// [&](std::monostate const &x) { return x; }, -// [&](LeafLabel const &t) { return visitor.leaf_func(t); }, -// }; -// -// return LeafOnlyBinarySPDecompositionTree{ -// transform(t.raw_tree, generic_visitor), -// }; -// } - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/src/utils/full_binary_tree/get_node_type.cc b/lib/utils/src/utils/full_binary_tree/get_node_type.cc deleted file mode 100644 index a4c88a03f3..0000000000 --- a/lib/utils/src/utils/full_binary_tree/get_node_type.cc +++ /dev/null @@ -1,7 +0,0 @@ -#include "utils/full_binary_tree/get_node_type.h" - -namespace FlexFlow { - -template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/require.cc b/lib/utils/src/utils/full_binary_tree/require.cc deleted file mode 100644 index e4454927a4..0000000000 --- a/lib/utils/src/utils/full_binary_tree/require.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include "utils/full_binary_tree/require.h" - -namespace FlexFlow { - -template FullBinaryTreeParentNode const & - require_full_binary_tree_parent_node(FullBinaryTree const &); -template int const & - require_full_binary_tree_leaf(FullBinaryTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc deleted file mode 100644 index e66b996721..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" - -namespace FlexFlow { - -template SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc deleted file mode 100644 index 056435531f..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" - -namespace FlexFlow { - -template bool - is_series_split(GenericBinarySPDecompositionTree const &); -template bool - is_parallel_split(GenericBinarySPDecompositionTree const &); -template bool is_leaf(GenericBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc deleted file mode 100644 index 9a3fa879d4..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ /dev/null @@ -1,27 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using SeriesLabel = value_type<0>; -using ParallelLabel = value_type<1>; -using LeafLabel = value_type<2>; - -template - GenericBinarySeriesSplit - require_generic_binary_series_split( - GenericBinarySPDecompositionTree const &); -template - GenericBinaryParallelSplit - require_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &); -template - LeafLabel require_generic_binary_leaf( - GenericBinarySPDecompositionTree const &); -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index bdb59887a1..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" - -namespace FlexFlow { - -// template GenericBinarySeriesSplit -// transform(GenericBinarySeriesSplit const &, -// GenericBinarySPDecompositionTreeTransformVisitor const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc deleted file mode 100644 index 79aac82e12..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" - -namespace FlexFlow { - -// template std::unordered_multiset -// get_leaves(LeafOnlyBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc deleted file mode 100644 index 9f5516fbb3..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" - -namespace FlexFlow { - -// template SPDecompositionTreeNodeType -// get_node_type(LeafOnlyBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc deleted file mode 100644 index 393189f092..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" - -namespace FlexFlow { - -// template bool is_binary_sp_tree_left_associative( -// LeafOnlyBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc deleted file mode 100644 index 0a7724b13a..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" - -namespace FlexFlow { - -// template bool is_binary_sp_tree_right_associative( -// LeafOnlyBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index 6d8ead4165..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h" - -namespace nlohmann { - -// template -// struct adl_serializer<::FlexFlow::LeafOnlyBinarySPDecompositionTree>; - -} // namespace nlohmann diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc deleted file mode 100644 index fe671a8e8f..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree_visitor.h" - -namespace FlexFlow { - -// template -// GenericBinarySPDecompositionTreeVisitor -// generic_visitor_from_leaf_only_visitor(LeafOnlyBinarySPDecompositionTreeVisitor const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc deleted file mode 100644 index cefb44d0c4..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" - -namespace FlexFlow { - -// template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split( -// LeafOnlyBinarySPDecompositionTree const &); -// template LeafOnlyBinaryParallelSplit -// require_leaf_only_binary_parallel_split( -// LeafOnlyBinarySPDecompositionTree const &); -// template int require_leaf_only_binary_leaf( -// LeafOnlyBinarySPDecompositionTree const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index 2421ffdc43..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" - -namespace FlexFlow { - -// template LeafOnlyBinarySeriesSplit transform( LeafOnlyBinarySeriesSplit const &, -// LeafOnlyBinarySPDecompositionTreeVisitor const &); -// template LeafOnlyBinaryParallelSplit transform( -// LeafOnlyBinaryParallelSplit const &, -// LeafOnlyBinarySPDecompositionTreeVisitor const &); -// -// template LeafOnlyBinarySPDecompositionTree transform( -// LeafOnlyBinarySPDecompositionTree const &, -// LeafOnlyBinarySPDecompositionTreeVisitor const &); -// -} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 4e889ceab0..324008fdca 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,4 +1,4 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include From a2b8832a0918de2eb1554f790c1bd6bf38cfc0a5 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 7 Oct 2024 21:26:13 -0700 Subject: [PATCH 29/29] Format --- .../src/export_model_arch.cc | 3 +- .../machine_mapping_problem_tree.h | 7 +- ...omputation_graph_binary_sp_decomposition.h | 16 +- .../pcg/pcg_binary_parallel_split.h | 3 +- .../pcg/pcg_binary_series_split.h | 3 +- .../pcg/pcg_binary_sp_decomposition.h | 15 +- ...racted_tensor_set_movement_across_split.cc | 5 +- .../get_optimal_machine_mapping.cc | 17 +- .../get_machine_mapping_problem_tree.cc | 51 +++-- .../machine_mapping_problem_tree.cc | 94 +++++--- .../machine_mapping/transitive_reduced_pcg.cc | 9 +- ...mputation_graph_binary_sp_decomposition.cc | 206 ++++++++++-------- .../pcg/pcg_binary_parallel_split.cc | 7 +- .../pcg/pcg_binary_series_split.cc | 7 +- .../pcg/pcg_binary_sp_decomposition.cc | 140 +++++++----- ...racted_tensor_set_movement_across_split.cc | 39 ++-- .../get_optimal_machine_mapping.cc | 35 +-- .../get_machine_mapping_problem_tree.cc | 84 +++---- .../v1/v1_binary_sp_decomposition/json.h | 2 +- .../v1/v1_binary_sp_decomposition/json.cc | 75 ++++--- .../v1/v1_binary_sp_decomposition/json.cc | 159 +++++++------- .../include/utils/archetypes/value_type.h | 32 ++- lib/utils/include/utils/fmt/json.h | 6 +- .../full_binary_tree/find_paths_to_leaf.h | 49 +++-- .../full_binary_tree/get_all_leaf_paths.h | 39 ++-- .../utils/full_binary_tree/get_child.h | 2 +- .../utils/full_binary_tree/get_leaves.h | 24 +- .../full_binary_tree/get_num_tree_nodes.h | 17 +- .../full_binary_tree/get_subtree_at_path.h | 26 +-- .../include/utils/full_binary_tree/visit.h | 4 +- .../binary_sp_decomposition_tree.h | 10 +- .../find_paths_to_leaf.h | 15 +- ...ary_sp_decomposition_tree_implementation.h | 96 ++++---- .../get_all_leaf_paths.h | 9 +- .../get_leaves.h | 13 +- .../get_num_tree_nodes.h | 12 +- .../get_subtree_at_path.h | 15 +- .../is_binary_sp_tree_left_associative.h | 41 ++-- .../is_binary_sp_tree_right_associative.h | 41 ++-- .../visit.h | 28 ++- .../series_parallel/series_parallel_splits.h | 8 +- .../utils/json/check_is_json_deserializable.h | 2 +- .../utils/json/check_is_json_serializable.h | 4 +- lib/utils/src/utils/archetypes/value_type.cc | 3 +- lib/utils/src/utils/fmt/json.cc | 3 +- .../full_binary_tree/find_paths_to_leaf.cc | 7 +- .../full_binary_tree/get_all_leaf_paths.cc | 9 +- .../src/utils/full_binary_tree/get_child.cc | 8 +- .../src/utils/full_binary_tree/get_leaves.cc | 7 +- .../full_binary_tree/get_num_tree_nodes.cc | 5 +- .../full_binary_tree/get_subtree_at_path.cc | 9 +- lib/utils/src/utils/full_binary_tree/visit.cc | 7 +- .../binary_sp_decomposition_tree.cc | 106 +++++---- .../find_paths_to_leaf.cc | 12 +- ...ry_sp_decomposition_tree_implementation.cc | 6 +- .../get_all_leaf_paths.cc | 10 +- .../get_leaves.cc | 10 +- .../get_num_tree_nodes.cc | 9 +- .../get_subtree_at_path.cc | 12 +- .../is_binary_sp_tree_left_associative.cc | 10 +- .../is_binary_sp_tree_right_associative.cc | 10 +- .../visit.cc | 16 +- ...ft_associative_binary_sp_tree_from_nary.cc | 30 +-- ...ht_associative_binary_sp_tree_from_nary.cc | 30 +-- .../get_series_parallel_decomposition.cc | 24 +- .../intermediate_sp_decomposition_tree.cc | 44 ++-- ...sitive_reduced_boundary_nodes_for_split.cc | 9 +- ...t_transitive_reduced_edges_across_split.cc | 12 +- ...transitive_reduced_outputs_across_split.cc | 9 +- .../get_leaves.cc | 70 +++--- .../get_num_tree_nodes.cc | 46 ++-- .../is_binary_sp_tree_left_associative.cc | 22 +- .../is_binary_sp_tree_right_associative.cc | 22 +- ...ft_associative_binary_sp_tree_from_nary.cc | 13 +- .../nary_sp_tree_from_binary.cc | 28 +-- ...ht_associative_binary_sp_tree_from_nary.cc | 13 +- 76 files changed, 1146 insertions(+), 955 deletions(-) diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 9da33023a0..64419acce4 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -104,7 +104,8 @@ tl::expected to_v1_including_node_numbering(computation_graph); V1ComputationGraph v1_cg = v1_result.first; bidict layer_numbering = v1_result.second; - V1BinarySPDecomposition v1_sp_decomposition = to_v1(sp_decomposition, layer_numbering); + V1BinarySPDecomposition v1_sp_decomposition = + to_v1(sp_decomposition, layer_numbering); return JsonSPModelExport{ v1_sp_decomposition, diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 2eccd36719..29e9e7c90b 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -10,8 +10,11 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation - generic_binary_sp_impl_for_mm_problem_tree(); +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index eb50ee365e..fdc80a1e37 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -12,16 +12,18 @@ namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, - ComputationGraphBinarySeriesSplit, - ComputationGraphBinaryParallelSplit, - layer_guid_t> generic_impl_for_computation_graph_sp_tree(); + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree(); SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &); +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &); std::optional get_computation_graph_left_assoc_binary_sp_decomposition( @@ -34,7 +36,7 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &); std::unordered_multiset get_layers(ComputationGraphBinarySPDecomposition const &); -V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, bidict const &layer_numbering); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h index 05a1ae1169..f348b1a851 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -6,7 +6,8 @@ namespace FlexFlow { -BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &); +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h index 83e53e3d41..0842ffb48f 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -6,7 +6,8 @@ namespace FlexFlow { -BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); +BinarySeriesSplit + binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index e8c02ebfb5..86fa1a59aa 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -13,13 +13,14 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t> generic_impl_for_pcg_sp_tree(); - -BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); std::optional get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 53b8d5bdd6..0e0f60c891 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -16,9 +16,8 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) - -> AbstractedSingleTensorMovement - { + auto get_movement_for_tensor = + [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { std::unordered_set tensor_edges = filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index bf44ef0fd7..10abd7ff90 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -45,7 +45,8 @@ MachineMappingResult } } - MachineMappingResult result = problem_tree.visit(overload{ + MachineMappingResult result = + problem_tree.visit(overload{ [&](MMProblemTreeSeriesSplit const &series_split) { return get_optimal_machine_mapping( result_cache, @@ -86,8 +87,9 @@ MachineMappingResult [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedOpCostEstimateKey leaf = mm_problem_tree_get_subtree_at_path( - MachineMappingProblemTree{series_split}, l) - .value().get(); + MachineMappingProblemTree{series_split}, l) + .value() + .get(); return context.allowed_machine_views(leaf, resources); }); return transform( @@ -130,7 +132,8 @@ MachineMappingResult }; MachineMappingResult result = infeasible_machine_mapping_result(); - AbstractedTensorSetMovement tensor_movement = series_split.tensor_set_movement; + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : @@ -178,9 +181,9 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult series_result = [&] { MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), - /*left_child=*/lhs, - /*right_child=*/rhs, + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, }; return get_optimal_machine_mapping(result_cache, diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index ada271580f..367af3701e 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -20,32 +20,31 @@ MachineMappingProblemTree get_machine_mapping_problem_tree( to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { return sp.visit(overload{ - [&](PCGBinarySeriesSplit const &series) { - AbstractedTensorSetMovement tensor_movement = - get_abstracted_tensor_set_movement_across_split(tr_pcg, - series); - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/tensor_movement, - /*lhs=*/to_problem_tree(series.get_left_child()), - /*rhs=*/to_problem_tree(series.get_right_child()), - }, - }; - }, - [&](PCGBinaryParallelSplit const ¶llel) { - return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - to_problem_tree(parallel.get_left_child()), - to_problem_tree(parallel.get_right_child()), - }, - }; - }, - [&](parallel_layer_guid_t const &leaf) { - return MachineMappingProblemTree{ - get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), - }; - }, - }); + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; + }, + [&](parallel_layer_guid_t const &leaf) { + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); }; return to_problem_tree(sp_decomposition_tree); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index a5b3cab43e..1e39a7be19 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -5,46 +5,69 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation - generic_binary_sp_impl_for_mm_problem_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { return GenericBinarySPDecompositionTreeImplementation< - MachineMappingProblemTree, - MMProblemTreeSeriesSplit, - MMProblemTreeParallelSplit, - UnmappedOpCostEstimateKey>{ - /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeSeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](MachineMappingProblemTree const &tree) -> UnmappedOpCostEstimateKey const & { - return tree.get(); - }, + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](MachineMappingProblemTree const &tree) + -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, }; } SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { - return tree.visit(overload { - [](MMProblemTreeSeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](MMProblemTreeParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](UnmappedOpCostEstimateKey const &) { return SPDecompositionTreeNodeType::NODE; }, + return tree.visit(overload{ + [](MMProblemTreeSeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](UnmappedOpCostEstimateKey const &) { + return SPDecompositionTreeNodeType::NODE; + }, }); } @@ -61,7 +84,8 @@ std::unordered_set std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, BinaryTreePath const &path) { - return get_subtree_at_path(tree, generic_binary_sp_impl_for_mm_problem_tree(), path); + return get_subtree_at_path( + tree, generic_binary_sp_impl_for_mm_problem_tree(), path); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index 004aca6a81..96c8106cad 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -41,7 +41,8 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); std::unordered_set raw_edges = get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); @@ -57,7 +58,8 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); std::unordered_set raw_outputs = get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); @@ -72,7 +74,8 @@ PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); SplitBoundaryNodes raw_boundary = get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc index f26b899109..32fb53b58a 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -10,55 +10,68 @@ namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, - ComputationGraphBinarySeriesSplit, - ComputationGraphBinaryParallelSplit, - layer_guid_t> generic_impl_for_computation_graph_sp_tree() { - - return GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySPDecomposition, ComputationGraphBinarySeriesSplit, ComputationGraphBinaryParallelSplit, - layer_guid_t>{ - /*series_get_left_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](ComputationGraphBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinarySeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinaryParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](ComputationGraphBinarySPDecomposition const &tree) -> layer_guid_t const & { - return tree.get(); - }, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> SPDecompositionTreeNodeType { return get_node_type(tree); }, + /*require_series=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> layer_guid_t const & { return tree.get(); }, }; } SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &tree) { - return tree.visit(overload { - [](ComputationGraphBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](ComputationGraphBinaryParallelSplit const ¶llel) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](layer_guid_t const &leaf) { - return SPDecompositionTreeNodeType::NODE; - }, + return tree.visit(overload{ + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, }); } @@ -66,30 +79,35 @@ layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { return tree.get(); } -ComputationGraphBinarySPDecomposition - computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &bin) { - return bin.visit(overload { - [](BinarySeriesSplit const &series) { - return ComputationGraphBinarySPDecomposition{ - ComputationGraphBinarySeriesSplit{ - computation_graph_sp_decomp_from_binary_sp_decomp(series.get_left_child()), - computation_graph_sp_decomp_from_binary_sp_decomp(series.get_right_child()), - }, - }; - }, - [](BinaryParallelSplit const ¶llel) { - return ComputationGraphBinarySPDecomposition{ - ComputationGraphBinaryParallelSplit{ - computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_left_child()), - computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_right_child()), - }, - }; - }, - [](Node const &node) { - return ComputationGraphBinarySPDecomposition{ - layer_guid_t{node}, - }; - }, +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &bin) { + return bin.visit(overload{ + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, }); } @@ -130,11 +148,13 @@ std::optional } bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { - return is_binary_sp_tree_left_associative(tree, generic_impl_for_computation_graph_sp_tree()); + return is_binary_sp_tree_left_associative( + tree, generic_impl_for_computation_graph_sp_tree()); } bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { - return is_binary_sp_tree_right_associative(tree, generic_impl_for_computation_graph_sp_tree()); + return is_binary_sp_tree_right_associative( + tree, generic_impl_for_computation_graph_sp_tree()); } std::unordered_multiset @@ -142,31 +162,31 @@ std::unordered_multiset return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); } -V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &tree, - bidict const &layer_numbering) { - return tree.visit(overload { - [&](ComputationGraphBinarySeriesSplit const &series) { - return V1BinarySPDecomposition{ - V1BinarySeriesSplit{ - to_v1(series.get_left_child(), layer_numbering), - to_v1(series.get_right_child(), layer_numbering), - }, - }; - }, - [&](ComputationGraphBinaryParallelSplit const ¶llel) { - return V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - to_v1(parallel.get_left_child(), layer_numbering), - to_v1(parallel.get_right_child(), layer_numbering), - }, - }; - }, - [&](layer_guid_t const &layer) { - return V1BinarySPDecomposition{ - layer_numbering.at_r(layer), - }; - } - }); +V1BinarySPDecomposition + to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit( + overload{[&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + }}); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc index 7e6327d06a..657a3c3166 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -3,10 +3,11 @@ namespace FlexFlow { -BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &pcg_split) { +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &pcg_split) { return BinaryParallelSplit{ - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc index b0fec5f6ce..304ad224b1 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -3,10 +3,11 @@ namespace FlexFlow { -BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &pcg_split) { +BinarySeriesSplit binary_series_split_from_pcg_series_split( + PCGBinarySeriesSplit const &pcg_split) { return BinarySeriesSplit{ - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 0555c2a14d..5eb993c6ef 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,70 +1,83 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/overload.h" namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t> generic_impl_for_pcg_sp_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree() { return GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t>{ - /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinaryParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](PCGBinarySPDecomposition const &tree) -> parallel_layer_guid_t const & { - return tree.get(); - }, + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](PCGBinarySPDecomposition const &tree) + -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](PCGBinarySPDecomposition const &tree) + -> parallel_layer_guid_t const & { + return tree.get(); + }, }; } - -BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { - return pcg_tree.visit(overload { - [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - binary_series_split_from_pcg_series_split(series), - }; - }, - [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{ - binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), - }, - }; - }, - [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - layer.raw_graph_node, - }; - }, +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload{ + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, }); } @@ -78,11 +91,18 @@ std::unordered_multiset return get_leaves(tree, generic_impl_for_pcg_sp_tree()); } -SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &tree) { - return tree.visit(overload { - [](PCGBinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](PCGBinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](parallel_layer_guid_t const &) { return SPDecompositionTreeNodeType::NODE; }, +SPDecompositionTreeNodeType + get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](PCGBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](PCGBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](parallel_layer_guid_t const &) { + return SPDecompositionTreeNodeType::NODE; + }, }); } diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index b63ce95ae0..5c8ea1c0f1 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -9,11 +9,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_abstracted_tensor_set_movement_across_split") { - auto make_series_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; }; @@ -70,8 +72,8 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_leaf(input1.parallel_layer), - make_leaf(input2.parallel_layer), + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), }; AbstractedTensorSetMovement result = @@ -94,9 +96,9 @@ TEST_SUITE(FF_TEST_SUITE) { pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_series_split(make_leaf(input.parallel_layer), - make_leaf(layer_1.parallel_layer)), - make_leaf(layer_2.parallel_layer), + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), }; AbstractedTensorSetMovement result = @@ -140,12 +142,11 @@ TEST_SUITE(FF_TEST_SUITE) { {relu_output_attrs}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_series_split( - make_leaf(input.parallel_layer), - make_series_split( - make_leaf(layer_1.parallel_layer), - make_leaf(layer_2.parallel_layer))), - make_leaf(layer_3.parallel_layer), + make_series_split( + make_leaf(input.parallel_layer), + make_series_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), }; AbstractedTensorSetMovement result = @@ -188,9 +189,9 @@ TEST_SUITE(FF_TEST_SUITE) { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split(make_leaf(input.parallel_layer), - make_leaf(layer_1.parallel_layer)), + make_leaf(layer_1.parallel_layer)), make_parallel_split(make_leaf(layer_2.parallel_layer), - make_leaf(layer_3.parallel_layer)), + make_leaf(layer_3.parallel_layer)), }; AbstractedTensorSetMovement result = @@ -244,12 +245,10 @@ TEST_SUITE(FF_TEST_SUITE) { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( make_leaf(input.parallel_layer), - make_parallel_split( - make_leaf(layer_1.parallel_layer), - make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), make_parallel_split(make_leaf(layer_3.parallel_layer), - make_leaf(layer_4.parallel_layer)) - }; + make_leaf(layer_4.parallel_layer))}; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 7194fc038c..0a874948e4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -16,28 +16,29 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { - return MachineMappingProblemTree{k}; + return MachineMappingProblemTree{k}; }; - auto make_series_split = [](AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/tensor_set_movement, - /*left_child=*/lhs, - /*right_child=*/rhs, - }, - }; - }; + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; auto make_parallel_split = [](MachineMappingProblemTree const &lhs, MachineMappingProblemTree const &rhs) { return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - /*left_child=*/lhs, - /*right_child=*/rhs, - }, + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, }; }; @@ -200,7 +201,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in parallel") { MachineMappingProblemTree problem_tree = - make_parallel_split(make_leaf(k1), make_leaf(k2)); + make_parallel_split(make_leaf(k1), make_leaf(k2)); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 09d4af7756..06ab1e5b8c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -12,24 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { return PCGBinarySPDecomposition{l}; }; - auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - PCGBinarySeriesSplit{ - lhs, - rhs, - }, + PCGBinarySeriesSplit{ + lhs, + rhs, + }, }; }; auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { - return PCGBinarySPDecomposition{ - PCGBinaryParallelSplit{ - lhs, - rhs, - }, + PCGBinaryParallelSplit{ + lhs, + rhs, + }, }; }; @@ -37,28 +36,29 @@ TEST_SUITE(FF_TEST_SUITE) { return MachineMappingProblemTree{k}; }; - auto mm_problem_tree_make_series = [](AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - tensor_set_movement, - lhs, - rhs, - }, - }; - }; - - auto mm_problem_tree_make_parallel = [](MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { + auto mm_problem_tree_make_series = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; - return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - lhs, - rhs, - }, - }; - }; + auto mm_problem_tree_make_parallel = + [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -113,7 +113,8 @@ TEST_SUITE(FF_TEST_SUITE) { UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = PCGBinarySPDecomposition{input_layer}; + PCGBinarySPDecomposition sp_decomposition = + PCGBinarySPDecomposition{input_layer}; MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); @@ -198,9 +199,9 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_parallel( - mm_problem_tree_make_leaf(input1_key), - mm_problem_tree_make_leaf(input2_key)); + MachineMappingProblemTree correct = + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); CHECK(result == correct); } @@ -240,10 +241,10 @@ TEST_SUITE(FF_TEST_SUITE) { /*output_shapes=*/{ew_op_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = pcg_make_series( - pcg_make_parallel(pcg_make_leaf(input1_layer), - pcg_make_leaf(input2_layer)), - pcg_make_leaf(ew_op_layer)); + PCGBinarySPDecomposition sp_decomposition = + pcg_make_series(pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); @@ -278,9 +279,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, }}, /*pre=*/ - mm_problem_tree_make_parallel( - mm_problem_tree_make_leaf(input1_key), - mm_problem_tree_make_leaf(input2_key)), + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), /*post=*/mm_problem_tree_make_leaf(ew_op_key)); CHECK(result == correct); diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h index 62cfd6ec62..a1ca0aceed 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H -#include #include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include namespace nlohmann { diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc index 3adb79eb8f..5341e03c0a 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -1,75 +1,84 @@ #include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" #include "utils/exception.h" -#include "utils/overload.h" #include "utils/fmt/json.h" +#include "utils/overload.h" using namespace ::FlexFlow; namespace nlohmann { -V1BinarySPDecomposition adl_serializer::from_json(json const &j) { +V1BinarySPDecomposition + adl_serializer::from_json(json const &j) { std::string type = j.at("type").get(); if (type == "series") { return V1BinarySPDecomposition{ - j.get(), + j.get(), }; } else if (type == "parallel") { return V1BinarySPDecomposition{ - j.get(), + j.get(), }; } else if (type == "leaf") { return V1BinarySPDecomposition{ - j.at("value").get(), + j.at("value").get(), }; } else { - throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j)); + throw mk_runtime_error(fmt::format( + "Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" " + "in json object: {}", + type, + j)); } } -void adl_serializer::to_json(json &j, V1BinarySPDecomposition const &tree) { - tree.visit(overload { - [&](V1BinarySeriesSplit const &split) { - j = split; - j["type"] = "series"; - return std::monostate{}; - }, - [&](V1BinaryParallelSplit const &split) { - j = split; - j["type"] = "parallel"; - return std::monostate{}; - }, - [&](int leaf) { - j["value"] = leaf; - j["type"] = "leaf"; - return std::monostate{}; - }, +void adl_serializer::to_json( + json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload{ + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, }); } -V1BinarySeriesSplit adl_serializer::from_json(json const &j) { +V1BinarySeriesSplit + adl_serializer::from_json(json const &j) { return V1BinarySeriesSplit{ - /*lhs=*/j.at("left_child").get(), - /*rhs=*/j.at("right_child").get(), + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), }; } -void adl_serializer::to_json(json &j, V1BinarySeriesSplit const &series) { +void adl_serializer::to_json( + json &j, V1BinarySeriesSplit const &series) { j["left_child"] = series.get_left_child(); j["right_child"] = series.get_right_child(); } -V1BinaryParallelSplit adl_serializer::from_json(json const &j) { +V1BinaryParallelSplit + adl_serializer::from_json(json const &j) { return V1BinaryParallelSplit{ - /*lhs=*/j.at("left_child").get(), - /*rhs=*/j.at("right_child").get(), + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), }; } -void adl_serializer::to_json(json &j, V1BinaryParallelSplit const &series) { +void adl_serializer::to_json( + json &j, V1BinaryParallelSplit const &series) { j["left_child"] = series.get_left_child(); j["right_child"] = series.get_right_child(); } - -} // namespace FlexFlow +} // namespace nlohmann diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc index e9f2573914..9068e14517 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -6,46 +6,46 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ - V1BinarySeriesSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, - }, + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, }, - V1BinarySPDecomposition{3}, - }, }; nlohmann::json example_json = { - {"type", "series"}, - { - "left_child", + {"type", "series"}, { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, - }, - }, }; SUBCASE("to_json") { @@ -56,7 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("from_json") { - V1BinarySPDecomposition result = example_json.get(); + V1BinarySPDecomposition result = + example_json.get(); V1BinarySPDecomposition correct = example_tree; CHECK(result == correct); @@ -65,43 +66,43 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, }, - }, - V1BinarySPDecomposition{3}, + V1BinarySPDecomposition{3}, }; nlohmann::json example_json = { - { - "left_child", { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, - }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, }, - }, }; SUBCASE("to_json") { @@ -121,43 +122,43 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, }, - }, - V1BinarySPDecomposition{3}, + V1BinarySPDecomposition{3}, }; nlohmann::json example_json = { - { - "left_child", { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, - }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, }, - }, }; SUBCASE("to_json") { diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h index 4831afa408..1635747612 100644 --- a/lib/utils/include/utils/archetypes/value_type.h +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -10,14 +10,26 @@ template struct value_type { value_type() = delete; - value_type(value_type const &) { assert(false); } - value_type &operator=(value_type const &) { assert(false); } - - value_type(value_type &&) { assert(false); } - value_type &operator=(value_type &&) { assert(false); } - - bool operator==(value_type const &) const { assert(false); } - bool operator!=(value_type const &) const { assert(false); } + value_type(value_type const &) { + assert(false); + } + value_type &operator=(value_type const &) { + assert(false); + } + + value_type(value_type &&) { + assert(false); + } + value_type &operator=(value_type &&) { + assert(false); + } + + bool operator==(value_type const &) const { + assert(false); + } + bool operator!=(value_type const &) const { + assert(false); + } }; } // namespace FlexFlow @@ -27,10 +39,10 @@ namespace std { template struct hash<::FlexFlow::value_type> { size_t operator()(::FlexFlow::value_type const &) const { - assert (false); + assert(false); }; }; -} +} // namespace std #endif diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h index 15ad0de4e0..c7aa87e3eb 100644 --- a/lib/utils/include/utils/fmt/json.h +++ b/lib/utils/include/utils/fmt/json.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H -#include #include +#include namespace fmt { template struct formatter<::nlohmann::json, Char> : formatter { - template + template auto format(::nlohmann::json const &j, FormatContext &ctx) { - std::ostringstream oss; + std::ostringstream oss; oss << j; return formatter::format(oss.str(), ctx); } diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 07928f7871..9cf5d63210 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -11,27 +11,34 @@ namespace FlexFlow { template -std::unordered_set - find_paths_to_leaf(Tree const &tree, FullBinaryTreeImplementation const &impl, Leaf const &needle) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::unordered_set { - return set_union( - transform(find_paths_to_leaf(impl.get_left_child(parent), impl, needle), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(find_paths_to_leaf(impl.get_right_child(parent), impl, needle), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - }, - [&](Leaf const &leaf) -> std::unordered_set { - if (leaf == needle) { - return {binary_tree_root_path()}; - } else { - return {}; - } - }, +std::unordered_set find_paths_to_leaf( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform( + find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform( + find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 20c2eb8b62..822acfe9ee 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -12,24 +12,27 @@ namespace FlexFlow { template -std::unordered_set - get_all_leaf_paths(Tree const &tree, - FullBinaryTreeImplementation const &impl) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::unordered_set { - return set_union( - transform(get_all_leaf_paths(impl.get_left_child(parent), impl), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(impl.get_right_child(parent), impl), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - }, - [&](Leaf const &leaf) -> std::unordered_set { - return {binary_tree_root_path()}; - }, +std::unordered_set get_all_leaf_paths( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index 5c1e21014d..7517028ec0 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -9,7 +9,7 @@ namespace FlexFlow { template -Tree get_child(Parent const &parent, +Tree get_child(Parent const &parent, FullBinaryTreeImplementation const &impl, BinaryTreePathEntry const &e) { switch (e) { diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index 87633f29a9..8f9d8e919f 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -13,19 +13,17 @@ std::unordered_multiset get_leaves(Tree const &tree, FullBinaryTreeImplementation const &impl) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) - -> std::unordered_multiset - { - return multiset_union(get_leaves(impl.get_left_child(parent), impl), - get_leaves(impl.get_right_child(parent), impl)); - }, - [](Leaf const &leaf) - -> std::unordered_multiset - { - return {leaf}; - }, - }; + auto visitor = + FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_multiset { + return multiset_union( + get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) -> std::unordered_multiset { + return {leaf}; + }, + }; return visit(tree, impl, visitor); } diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h index 69d4e2ea49..922a42242c 100644 --- a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -7,20 +7,21 @@ namespace FlexFlow { template -int get_num_tree_nodes(Tree const &tree, - FullBinaryTreeImplementation const &impl) { - +int get_num_tree_nodes( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor{ - [&](Parent const &parent) -> int { - return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + get_num_tree_nodes(impl.get_right_child(parent), impl); - }, - [](Leaf const &) -> int { return 1; }, + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, }; return visit(tree, impl, visitor); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index bbdc74850c..83ce1367b9 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -3,31 +3,29 @@ #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/visit.h" #include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" #include namespace FlexFlow { template -std::optional - get_subtree_at_path(Tree const &tree, - FullBinaryTreeImplementation const &impl, - BinaryTreePath const &p) { +std::optional get_subtree_at_path( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + BinaryTreePath const &p) { if (p == binary_tree_root_path()) { return tree; } auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::optional { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - - return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); - }, - [](Leaf const &leaf) -> std::optional { - return std::nullopt; - }, + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { return std::nullopt; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 87aa115c8c..832d39bdff 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H #include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" namespace FlexFlow { template -Result visit(Tree const &tree, +Result visit(Tree const &tree, FullBinaryTreeImplementation const &impl, FullBinaryTreeVisitor const &visitor) { if (impl.is_leaf(tree)) { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index 28e9beeebd..de48cd17e9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -10,11 +10,11 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> generic_impl_for_binary_sp_tree(); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree(); bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h index 9eaf84149f..105f5490a4 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -7,12 +7,15 @@ namespace FlexFlow { template -std::unordered_set - find_paths_to_leaf(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - Leaf const &needle) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); +std::unordered_set find_paths_to_leaf( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return find_paths_to_leaf(tree, full_binary_impl, needle); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h index fd29b69567..0bddbee81c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -1,59 +1,69 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/exception.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" -#include +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/overload.h" -#include "utils/exception.h" +#include namespace FlexFlow { template -FullBinaryTreeImplementation, Leaf> - get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &impl) { +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &impl) { using Parent = std::variant; auto full_binary_impl = FullBinaryTreeImplementation{ - /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { - return std::visit(overload { - [&](Series const &series) -> Tree const & { - return impl.series_get_left_child(series); - }, - [&](Parallel const ¶llel) -> Tree const & { - return impl.parallel_get_left_child(parallel); - }, - }, parent); - }, - /*get_right_child=*/[impl](Parent const &parent) -> Tree const & { - return std::visit(overload { - [&](Series const &series) -> Tree const & { - return impl.series_get_right_child(series); - }, - [&](Parallel const ¶llel) -> Tree const & { - return impl.parallel_get_right_child(parallel); - }, - }, parent); - }, - /*is_leaf=*/[impl](Tree const &tree) -> bool { - return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; - }, - /*require_leaf=*/[impl](Tree const &tree) -> Leaf const & { - return impl.require_leaf(tree); - }, - /*require_parent=*/[impl](Tree const &tree) -> Parent { - SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); - switch (node_type) { - case SPDecompositionTreeNodeType::SERIES: - return Parent{impl.require_series(tree)}; - case SPDecompositionTreeNodeType::PARALLEL: - return Parent{impl.require_parallel(tree)}; - default: - throw mk_runtime_error(fmt::format("Unexpected SPDecompositionTreeNodeType: {}", node_type)); - } - } - }; + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, + parent); + }, + /*get_right_child=*/ + [impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, + parent); + }, + /*is_leaf=*/ + [impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/ + [impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/ + [impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format( + "Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + }}; return full_binary_impl; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h index 4637cbd81c..b0bb8355db 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -9,10 +9,13 @@ namespace FlexFlow { template std::unordered_set get_all_leaf_paths( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_all_leaf_paths(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 7bbc5cf603..c543375148 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -7,12 +7,15 @@ namespace FlexFlow { template -std::unordered_multiset - get_leaves(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { +std::unordered_multiset get_leaves( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_leaves(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index b5fe0d4131..4678e0c0f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -7,11 +7,15 @@ namespace FlexFlow { template -int get_num_tree_nodes(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { +int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_num_tree_nodes(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h index 8a687d9702..c48185fb7f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -8,12 +8,15 @@ namespace FlexFlow { template -std::optional - get_subtree_at_path(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - BinaryTreePath const &path) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); +std::optional get_subtree_at_path( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + BinaryTreePath const &path) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_subtree_at_path(tree, full_binary_impl, path); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 17ff9c5dd1..68e0a3af32 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -9,22 +9,33 @@ namespace FlexFlow { template bool is_binary_sp_tree_left_associative( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { + GenericBinarySPDecompositionTreeImplementation const &impl) { - auto visitor = GenericBinarySPDecompositionTreeVisitor{ - [&](Series const &split) { - return impl.get_node_type(impl.series_get_right_child(split)) != SPDecompositionTreeNodeType::SERIES && - is_binary_sp_tree_left_associative(impl.series_get_left_child(split), impl) && - is_binary_sp_tree_left_associative(impl.series_get_right_child(split), impl); - }, - [&](Parallel const &split) { - return impl.get_node_type(impl.parallel_get_right_child(split)) != SPDecompositionTreeNodeType::PARALLEL && - is_binary_sp_tree_left_associative(impl.parallel_get_left_child(split), impl) && - is_binary_sp_tree_left_associative(impl.parallel_get_right_child(split), impl); - }, - [&](Leaf const &leaf) { - return true; - }, + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index b284ce763e..7042765203 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -9,21 +9,32 @@ namespace FlexFlow { template bool is_binary_sp_tree_right_associative( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { - auto visitor = GenericBinarySPDecompositionTreeVisitor{ - [&](Series const &split) { - return impl.get_node_type(impl.series_get_left_child(split)) != SPDecompositionTreeNodeType::SERIES && - is_binary_sp_tree_right_associative(impl.series_get_left_child(split), impl) && - is_binary_sp_tree_right_associative(impl.series_get_right_child(split), impl); - }, - [&](Parallel const &split) { - return impl.get_node_type(impl.parallel_get_left_child(split)) != SPDecompositionTreeNodeType::PARALLEL && - is_binary_sp_tree_right_associative(impl.parallel_get_left_child(split), impl) && - is_binary_sp_tree_right_associative(impl.parallel_get_right_child(split), impl); - }, - [&](Leaf const &leaf) { - return true; - }, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 89bb45f0fb..c06db135b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -1,17 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" namespace FlexFlow { -template -ReturnType visit( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - GenericBinarySPDecompositionTreeVisitor const &visitor) { +template +ReturnType + visit(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); switch (node_type) { case SPDecompositionTreeNodeType::SERIES: { @@ -27,7 +38,8 @@ ReturnType visit( return result; } default: - throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType value: {}", node_type)); + throw mk_runtime_error(fmt::format( + "Unknown SPDecompositionTreeNodeType value: {}", node_type)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 98eb913aeb..7374b45a60 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#include "utils/graph/series_parallel/series_split.dtg.h" #include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" namespace FlexFlow { // struct SeriesSplit { // public: // SeriesSplit() = delete; -// explicit SeriesSplit(std::vector> const &); -// explicit SeriesSplit( +// explicit SeriesSplit(std::vector> const +// &); explicit SeriesSplit( // std::initializer_list> const &); // // bool operator==(SeriesSplit const &) const; @@ -71,6 +71,6 @@ namespace FlexFlow { // size_t operator()(::FlexFlow::ParallelSplit const &) const; // }; -} // namespace std +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h index f72485dcbd..dd5f397c19 100644 --- a/lib/utils/include/utils/json/check_is_json_deserializable.h +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -6,7 +6,7 @@ namespace FlexFlow { #define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_deserializable::value, \ + static_assert(::FlexFlow::is_json_deserializable::value, \ #TYPENAME " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h index f3d1a058f8..dfcb26081d 100644 --- a/lib/utils/include/utils/json/check_is_json_serializable.h +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -5,8 +5,8 @@ namespace FlexFlow { -#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_serializable::value, \ +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ #TYPENAME " should be json serializeable") } // namespace FlexFlow diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc index 9c197112a1..f7da47d8f9 100644 --- a/lib/utils/src/utils/archetypes/value_type.cc +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -2,7 +2,6 @@ namespace FlexFlow { -template - struct value_type<0>; +template struct value_type<0>; } // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc index 783b75973c..49ad57fba7 100644 --- a/lib/utils/src/utils/fmt/json.cc +++ b/lib/utils/src/utils/fmt/json.cc @@ -2,7 +2,6 @@ namespace fmt { -template - struct formatter<::nlohmann::json, char>; +template struct formatter<::nlohmann::json, char>; } diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc index b3ddab6cbc..47845720ed 100644 --- a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -7,8 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::unordered_set - find_paths_to_leaf(Tree const &, FullBinaryTreeImplementation const &, Leaf const &); +template std::unordered_set + find_paths_to_leaf(Tree const &, + FullBinaryTreeImplementation const &, + Leaf const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc index cbbffb0b4a..b4d8aa1011 100644 --- a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -3,9 +3,10 @@ namespace FlexFlow { -template - std::unordered_set - get_all_leaf_paths(value_type<0> const &, - FullBinaryTreeImplementation, value_type<1>, value_type<2>> const &); +template std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, + value_type<1>, + value_type<2>> const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc index 3283db398b..19362ae510 100644 --- a/lib/utils/src/utils/full_binary_tree/get_child.cc +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -7,9 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - Tree get_child(Parent const &, - FullBinaryTreeImplementation const &, - BinaryTreePathEntry const &); +template Tree + get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc index 18221cd98a..0d7e9106f6 100644 --- a/lib/utils/src/utils/full_binary_tree/get_leaves.cc +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -7,9 +7,8 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::unordered_multiset - get_leaves(Tree const &, - FullBinaryTreeImplementation const &); +template std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc index b651309c32..7a99dd60fa 100644 --- a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -7,8 +7,7 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - int get_num_tree_nodes(Tree const &, - FullBinaryTreeImplementation const &); +template int get_num_tree_nodes( + Tree const &, FullBinaryTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc index 689237752a..1eea13fedd 100644 --- a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -7,10 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::optional - get_subtree_at_path(Tree const &, - FullBinaryTreeImplementation const &, - BinaryTreePath const &); +template std::optional get_subtree_at_path( + Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc index c8a36dff66..4a4f7c9302 100644 --- a/lib/utils/src/utils/full_binary_tree/visit.cc +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -2,9 +2,8 @@ namespace FlexFlow { -template - int visit(std::string const &, - FullBinaryTreeImplementation const &, - FullBinaryTreeVisitor const &); +template int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 56718fa71f..62489ff75f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,66 +1,84 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> generic_impl_for_binary_sp_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree() { return GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> - { - /*series_get_left_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { - return tree.require_series(); - }, - /*require_parallel=*/[](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { - return tree.require_parallel(); - }, - /*require_leaf=*/[](BinarySPDecompositionTree const &tree) -> Node const & { - return tree.require_node(); - }, + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node>{ + /*series_get_left_child=*/[](BinarySeriesSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/ + [](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/ + [](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, }; } bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { - return is_binary_sp_tree_left_associative(tree, generic_impl_for_binary_sp_tree()); + return is_binary_sp_tree_left_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tree) { - return is_binary_sp_tree_right_associative(tree, generic_impl_for_binary_sp_tree()); +bool is_binary_sp_tree_right_associative( + BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, + generic_impl_for_binary_sp_tree()); } -std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tree) { +std::unordered_multiset + get_leaves(BinarySPDecompositionTree const &tree) { return get_leaves(tree, generic_impl_for_binary_sp_tree()); } -SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tree) { - return tree.visit(overload { - [](BinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](BinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, +SPDecompositionTreeNodeType + get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload{ + [](BinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](BinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, }); } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc index e30b9f97a6..07e2c3e3e3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -8,10 +8,12 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_set - find_paths_to_leaf(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - Leaf const &); +template std::unordered_set find_paths_to_leaf( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc index bc6b4b1ccf..56a6d0cc85 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -9,6 +9,10 @@ using Parallel = value_type<2>; using Leaf = value_type<3>; FullBinaryTreeImplementation, Leaf> - get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &); + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc index 7bc9c4bfe4..71d3f6ac31 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_set get_all_leaf_paths( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 6c80f4ba9b..3bb90bfa32 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_multiset - get_leaves(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &); +template std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 89e8deb437..3d166145c1 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -8,8 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - int get_num_tree_nodes(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc index e95284fa5e..d1d8079c0b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -8,10 +8,12 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::optional - get_subtree_at_path(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - BinaryTreePath const &); +template std::optional get_subtree_at_path( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 2b478edb20..69cbb28582 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - bool is_binary_sp_tree_left_associative( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index e50a861219..584099e33e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -8,9 +8,11 @@ using Series = value_type<2>; using Parallel = value_type<3>; using Leaf = value_type<4>; -template - bool is_binary_sp_tree_right_associative( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc index b7175e0e1b..056ae2a8d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -9,10 +9,16 @@ using Series = value_type<2>; using Parallel = value_type<3>; using Leaf = value_type<4>; -template - ReturnType visit( - Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - GenericBinarySPDecompositionTreeVisitor const &); +template ReturnType + visit(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 33ac5f00e9..69b2ebea8e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -23,26 +23,28 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { std::vector children = transform(s.children, from_series_child); - return foldl1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinarySeriesSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) -> BinarySPDecompositionTree { std::vector children = transform(vector_of(s.get_children()), from_parallel_child); - return foldl1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; + }); }; from_parallel_child = [&](std::variant const &v) diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 2477140d71..478d90e0c3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -21,25 +21,27 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( auto from_series = [&](SeriesSplit const &s) { std::vector children = transform(s.children, from_series_child); - return foldr1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinarySeriesSplit{x, accum}, - }; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) { std::vector children = transform(vector_of(s.get_children()), from_parallel_child); - return foldr1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{x, accum}, - }; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; + }); }; from_parallel_child = [&](std::variant const &v) diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 84ef2fc106..cd29af59a0 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -33,10 +33,10 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); std::unordered_map - ttsp_edge_to_sp_tree = - map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { return BinarySPDecompositionTree{n}; }); + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -47,10 +47,10 @@ std::optional auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinaryParallelSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); @@ -67,10 +67,10 @@ std::optional MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinarySeriesSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 07df693ae1..410a40236d 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -49,27 +49,29 @@ std::variant flatten_ast( std::variant from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return binary.template visit>(overload{ - [](Node const &n) { return n; }, - [](BinarySeriesSplit const &s) { - return IntermediateSpDecompositionTree{ - SplitType::SERIES, - { - from_binary_sp_tree(s.get_left_child()), - from_binary_sp_tree(s.get_right_child()), - }, - }; - }, - [](BinaryParallelSplit const &p) { - return IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - { - from_binary_sp_tree(p.get_left_child()), - from_binary_sp_tree(p.get_right_child()), - }, - }; - }, - }); + return binary + .template visit>( + overload{ + [](Node const &n) { return n; }, + [](BinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), + }, + }; + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc index 9364e02afc..c35789044d 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -10,14 +10,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; - + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index 1b49c7218d..1f8f66b932 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -12,17 +12,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_edges_across_split") { DataflowGraph g = DataflowGraph::create(); - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("multiple nodes with edges across") { NodeAddedResult n1_added = g.add_node({}, 1); @@ -82,7 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { get_dataflow_graph_transitive_reduction(g); BinarySeriesSplit split = BinarySeriesSplit{ - make_leaf(n1), + make_leaf(n1), make_leaf(n2), }; diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc index 222e9b20bb..0e77739434 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -10,14 +10,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_outputs_across_split") { - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; - + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 8981312c4b..9ca869b2b0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,6 +1,6 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include @@ -12,11 +12,11 @@ TEST_SUITE(FF_TEST_SUITE) { Node n2 = Node{2}; Node n3 = Node{3}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { return get_leaves(tree, impl); @@ -33,13 +33,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("series split") { SUBCASE("children are not the same") { - BinarySPDecompositionTree input = - BinarySPDecompositionTree{ + BinarySPDecompositionTree input = BinarySPDecompositionTree{ BinarySeriesSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n2}, + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, }, - }; + }; std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n2}; @@ -48,13 +47,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("children are the same") { - BinarySPDecompositionTree input = - BinarySPDecompositionTree{ + BinarySPDecompositionTree input = BinarySPDecompositionTree{ BinarySeriesSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, }, - }; + }; std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n1}; @@ -66,10 +64,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { BinarySPDecompositionTree input = BinarySPDecompositionTree{ - BinaryParallelSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n2}, - }, + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, }; std::unordered_multiset result = generic_get_leaves(input); @@ -80,10 +78,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("children are the same") { BinarySPDecompositionTree input = BinarySPDecompositionTree{ - BinaryParallelSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n1}, - }, + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, }; std::unordered_multiset result = generic_get_leaves(input); @@ -93,29 +91,23 @@ TEST_SUITE(FF_TEST_SUITE) { } } - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("nested") { - BinarySPDecompositionTree input = - make_parallel_split( - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))), - make_parallel_split( - make_leaf(n2), - make_leaf(n1))); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n1, n2, n2, n3}; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index f61ff83bf9..ad7e1c2609 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -11,31 +11,31 @@ TEST_SUITE(FF_TEST_SUITE) { Node n2 = Node{2}; Node n3 = Node{3}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; - auto generic_get_num_tree_nodes = [&](BinarySPDecompositionTree const &tree) { - return get_num_tree_nodes(tree, impl); - }; + auto generic_get_num_tree_nodes = + [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; SUBCASE("leaf") { - BinarySPDecompositionTree input = - make_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); int result = generic_get_num_tree_nodes(input); int correct = 1; @@ -88,16 +88,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - BinarySPDecompositionTree input = - make_parallel_split( - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))), - make_parallel_split( - make_leaf(n2), - make_leaf(n1))); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); int result = generic_get_num_tree_nodes(input); int correct = 9; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 05ff0b4aaa..3fae155280 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -12,23 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = Node{3}; Node n4 = Node{4}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually left associative") { SUBCASE("just node") { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 324008fdca..5b4e26107e 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -12,23 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = Node{3}; Node n4 = Node{4}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually right associative") { SUBCASE("just node") { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 20f939a8f0..fee971e5e0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,17 +18,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; @@ -49,8 +49,7 @@ TEST_SUITE(FF_TEST_SUITE) { left_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_series_split(make_leaf(n1), make_leaf(n2)), - make_leaf(n3)); + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 5db50ab2ef..fd540f853f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -14,17 +14,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("leaf") { BinarySPDecompositionTree input = make_leaf(n1); @@ -37,8 +37,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative series") { BinarySPDecompositionTree input = make_series_split( - make_series_split(make_leaf(n2), make_leaf(n1)), - make_leaf(n3)); + make_series_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -49,8 +48,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("right associative series") { BinarySPDecompositionTree input = make_series_split( - make_leaf(n2), - make_series_split(make_leaf(n1), make_leaf(n3))); + make_leaf(n2), make_series_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -73,8 +71,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_parallel_split(make_leaf(n2), make_leaf(n1)), - make_leaf(n3)); + make_parallel_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -85,8 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("right associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_leaf(n2), - make_parallel_split(make_leaf(n1), make_leaf(n3))); + make_leaf(n2), make_parallel_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -113,9 +109,9 @@ TEST_SUITE(FF_TEST_SUITE) { make_parallel_split( make_leaf(n1), make_series_split( - make_series_split(make_series_split(make_leaf(n2), - make_leaf(n3)), - make_leaf(n3)), + make_series_split( + make_series_split(make_leaf(n2), make_leaf(n3)), + make_leaf(n3)), make_leaf(n5))), make_series_split(make_leaf(n6), make_leaf(n4))), make_leaf(n5)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 19b9cfd944..532ff86c90 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,17 +16,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; @@ -47,8 +47,7 @@ TEST_SUITE(FF_TEST_SUITE) { right_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_leaf(n1), - make_series_split(make_leaf(n2), make_leaf(n3))); + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); CHECK(result == correct); }