diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f976d369d5..b54ef25819 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -123,9 +123,9 @@ jobs: run: | test_target.sh substitutions - # - name: Test compiler - # run: | - # test_target.sh compiler + - name: Test compiler + run: | + test_target.sh compiler - name: Test substitution-generator run: | diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml index efaf368bc8..efaf10c255 100644 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -3,24 +3,22 @@ name = "JsonSPModelExport" features = [ "eq", "hash", - "json", "fmt", + "json", ] includes = [ "pcg/file_format/v1/v1_computation_graph.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", ] src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", ] [[fields]] name = "sp_decomposition" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +type = "::FlexFlow::V1BinarySPDecomposition" [[fields]] name = "computation_graph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 1c2dfd6ea3..64419acce4 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,5 +1,5 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" @@ -13,7 +13,6 @@ #include "utils/cli/cli_parse.h" #include "utils/cli/cli_parse_result.h" #include "utils/cli/cli_spec.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/get_series_parallel_decomposition.h" @@ -105,9 +104,8 @@ tl::expected to_v1_including_node_numbering(computation_graph); V1ComputationGraph v1_cg = v1_result.first; bidict layer_numbering = v1_result.second; - GenericBinarySPDecompositionTree v1_sp_decomposition = - transform(sp_decomposition.raw_tree, - [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + V1BinarySPDecomposition v1_sp_decomposition = + to_v1(sp_decomposition, layer_numbering); return JsonSPModelExport{ v1_sp_decomposition, diff --git a/flake.lock b/flake.lock index 1aad68ae29..87fae7f446 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722923482, - "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", + "lastModified": 1728341842, + "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", "owner": "lockshaw", "repo": "proj", - "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", + "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h deleted file mode 100644 index 2e4ff8448b..0000000000 --- a/lib/compiler/include/compiler/cost_estimate.h +++ /dev/null @@ -1,61 +0,0 @@ - -#ifndef _FLEXFLOW_COMPILER_COST_ESTIMATE_H -#define _FLEXFLOW_COMPILER_COST_ESTIMATE_H - -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - -namespace FlexFlow { - -struct ICostEstimator { - virtual float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const = 0; - virtual float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const = 0; - - ICostEstimator() = default; - ICostEstimator(ICostEstimator const &) = delete; - ICostEstimator &operator=(ICostEstimator const &) = delete; - - virtual ~ICostEstimator() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); - -struct CostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const { - return this->implementation_ptr->estimate_cost( - op, inputs, weights, outputs, mv); - } - - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const { - return this->implementation_ptr->estimate_cost(tensor_shape, src, dst); - } - - template - static typename std::enable_if::value, - CostEstimator>::type - create(Args &&...args) { - return CostEstimator(std::make_shared(std::forward(args)...)); - } - -private: - CostEstimator(std::shared_ptr implementation_ptr) - : implementation_ptr(implementation_ptr) {} - std::shared_ptr implementation_ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h new file mode 100644 index 0000000000..65bae0c76a --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/machine_view.dtg.h" +#include + +namespace FlexFlow { + +struct ICostEstimator { + virtual float estimate_cost(OpCostEstimateKey const &) const = 0; + virtual float estimate_cost(TensorSetMovement const &) const = 0; + + ICostEstimator() = default; + ICostEstimator(ICostEstimator const &) = delete; + ICostEstimator &operator=(ICostEstimator const &) = delete; + + virtual ~ICostEstimator() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); + +struct CostEstimator { + float estimate_cost(OpCostEstimateKey const &k) const; + float estimate_cost(TensorSetMovement const &m) const; + + template + static typename std::enable_if::value, + CostEstimator>::type + create(Args &&...args) { + return CostEstimator(std::make_shared(std::forward(args)...)); + } + +private: + CostEstimator(std::shared_ptr implementation_ptr); + +private: + std::shared_ptr implementation_ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..8fd860d00d --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "OpCostEstimateKey" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml new file mode 100644 index 0000000000..70f73ebe51 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "SingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/machine_view.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml new file mode 100644 index 0000000000..3625605239 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>" diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml new file mode 100644 index 0000000000..22f29cbd59 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "GraphOptimizeResult" +features = [ ] + +includes = [ + "compiler/machine_mapping/machine_mapping.dtg.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h" +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h new file mode 100644 index 0000000000..2de2321ba6 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H +#define _FLEXFLOW_COMPILER_MCMC_STATE_H + +#include "compiler/graph_optimize_result.dtg.h" + +namespace FlexFlow { + +struct GraphOptimizeState { + GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, + float runtime); + + GraphOptimizeResult graph_optimize_result; + float runtime; + + bool operator==(GraphOptimizeState const &other) const; + bool operator!=(GraphOptimizeState const &other) const; + bool operator<(GraphOptimizeState const &other) const; +}; + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::GraphOptimizeState> { + size_t operator()(::FlexFlow::GraphOptimizeState const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h deleted file mode 100644 index 75fd369434..0000000000 --- a/lib/compiler/include/compiler/graph_utils.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_GRAPH_UTILS_H -#define _FLEXFLOW_COMPILER_GRAPH_UTILS_H - -#include "compiler/unity_algorithm.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -SeriesParallelDecomposition - get_series_parallel_decomposition(ParallelComputationGraph const &pcg); - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); - -// NOTE(@wmdi): I think we should have the following interfaces in the graph -// library eventually. - -template -void minimize(T &t, T const &v) { - if (v < t) { - t = v; - } -} - -template -void minimize(T &t, T const &v, Compare comp) { - if (comp(v, t)) { - t = v; - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h deleted file mode 100644 index 3774f2cd52..0000000000 --- a/lib/compiler/include/compiler/machine_mapping.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H - -#include "compiler/machine_mapping.dtg.h" -#include "compiler/optimal_cost_state.dtg.h" -#include "cost_estimate.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &, MachineMapping const &); - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - -struct OptimalCostResult { - static OptimalCostResult sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult infinity(); - - float runtime; - req machine_mapping; -}; -FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); - -struct OptimalCostRuntimeCmp { - bool operator()(OptimalCostResult const &, OptimalCostResult const &); -}; - -class OptimalCostCache { -public: - OptimalCostCache() = default; - - std::optional load(OptimalCostState const &) const; - void save(OptimalCostState const &, OptimalCostResult const &); - -private: - std::unordered_map cache; -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs); - -} // namespace FlexFlow - -// namespace std { -// -// template <> -// struct hash> { -// size_t operator()( -// std::unordered_map const &g) -// const; -// }; - -// }; // namespace std - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml new file mode 100644 index 0000000000..449a448706 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h new file mode 100644 index 0000000000..5b7e2f3613 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &); + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &, + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml new file mode 100644 index 0000000000..4cf184706b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..8567a7a3e6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H + +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml new file mode 100644 index 0000000000..e71cfc540f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "FeasibleMachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h new file mode 100644 index 0000000000..990c1c8205 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H + +#include "pcg/machine_specification.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h new file mode 100644 index 0000000000..62da90bfcb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..2aed9a20e4 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml new file mode 100644 index 0000000000..b9a7f9ac59 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "IncludeUnconstrained" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [] + +[[fields]] +name = "raw_bool" +type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h new file mode 100644 index 0000000000..06cbbf942d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping.dtg.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &, + MachineMapping const &); + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml similarity index 50% rename from lib/compiler/include/compiler/machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index 4c4912a3fd..92517c1110 100644 --- a/lib/compiler/include/compiler/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -9,13 +9,16 @@ features = [ "fmt", ] -includes = [ - "utils/graph/node/node.dtg.h", +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", +] + +src_includes = [ "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h new file mode 100644 index 0000000000..3a0fcf0e15 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache(); +std::optional + machine_mapping_cache_load(MachineMappingCache const &, + MachineMappingState const &); +void machine_mapping_cache_save(MachineMappingCache &, + MachineMappingState const &, + MachineMappingResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml new file mode 100644 index 0000000000..a76ff26eb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MachineMappingCache" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/machine_mapping_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h new file mode 100644 index 0000000000..d314ab493b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H + +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &); + +std::unordered_set + get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); + +std::optional + get_machine_view_for_layer(MachineMappingConstraints const &, + BinaryTreePath const &); + +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, + BinaryTreePathEntry const &); +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &); +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &); + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &, + ParallelLayerGuidObliviousMachineMapping const &); + +std::optional require_only_root(MachineMappingConstraints const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml new file mode 100644 index 0000000000..8e13abedb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineMappingConstraints" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml new file mode 100644 index 0000000000..81e26f491d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingContext" +features = [] + +includes = [ + "compiler/cost_estimator/cost_estimator.h", + "pcg/machine_view.dtg.h", + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::CostEstimator" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::UnmappedOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h new file mode 100644 index 0000000000..68d02aaa54 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree + get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h new file mode 100644 index 0000000000..29e9e7c90b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &); +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &); + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml new file mode 100644 index 0000000000..1949f143cb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::MMProblemTreeSeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::MMProblemTreeParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::UnmappedOpCostEstimateKey" +key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..5247b2006a --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml new file mode 100644 index 0000000000..d4f61bb3f5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h new file mode 100644 index 0000000000..9fbad4a1d0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &, parallel_layer_guid_t const &); + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..fe76683eb7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h new file mode 100644 index 0000000000..b21fea5f24 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H + +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" + +namespace FlexFlow { + +[[nodiscard]] MachineMappingResult infeasible_machine_mapping_result(); +[[nodiscard]] bool is_infeasible(MachineMappingResult const &); +FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &); + +[[nodiscard]] MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const + ¶llel_split_transformation); +[[nodiscard]] MachineMappingResult + parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); + +[[nodiscard]] MachineMappingResult + minimize_runtime(MachineMappingResult const &m1, + MachineMappingResult const &m2); + +[[nodiscard]] MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml new file mode 100644 index 0000000000..92a2873af5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "raw_result" +type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml new file mode 100644 index 0000000000..1346f6ebe7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingState" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +[[fields]] +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" + +[[fields]] +name = "resources" +type = "::FlexFlow::MachineSpecification" + +[[fields]] +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h new file mode 100644 index 0000000000..accd96af4c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +ParallelLayerGuidObliviousMachineMapping + restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping + restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); + +std::optional + get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml new file mode 100644 index 0000000000..f00fcc8490 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ParallelLayerGuidObliviousMachineMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml new file mode 100644 index 0000000000..8247c0cbdc --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelSplitTransformation" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LthenR" + +[[values]] +name = "RthenL" diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml new file mode 100644 index 0000000000..155e526672 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "PCGSplitBoundaryLayers" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h new file mode 100644 index 0000000000..2b2bc9bf84 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H + +#include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &); + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &); + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml new file mode 100644 index 0000000000..bb76ec2ff7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TransitiveReducedPCG" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "full_pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml deleted file mode 100644 index 036647c0b1..0000000000 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "OptimalCostState" -features = [ - "eq", - # "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/series_parallel_decomposition.dtg.h", - "pcg/machine_specification.dtg.h", - "pcg/machine_view.dtg.h", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "subgraph" -type = "::FlexFlow::SeriesParallelDecomposition" - -[[fields]] -name = "resource" -type = "::FlexFlow::MachineSpecification" - -[[fields]] -name = "given_machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" - -[[fields]] -name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/optimizer_config.struct.toml new file mode 100644 index 0000000000..b7f4f71e9c --- /dev/null +++ b/lib/compiler/include/compiler/optimizer_config.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OptimizerConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "alpha" +type = "float" + +[[fields]] +name = "budget" +type = "int" + +[[fields]] +name = "threshold" +type = "float" + +[[fields]] +name = "max_num_ops" +type = "int" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..9654a2546e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml new file mode 100644 index 0000000000..aa66c80b43 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h similarity index 52% rename from lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index 3032e3efe9..fdc80a1e37 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -1,19 +1,30 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" namespace FlexFlow { +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree(); + SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); + ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &); -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &); + std::optional get_computation_graph_left_assoc_binary_sp_decomposition( ComputationGraph const &); @@ -25,6 +36,9 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &); std::unordered_multiset get_layers(ComputationGraphBinarySPDecomposition const &); +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, + bidict const &layer_numbering); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..452470620b --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::ComputationGraphBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::ComputationGraphBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h similarity index 100% rename from lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml deleted file mode 100644 index 147b1e3acf..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinarySPDecomposition" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h new file mode 100644 index 0000000000..d43edaa79d --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h @@ -0,0 +1,11 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h new file mode 100644 index 0000000000..d4ae77541a --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h new file mode 100644 index 0000000000..f348b1a851 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..f7f7026716 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h new file mode 100644 index 0000000000..0842ffb48f --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +BinarySeriesSplit + binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml new file mode 100644 index 0000000000..af2c8c4dae --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h new file mode 100644 index 0000000000..86fa1a59aa --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &); + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..52372fb270 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::PCGBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::PCGBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::parallel_layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index abddef37ed..232f2b9563 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,39 +1,17 @@ #ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#include "compiler/machine_mapping.h" -#include "cost_estimate.h" -#include "machine_mapping.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/graph_optimize_result.dtg.h" +#include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" -namespace FlexFlow { - -struct Strategy { - ParallelComputationGraph pcg; - MachineMapping machine_mapping; - req runtime; - friend bool operator!=(Strategy const &lhs, Strategy const &rhs) { - return (lhs.machine_mapping != rhs.machine_mapping) || - (lhs.runtime != rhs.runtime); - } -}; - -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); -struct StrategyRuntimeCmp { - bool operator()(Strategy const &, Strategy const &); -}; - -struct OptimizerConfig { - float alpha; - int budget; - float threshold; - int max_num_ops; -}; +namespace FlexFlow { -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc new file mode 100644 index 0000000000..051ffcd190 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -0,0 +1,16 @@ +#include "compiler/cost_estimator/cost_estimator.h" + +namespace FlexFlow { + +CostEstimator::CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} + +float CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->implementation_ptr->estimate_cost(k); +} + +float CostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->implementation_ptr->estimate_cost(m); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc new file mode 100644 index 0000000000..4b4f323ea4 --- /dev/null +++ b/lib/compiler/src/compiler/graph_optimize_state.cc @@ -0,0 +1,85 @@ +#include "compiler/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState( + GraphOptimizeResult const &graph_optimize_result, float runtime) + : graph_optimize_result(graph_optimize_result), runtime(runtime) {} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + // Note(@wmdi): This is a hack to implement a partially correct homomorphism + // check. Switch to the homomorphism check used in substitutions right after + // https://github.com/flexflow/FlexFlow/pull/1471 is merged. + auto layers1 = topological_ordering(graph_optimize_result.pcg); + auto layers2 = topological_ordering(other.graph_optimize_result.pcg); + if (layers1.size() != layers2.size()) { + return false; + } + std::unordered_map mapping; + for (size_t i = 0; i < layers1.size(); ++i) { + if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != + get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { + return false; + } + auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); + auto inputs2 = + get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); + if (inputs1.size() != inputs2.size()) { + return false; + } + for (size_t j = 0; j < inputs1.size(); ++j) { + if (inputs1[j] != mapping.at(inputs2[j])) { + return false; + } + } + auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); + auto outputs2 = + get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); + if (outputs1.size() != outputs2.size()) { + return false; + } + for (size_t j = 0; j < outputs1.size(); ++j) { + mapping.emplace(outputs2[j], outputs1[j]); + } + } + return true; +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + size_t seed = 0; + auto layers = topological_ordering(state.graph_optimize_result.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (auto layer : layers) { + ::FlexFlow::hash_combine( + seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); + auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (auto input : inputs) { + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(input) == layers[i]) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } + } + } + return seed; +} + +} // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc new file mode 100644 index 0000000000..6f3deca138 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -0,0 +1,62 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { + return AbstractedTensorSetMovement{{}}; +} + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.src_machine_views; + }); +} + +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.dst_machine_views; + }); +} + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &abstracted, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(/*lhs=*/pre_mapping, + /*rhs=*/post_mapping); + + auto concretize_tensor_movement = + [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/ + transform( + a.src_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(pre_mapping, path).value(); + }), + /*dst_machine_views=*/ + transform( + a.dst_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(post_mapping, path).value(); + }), + }; + }; + + return TensorSetMovement{ + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, + concretize_tensor_movement), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..0e0f60c891 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,63 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + std::unordered_set edges_across_split = + pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = + [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { + std::unordered_set tensor_edges = + filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { + return get_parallel_tensor(e) == t; + }); + + std::unordered_set src_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); + + std::unordered_set dst_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); + + return AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/ + transform(src_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_left_child(), l)); + }), + /*dst_machine_views=*/ + transform(dst_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_right_child(), l)); + }), + }; + }; + + std::unordered_map + single_tensor_movements = generate_map( + pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return AbstractedTensorSetMovement{ + values(single_tensor_movements), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..5126d9687e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,32 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "utils/hash/pair.h" + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource) { + std::unordered_set> + result; + + for (int i = 1; i < resource.num_nodes; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_nodes = i; + sub_resource2.num_nodes = resource.num_nodes - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + for (int i = 1; i < resource.num_gpus_per_node; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_gpus_per_node = i; + sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..10abd7ff90 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,254 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/contains.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingState state = MachineMappingState{ + problem_tree, + resources, + constraints, + }; + + { + std::optional cached_result = + machine_mapping_cache_load(result_cache, state); + if (cached_result) { + return cached_result.value(); + } + } + + MachineMappingResult result = + problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping( + result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping(result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, + }); + + machine_mapping_cache_save(result_cache, state, result); + return result; +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation) { + + auto get_boundary_machine_view_assignments = + [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { + std::unordered_map> + allowed = generate_map( + boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = + mm_problem_tree_get_subtree_at_path( + MachineMappingProblemTree{series_split}, l) + .value() + .get(); + return context.allowed_machine_views(leaf, resources); + }); + return transform( + get_all_assignments(allowed), + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; + }); + }; + + auto eval_pre_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views) { + MachineMappingConstraints pre_candidate = with_additional_constraints( + restrict_to_left_child(constraints), assigned_pre_machine_views); + + MachineMappingResult pre_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_left_child(), + resources, + pre_candidate); + + return pre_result; + }; + + auto eval_post_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = with_additional_constraints( + restrict_to_right_child(constraints), assigned_post_machine_views); + + MachineMappingResult post_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_right_child(), + resources, + post_candidate); + + return post_result; + }; + + MachineMappingResult result = infeasible_machine_mapping_result(); + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views : + get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingResult pre_result = + eval_pre_boundary_mapping(assigned_pre_machine_views); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views : + get_boundary_machine_view_assignments( + get_dst_layers(tensor_movement))) { + + MachineMappingResult post_result = + eval_post_boundary_mapping(assigned_post_machine_views); + + TensorSetMovement comm_across_split = + concretize_abstracted_tensor_set_movement( + tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); + float cost_across_split = + context.cost_estimator.estimate_cost(comm_across_split); + + result = minimize_runtime(result, + series_combine(cost_across_split, + pre_result, + post_result, + parallel_split_transformation)); + } + } + + return result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingProblemTree lhs = parallel_split.get_left_child(); + MachineMappingProblemTree rhs = parallel_split.get_right_child(); + + MachineMappingResult series_result = [&] { + MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, + }; + + return get_optimal_machine_mapping(result_cache, + context, + series_split, + resources, + constraints, + ParallelSplitTransformation::LthenR); + }(); + + MachineMappingConstraints left_constraints = + restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = + restrict_to_right_child(constraints); + + auto evaluate_resource_split = + [&](std::pair const + &resource_split) { + MachineMappingResult left_result = get_optimal_machine_mapping( + result_cache, context, lhs, resource_split.first, left_constraints); + MachineMappingResult right_result = + get_optimal_machine_mapping(result_cache, + context, + rhs, + resource_split.second, + right_constraints); + + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results = transform( + get_machine_resource_splits(resources), evaluate_resource_split); + + return minimize_runtime(series_result, + get_mapping_with_minimal_runtime(parallel_results)); +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resource, + MachineMappingConstraints const &constraints) { + + std::unordered_set candidates = [&] { + std::optional machine_view = require_only_root(constraints); + if (machine_view.has_value()) { + return std::unordered_set{machine_view.value()}; + } else { + return context.allowed_machine_views(leaf, resource); + } + }(); + + auto get_mapping_result = [&](MachineView const &machine_view) { + OpCostEstimateKey mapped = + map_unmapped_op_cost_estimate_key(leaf, machine_view); + float cost = context.cost_estimator.estimate_cost(mapped); + + return make_singleton_machine_mapping_result(cost, machine_view); + }; + + std::unordered_set candidate_results = + transform(candidates, get_mapping_result); + + return get_mapping_with_minimal_runtime(candidate_results); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..6cc3f4329c --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,26 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + AbstractedTensorSetMovement abstracted = + get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement( + abstracted, pre_mapping, post_mapping); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..6f350d8773 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "utils/containers.h" +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &s1, + MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; +} + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc new file mode 100644 index 0000000000..fbfccf737f --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -0,0 +1,30 @@ +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache() { + return MachineMappingCache{{}}; +} + +std::optional + machine_mapping_cache_load(MachineMappingCache const &cache, + MachineMappingState const &k) { + return try_at(cache.raw_map, k); +} + +void machine_mapping_cache_save(MachineMappingCache &cache, + MachineMappingState const &k, + MachineMappingResult const &v) { + if (contains_key(cache.raw_map, k)) { + throw mk_runtime_error( + fmt::format("machine_mapping_cache_save expected key to not already " + "exist, but received existing key {}", + k)); + } + + cache.raw_map.emplace(k, v); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc new file mode 100644 index 0000000000..2cee866a01 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -0,0 +1,112 @@ +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "utils/containers/filter.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &layers) { + return MachineMappingConstraints{ + generate_map(layers, + [](BinaryTreePath const &) -> std::optional { + return std::nullopt; + }), + }; +} + +std::unordered_set + get_all_layers(MachineMappingConstraints const &partial_solution, + IncludeUnconstrained const &include_unconstrained) { + std::unordered_set with_unconstrained = + keys(partial_solution.machine_views); + + if (include_unconstrained.raw_bool) { + return with_unconstrained; + } else { + return filter(with_unconstrained, [&](BinaryTreePath const &l) { + return partial_solution.machine_views.at(l).has_value(); + }); + } +} + +std::optional get_machine_view_for_layer( + MachineMappingConstraints const &partial_solution, + BinaryTreePath const &layer) { + return partial_solution.machine_views.at(layer); +} + +MachineMappingConstraints + restrict_to_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { + return MachineMappingConstraints{filtermap_keys( + constraints.machine_views, + [&](BinaryTreePath const &path) -> std::optional { + BinaryTreePathEntry head = binary_tree_path_get_top_level(path); + + if (head == prefix) { + BinaryTreePath rest = binary_tree_path_get_non_top_level(path); + return rest; + } else { + return std::nullopt; + } + })}; +} + +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::LEFT_CHILD); +} + +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::RIGHT_CHILD); +} + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &constraints, + ParallelLayerGuidObliviousMachineMapping const &additional) { + MachineMappingConstraints result = constraints; + + for (auto const &[layer, machine_view] : additional.raw_mapping) { + std::optional current_machine_view = + result.machine_views.at(layer); + + if (!current_machine_view.has_value()) { + result.machine_views.at(layer) = machine_view; + } else { + if (current_machine_view.value() != machine_view) { + throw mk_runtime_error( + fmt::format("with_additional_layer_machine_views received machine " + "view assignment for layer {} " + "to machine view {}, but that layer is already " + "assigned to machine view {}.", + layer, + machine_view, + current_machine_view.value())); + } + } + } + + return result; +} + +std::optional + require_only_root(MachineMappingConstraints const &constraints) { + if (keys(constraints.machine_views) != + std::unordered_set{binary_tree_root_path()}) { + throw mk_runtime_error( + fmt::format("require_only_root expected constraints to have only a " + "single key (the root path), but received {}", + constraints)); + } + + return constraints.machine_views.at(binary_tree_root_path()); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..367af3701e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,53 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { + TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); + + std::function + to_problem_tree; + + to_problem_tree = + [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + return sp.visit(overload{ + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; + }, + [&](parallel_layer_guid_t const &leaf) { + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); + }; + + return to_problem_tree(sp_decomposition_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..1e39a7be19 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -0,0 +1,91 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { + return GenericBinarySPDecompositionTreeImplementation< + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](MachineMappingProblemTree const &tree) + -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(MachineMappingProblemTree const &tree) { + return tree.visit(overload{ + [](MMProblemTreeSeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](UnmappedOpCostEstimateKey const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &tree) { + return get_leaves(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &tree) { + return get_all_leaf_paths(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, + BinaryTreePath const &path) { + return get_subtree_at_path( + tree, generic_binary_sp_impl_for_mm_problem_tree(), path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc new file mode 100644 index 0000000000..990b287f8b --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -0,0 +1,36 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { + return get_parallel_tensor_shape(pcg, t); + }; + + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/pcg_get_op_attrs(pcg, layer), + /*input_shapes=*/ + transform(get_incoming_inputs(pcg, layer), get_tensor_shape), + /*weight_shapes=*/ + transform(get_incoming_weights(pcg, layer), get_tensor_shape), + /*output_shapes=*/ + transform(get_layer_outputs(pcg, layer), get_tensor_shape), + }; +} + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view) { + return OpCostEstimateKey{ + /*op_attrs=*/unmapped.op_attrs, + /*input_shapes=*/unmapped.input_shapes, + /*weight_shapes=*/unmapped.weight_shapes, + /*output_shapes=*/unmapped.output_shapes, + /*machine_view=*/machine_view, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..3409f7f871 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,138 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingResult infeasible_machine_mapping_result() { + return MachineMappingResult{std::nullopt}; +} + +bool is_infeasible(MachineMappingResult const &result) { + return !result.raw_result.has_value(); +} + +FeasibleMachineMappingResult + require_feasible(MachineMappingResult const &result) { + return result.raw_result.value(); +} + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &candidates) { + MachineMappingResult result = infeasible_machine_mapping_result(); + + for (MachineMappingResult const &candidate : candidates) { + result = minimize_runtime(result, candidate); + } + + return result; +} + +MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &maybe_pre_result, + MachineMappingResult const &maybe_post_result, + std::optional const + ¶llel_split_transformation) { + FeasibleMachineMappingResult pre_result = ({ + if (is_infeasible(maybe_pre_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_pre_result); + }); + + FeasibleMachineMappingResult post_result = ({ + if (is_infeasible(maybe_post_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_post_result); + }); + + ParallelLayerGuidObliviousMachineMapping mapping = [&] { + if (parallel_split_transformation.has_value() && + parallel_split_transformation.value() == + ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, + /*rhs=*/pre_result.machine_mapping); + } else { + return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, + /*rhs=*/post_result.machine_mapping); + } + }(); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, + /*machine_mapping=*/mapping, + }, + }; +} + +MachineMappingResult + parallel_combine(MachineMappingResult const &maybe_lhs_result, + MachineMappingResult const &maybe_rhs_result) { + FeasibleMachineMappingResult lhs_result = ({ + if (is_infeasible(maybe_lhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_lhs_result); + }); + + FeasibleMachineMappingResult rhs_result = ({ + if (is_infeasible(maybe_rhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_rhs_result); + }); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), + /*machine_mapping=*/ + binary_combine_mappings(/*lhs=*/lhs_result.machine_mapping, + /*rhs=*/rhs_result.machine_mapping), + }, + }; +} + +MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, + MachineMappingResult const &maybe_m2) { + FeasibleMachineMappingResult m1 = ({ + if (is_infeasible(maybe_m1)) { + return maybe_m2; + } + require_feasible(maybe_m1); + }); + + FeasibleMachineMappingResult m2 = ({ + if (is_infeasible(maybe_m2)) { + return maybe_m1; + } + require_feasible(maybe_m2); + }); + + if (m2.runtime < m1.runtime) { + return maybe_m2; + } else { + return maybe_m1; + } +} + +MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view) { + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/runtime, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), machine_view}, + }}, + }, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc new file mode 100644 index 0000000000..715a4c2e3d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -0,0 +1,24 @@ +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/try_at.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &lhs, + ParallelLayerGuidObliviousMachineMapping const &rhs) { + return ParallelLayerGuidObliviousMachineMapping{ + merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), + }; +} + +std::optional get_machine_view_for_path( + ParallelLayerGuidObliviousMachineMapping const &mapping, + BinaryTreePath const &path) { + return try_at(mapping.raw_mapping, path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc new file mode 100644 index 0000000000..96c8106cad --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -0,0 +1,93 @@ +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &tr_pcg) { + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, + /*transitive_reduction=*/tr_pcg.transitive_reduction, + }; +} + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { + DiGraphView raw_digraph = pcg.raw_graph; + DiGraphView transitive_reduced = transitive_reduction(raw_digraph); + + return TransitiveReducedPCG{ + /*pcg=*/pcg, + /*transitive_reduction=*/transitive_reduced, + }; +} + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_edges = + get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_outputs = + get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + SplitBoundaryNodes raw_boundary = + get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); + + return PCGSplitBoundaryLayers{ + /*pre_split_boundary=*/transform( + raw_boundary.pre_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*post_split_boundary=*/ + transform(raw_boundary.post_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..32fb53b58a --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,192 @@ +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> SPDecompositionTreeNodeType { return get_node_type(tree); }, + /*require_series=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> layer_guid_t const & { return tree.get(); }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { + return tree.get(); +} + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &bin) { + return bin.visit(overload{ + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, + }); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_left_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_right_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); +} + +V1BinarySPDecomposition + to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit( + overload{[&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + }}); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 97% rename from lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 184ad93f4d..8f78d423b3 100644 --- a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "op-attrs/computation_graph_op_attrs.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph/computation_graph_edge.h" diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc deleted file mode 100644 index 63054385ac..0000000000 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ /dev/null @@ -1,90 +0,0 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" - -namespace FlexFlow { - -SPDecompositionTreeNodeType - get_node_type(ComputationGraphBinarySPDecomposition const &d) { - return get_node_type(d.raw_tree); -} - -ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_left_child(d.raw_tree), - }; -} - -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_right_child(d.raw_tree), - }; -} - -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_node(d.raw_tree); -} - -std::optional - get_computation_graph_left_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - left_associative_binary_sp_tree_from_nary(sp_decomposition); - - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; -} - -std::optional - get_computation_graph_right_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - right_associative_binary_sp_tree_from_nary(sp_decomposition); - - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; -} - -bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_left_associative(d.raw_tree); -} - -bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_right_associative(d.raw_tree); -} - -std::unordered_multiset - get_layers(ComputationGraphBinarySPDecomposition const &d) { - return get_leaves(d.raw_tree); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc new file mode 100644 index 0000000000..220614bb8b --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc @@ -0,0 +1,10 @@ +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc new file mode 100644 index 0000000000..657a3c3166 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &pcg_split) { + return BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc new file mode 100644 index 0000000000..304ad224b1 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinarySeriesSplit binary_series_split_from_pcg_series_split( + PCGBinarySeriesSplit const &pcg_split) { + return BinarySeriesSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..5eb993c6ef --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -0,0 +1,115 @@ +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](PCGBinarySPDecomposition const &tree) + -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](PCGBinarySPDecomposition const &tree) + -> parallel_layer_guid_t const & { + return tree.get(); + }, + }; +} + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload{ + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, + }); +} + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_pcg_sp_tree()); +} + +SPDecompositionTreeNodeType + get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](PCGBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](PCGBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](parallel_layer_guid_t const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &tree, + parallel_layer_guid_t const &leaf) { + return find_paths_to_leaf(tree, generic_impl_for_pcg_sp_tree(), leaf); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc deleted file mode 100644 index a19c5e8597..0000000000 --- a/lib/compiler/src/graph_utils.cc +++ /dev/null @@ -1,153 +0,0 @@ -#include "compiler/graph_utils.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/containers/without_order.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -namespace FlexFlow { - -SeriesParallelDecomposition - get_series_parallel_decomposition(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return get_series_parallel_decomposition(pcg.raw_graph); -} - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { - NOT_IMPLEMENTED(); -} - -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return view_output_labelled_as_output_labelled_open(pcg.raw_graph); -} - -// std::vector -// get_sorted_node_input_edges(ParallelComputationGraph const &pcg, -// Node const &n) { -// std::unordered_map> -// incoming_edges = -// get_incoming_edges_by_idx(pcg, n); - -// std::vector result; -// for (auto const &p_id_edge_set : incoming_edges) { -// result.push_back(get_only(p_id_edge_set.second)); -// } - -// return result; -// } - -// std::unordered_map -// infer_tensor_shapes(ParallelComputationGraph const &pcg) { -// std::unordered_map result; -// for (Node const &n : get_topological_ordering(pcg)) { -// PCGOperatorAttrs op = pcg.raw_graph.at(n); - -// std::vector input_tensor_shapes = -// vector_transform([&](MultiDiEdge const &e) { return result.at(e); }, -// get_sorted_node_input_edges(pcg, n)); - -// std::vector output_tensor_shapes = -// get_output_shapes(op, input_tensor_shapes); - -// auto outgoing_edges = get_outgoing_edges_by_idx(pcg, n); - -// int i = 0; - -// for (auto const &[node_port, edges] : outgoing_edges) { -// for (MultiDiEdge const &e : edges) { -// result.insert({e, output_tensor_shapes[i++]}); -// } -// } -// } - -// assert(result.size() == get_edges(pcg.raw_graph).size()); - -// return result; -// } - -/* template */ -/* LabelledOpenMultiDiGraph */ -/* get_subgraph(LabelledOpenMultiDiGraph const &g, */ -/* std::unordered_set const &nodes, */ -/* InputSettings input_settings, */ -/* OutputSettings output_settings) { */ - -/* auto iview = LabelledOpenMultiDiGraphView(g) */ -/* .unsafe(); */ - -/* if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledOpenMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view(subgraph_view); */ -/* } else if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::EXCLUDE_OUTPUTS) { */ -/* LabelledUpwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else if (input_settings == InputSettings::EXCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledDownwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else { */ -/* LabelledMultiDiSubgraphView subgraph_view(*iview, - */ -/* nodes); - */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); - */ -/* } */ -/* } */ - -// struct GetNodes { -// template -// std::unordered_set operator()(T const &t) { -// return get_nodes(t); -// } -// }; - -// std::unordered_set get_nodes(SeriesParallelDecomposition const &sp) { -// return std::visit(GetNodes{}, sp.raw_variant); -// } - -// std::unordered_set get_nodes(SeriesSplit const &serial) { -// return set_union( -// transform(serial.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(ParallelSplit const ¶llel) { -// return set_union( -// transform(parallel.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(Node const &node) { -// return {node}; -// } - -} // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc deleted file mode 100644 index fddd825109..0000000000 --- a/lib/compiler/src/machine_mapping.cc +++ /dev/null @@ -1,366 +0,0 @@ -#include "compiler/machine_mapping.h" -#include "compiler/cost_estimate.h" -#include "compiler/graph_utils.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers.h" -#include "utils/containers/are_disjoint.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/get_only.h" -#include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/require_no_duplicates.h" -#include "utils/containers/vector_of.h" -#include "utils/exception.h" -#include "utils/graph/graph_split.dtg.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.h" -#include "utils/graph/series_parallel/series_parallel_splits.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { - return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; -} - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); -} - -OptimalCostResult - OptimalCostResult::sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{s1.runtime + s2.runtime, - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult - OptimalCostResult::parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{std::max(s1.runtime, s2.runtime), - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult OptimalCostResult::infinity() { - return {std::numeric_limits::infinity(), - MachineMapping{std::unordered_map{}}}; -} - -bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, - OptimalCostResult const &rhs) { - return lhs.runtime < rhs.runtime; -} - -std::optional - OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - OptimalCostResult result = cache.at(state); - return std::make_optional(result); - } - return std::nullopt; -} - -void OptimalCostCache::save(OptimalCostState const &state, - OptimalCostResult const &result) { - assert(!contains_key(cache, state)); - cache.emplace(state, result); -} - -std::vector> - get_resource_split(MachineSpecification const &resource) { - std::vector> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; - sub_resource1.num_nodes = i; - sub_resource2.num_nodes = resource.num_nodes - i; - result.push_back(std::make_pair(sub_resource1, sub_resource2)); - } - return result; -} - -// We may replace this by having unflattened AST -std::pair - decompose(SeriesSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SeriesSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SeriesParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - decompose(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(vector_of(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SeriesParallelDecomposition(decompn1), - widen(child)}; -} - -GraphSplit - get_graph_split(SeriesParallelDecomposition const &pre_decomposition, - SeriesParallelDecomposition const &post_decomposition) { - std::unordered_set pre_nodes = - require_no_duplicates(get_nodes(pre_decomposition)); - std::unordered_set post_nodes = - require_no_duplicates(get_nodes(post_decomposition)); - assert(are_disjoint(pre_nodes, post_nodes)); - return GraphSplit{pre_nodes, post_nodes}; -} - -float estimate_cost(SubParallelComputationGraph const &g, - CostEstimator const &estimator, - MachineMapping const &device_mapping, - std::unordered_map const - &frontier_machine_views) { - // TODO: Consider parallelism - float cost = 0; - // for (Node const &node : get_nodes(g.raw_graph)) { - // std::vector incoming_edges = - // get_incoming_edges(g.raw_graph, node); - // std::vector inputs = - // transform(incoming_edges, - // [&](OpenDataflowEdge const &input_edge) { - // return g.raw_graph.at(input_edge).get_shape(); - // }); - // cost += estimator.estimate_cost( - // g.raw_graph.at(node).op_attrs, inputs, - // device_mapping.machine_views.at(node)); - // } - return cost; -} - -void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { - minimize(m1, m2, OptimalCostRuntimeCmp{}); -} - -struct MachineMappingSearcher { - MachineMappingSearcher( - CostEstimator cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : cost_estimator(cost_estimator), - allowed_machine_views(allowed_machine_views), - cached_subgraph_costs(cached_subgraph_costs) {} - - CostEstimator cost_estimator; - std::function(ParallelLayerAttrs const &, - MachineSpecification const &)> - allowed_machine_views; - OptimalCostCache &cached_subgraph_costs; - - struct OptimalCostFunctor { - OptimalCostFunctor( - MachineMappingSearcher *searcher, - SubParallelComputationGraph const &g, - MachineSpecification resource, - std::unordered_map given_machine_views, - std::unordered_map - frontier_machine_views) - : searcher(searcher), g(g), resource(resource), - given_machine_views(given_machine_views), - frontier_machine_views(frontier_machine_views) {} - - MachineMappingSearcher *searcher; - SubParallelComputationGraph const &g; - MachineSpecification resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; - - template - OptimalCostResult operator()(T const &t) { - OptimalCostState state{SeriesParallelDecomposition{t}, - resource, - given_machine_views, - frontier_machine_views}; - std::optional cached_result = - searcher->cached_subgraph_costs.load(state); - - if (cached_result) { - return cached_result.value(); - } - OptimalCostResult result = searcher->optimal_cost( - t, g, resource, given_machine_views, frontier_machine_views); - - searcher->cached_subgraph_costs.save(state, result); - return result; - } - }; - - OptimalCostResult - optimal_cost(SubParallelComputationGraph const &g, - MachineSpecification resource, - SeriesParallelDecomposition const &sp_decomposition) { - return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), - sp_decomposition.raw_variant); - } - - OptimalCostResult optimal_cost( - SeriesSplit const &serial, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - NOT_IMPLEMENTED(); - // OptimalCostResult optimal_result = OptimalCostResult::infinity(); - - // auto decomposed = decompose(serial); - // SeriesParallelDecomposition pre_decompn = decomposed.first; - // SeriesParallelDecomposition post_decompn = decomposed.second; - - // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - // SubParallelComputationGraph pre_graph = - // get_subgraph(g, graph_split.first); - // SubParallelComputationGraph post_graph = - // get_subgraph(g, graph_split.second); - - // std::unordered_set post_graph_sources = - // get_closed_sources(post_graph); - - // assert(post_graph_sources.size() == 1); // assume perfect SP - - // Node split_point = get_only(post_graph_sources); - // OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - - // for (MachineView const &mv : - // allowed_machine_views(g.raw_graph.at(split_point), resource)) { - // std::unordered_map new_given_machine_views = - // given_machine_views; - // new_given_machine_views.emplace(split_point, mv); - // std::unordered_map - // new_frontier_machine_views = frontier_machine_views; - // new_frontier_machine_views.emplace(split_edge, mv); - // minimize_runtime( - // optimal_result, - // OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // pre_graph, - // resource, - // given_machine_views, - // new_frontier_machine_views), - // pre_decompn.raw_variant), - // std::visit(OptimalCostFunctor(this, - // post_graph, - // resource, - // new_given_machine_views, - // frontier_machine_views), - // post_decompn.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - ParallelSplit const ¶llel, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - - NOT_IMPLEMENTED(); - // auto decomposed = decompose(parallel); - // SeriesParallelDecomposition decompn1 = decomposed.first; - // SeriesParallelDecomposition decompn2 = decomposed.second; - - // GraphSplit graph_split = get_graph_split(decompn1, decompn2); - // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), - // g2 = get_subgraph(g, graph_split.second); - - // OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant)); - - // for (auto const &resource_split : get_resource_split(resource)) { - // minimize_runtime( - // optimal_result, - // OptimalCostResult::parallel_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource_split.first, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource_split.second, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - Node const &node, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - if (contains_key(given_machine_views, node)) { - assert(contains(allowed_machine_views(g.raw_graph.at(node), resource), - given_machine_views.at(node))); - MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}; - } else { - OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (auto mv : allowed_machine_views(g.raw_graph.at(node), resource)) { - MachineMapping mv_map{{{node, mv}}}; - minimize_runtime( - optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}); - } - return optimal_result; - } - } -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs) { - SeriesParallelDecomposition sp_decomposition = - get_series_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - MachineMappingSearcher searcher( - cost_estimator, allowed_machine_views, cached_subgraph_costs); - return searcher.optimal_cost(subpcg, resources, sp_decomposition); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ba6ef28daa..86a211c535 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,20 +1,16 @@ #include "compiler/unity_algorithm.h" -#include "compiler/graph_utils.h" -#include "compiler/machine_mapping.h" +#include "compiler/graph_optimize_state.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { -bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.runtime < rhs.runtime; -} - /* * Gets all substitutions applicable to a PCG */ -std::unordered_set +std::vector get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -22,14 +18,14 @@ std::unordered_set /* * Applies a substitution to all possible positions in PCG */ -std::unordered_set +std::vector apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { NOT_IMPLEMENTED(); } -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( @@ -37,58 +33,61 @@ Strategy graph_optimize( &allowed_machine_views, OptimizerConfig const &opt_config) { NOT_IMPLEMENTED(); - // ParallelComputationGraph pcg = cg_to_pcg(cg); - - // std::unordered_set subs = - // get_all_applicable_substitutions(pcg); - - // OptimalCostCache cached_subgraph_costs; - // DeduplicatedPriorityQueue, - // StrategyRuntimeCmp> - // candidates; - - // OptimalCostResult initial_pcg_result = optimal_cost(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy initial_result{ - // pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; - - // Strategy best_result = initial_result; - // candidates.push(initial_result); + // std::vector substitutions = + // get_all_applicable_substitutions(pcg); + // + // MachineMappingCache cached_subgraph_costs; + // DeduplicatedPriorityQueue candidates; + // + // MachineMappingResult original_pcg_cost = + // get_optimal_machine_mapping(pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // + // GraphOptimizeState initial_state = { + // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + // original_pcg_cost.runtime}; + // + // GraphOptimizeState best_state = initial_state; + // candidates.push(initial_state); + // // for (int iteration = 0; !candidates.empty() && iteration < // opt_config.budget; // ++iteration) { - // Strategy const ¤t_result = candidates.top(); + // GraphOptimizeState current_state = candidates.top(); // candidates.pop(); - - // if (current_result.runtime < best_result.runtime) { - // best_result = current_result; - // } else if (current_result.runtime > - // best_result.runtime * opt_config.alpha) { + // + // if (current_state.runtime < best_state.runtime) { + // best_state = current_state; + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) + // { // continue; // } - - // for (auto const &sub : subs) { - // for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) - // { - // OptimalCostResult c = optimal_cost(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; - // if (new_result.runtime <= opt_config.threshold && + // + // for (Substitution const &substitution : substitutions) { + // for (ParallelComputationGraph const &new_pcg : apply_substitution( + // current_state.graph_optimize_result.pcg, substitution)) { + // MachineMappingResult new_pcg_cost = + // get_optimal_machine_mapping(new_pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // GraphOptimizeState new_state{ + // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + // new_pcg_cost.runtime}; + // if (new_pcg_cost.runtime <= opt_config.threshold && // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_result); + // candidates.push(new_state); // } // } // } // } - // return best_result; + // return best_state.graph_optimize_result; } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..5c8ea1c0f1 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,300 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_abstracted_tensor_set_movement_across_split") { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs ew_add_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + SUBCASE("no edges across split") { + ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{}, + }; + + CHECK(result == correct); + } + + SUBCASE("single edge across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_series_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_parallel_split(make_leaf(layer_2.parallel_layer), + make_leaf(layer_3.parallel_layer)), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_4 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_parallel_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_3.parallel_layer), + make_leaf(layer_4.parallel_layer))}; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc new file mode 100644 index 0000000000..9ee596af3e --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -0,0 +1,41 @@ +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" + +namespace FlexFlow { + +TestCostEstimator::TestCostEstimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) + : get_operator_cost(get_operator_cost), + get_communication_cost(get_communication_cost) {} + +float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->get_operator_cost(k); +} + +float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->get_communication_cost(m); +} + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) { + + return CostEstimator::create(get_operator_cost, + get_communication_cost); +} + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map) { + return make_fake_cost_estimator( + [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, + [comm_cost_map](TensorSetMovement const &m) { + return comm_cost_map.at(m); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h new file mode 100644 index 0000000000..7c1d06207a --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H +#define _FLEXFLOW_TEST_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +struct TestCostEstimator : public ICostEstimator { + std::function get_operator_cost; + std::function get_communication_cost; + + TestCostEstimator() = delete; + TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, + decltype(get_communication_cost) + const &get_communication_cost); + + float estimate_cost(OpCostEstimateKey const &) const override; + + float estimate_cost(TensorSetMovement const &) const override; +}; + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost); + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..499b111f8f --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,235 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_resource_splits") { + auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) { + return MachineSpecification{ + /*num_nodes=*/num_nodes, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/num_gpus_per_node, + /*inter_node_bandwidth=*/1.0, + /*intra_node_bandwidth=*/1.0, + }; + }; + + SUBCASE("returns no splits if no splits are possible") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1); + + std::unordered_set> + result = get_machine_resource_splits(input); + std::unordered_set> + correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "returns splits in gpu and node dimensions, but not at the same time") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/2); + + std::unordered_set> + result = get_machine_resource_splits(input); + + std::unordered_set> + correct = { + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + + }; + + CHECK(result == correct); + } + + SUBCASE("returns splits in node dimension in powers of two") { + SUBCASE("num_nodes is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/8, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_nodes is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("returns splits in gpu dimension in powers of two") { + SUBCASE("num_gpus_per_node is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/8); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_gpus_per_node is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + } + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..0a874948e4 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,236 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_optimal_machine_mapping") { + auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + auto make_parallel_split = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); + MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MachineSpecification split_machine_spec = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, + MachineSpecification const &resources) { + if (resources == full_machine_spec) { + return std::unordered_set{mv1, mv2}; + } else { + return std::unordered_set{mv2}; + } + }; + + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + ParallelTensorShape tensor_shape1 = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, + }}; + + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, + {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, + {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, + {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), + 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), + 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), + 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), + 0.4}, + }}); + + MachineMappingContext context = MachineMappingContext{ + cost_estimator, + allowed_machine_views1, + }; + + MachineMappingCache cache = empty_machine_mapping_cache(); + + SUBCASE("single layer") { + MachineMappingProblemTree problem_tree = make_leaf(k1); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in sequence") { + MachineMappingProblemTree problem_tree = + make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in parallel") { + MachineMappingProblemTree problem_tree = + make_parallel_split(make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.5, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..82210a138b --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,239 @@ +// #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +// #include "compiler/machine_mapping/transitive_reduced_pcg.h" +// #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +// #include "pcg/machine_view.h" +// #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +// #include +// "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +// #include "utils/containers/get_only.h" +// #include +// #include "./cost_estimator_for_test.h" +// +// using namespace ::FlexFlow; +// +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("get_tensor_set_movement_across_split") { +// ParallelComputationGraph pcg = empty_parallel_computation_graph(); +// +// ParallelTensorShape input_shape = +// ParallelTensorShape{ +// ParallelTensorDims{ +// FFOrdered{ +// ShardParallelDim{10, 2}, +// ShardParallelDim{12, 1}, +// }, +// ReplicaParallelDimSet{ +// SumDegree{1}, +// DiscardCopyDegree{1}, +// }, +// }, +// DataType::FLOAT, +// }; +// ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); +// +// ParallelLayerAttrs relu_attrs +// = ParallelLayerAttrs{ +// /*op_attrs=*/PCGOperatorAttrs{ +// ElementUnaryAttrs{ +// /*op_type=*/OperatorType::RELU, +// /*scalar=*/std::nullopt, +// }, +// }, +// /*name=*/std::nullopt, +// }; +// +// ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ +// /*shape=*/input_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// /*create_gradients=*/CreateGrad::YES, +// }; +// +// ParallelLayerAddedResult relu_1 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// ParallelLayerAddedResult relu_2 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); +// MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); +// MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); +// MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); +// +// SUBCASE("single edge across split") { +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_leaf_node(relu_2.parallel_layer))); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("does not include edges removed by transitive reduction") { +// +// } +// +// SUBCASE("single tensor, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer)))); +// +// SUBCASE("consumers have same view") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("consumers have different views") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv2}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("multiple tensors, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// +// ParallelLayerAddedResult relu_4 +// = add_parallel_layer(pcg, +// relu_attrs, +// // relu's don't have two inputs, but for the +// purposes of this test it's fine. +// {get_only(relu_1.outputs), +// get_only(relu_3.outputs)}, {relu_output_attrs}); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// {relu_3.parallel_layer, pre_mv2}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_4.parallel_layer, post_mv2}, +// }}; +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_1.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer))), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_4.parallel_layer)))); +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv2}, +// /*dst_machine_views=*/{post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..6b16a54c1f --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,55 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "cost_estimator_for_test.h" +#include "doctest/doctest.h" +#include "pcg/machine_view.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("combine_disjoint_mappings(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + }); + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + MachineMapping correct = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + MachineMapping result = + combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + TEST_CASE("nodes_are_disjoint(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + }); + + SUBCASE("nodes are disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + + bool correct = true; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + SUBCASE("nodes are not disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + bool correct = false; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..06ab1e5b8c --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,289 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_mapping_problem_tree") { + auto pcg_make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinarySeriesSplit{ + lhs, + rhs, + }, + }; + }; + + auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto mm_problem_tree_make_series = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_parallel = + [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + auto make_output_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + + auto make_input_key = + [&](ParallelTensorShape const ¶llel_tensor_shape) { + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/input_attrs, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{parallel_tensor_shape}, + }; + }; + + SUBCASE("single layer") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = + PCGBinarySPDecomposition{input_layer}; + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree correct = MachineMappingProblemTree{input_key}; + + CHECK(result == correct); + } + + SUBCASE("two layers in series") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + parallel_tensor_guid_t input = get_only(input_added.outputs); + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }; + ParallelTensorShape relu_output_shape = input_shape; + ParallelLayerAddedResult relu_added = + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); + parallel_layer_guid_t relu_layer = relu_added.parallel_layer; + parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + + UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/relu_attrs, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{relu_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = pcg_make_series( + pcg_make_leaf(input_layer), pcg_make_leaf(relu_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{}}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + mm_problem_tree_make_leaf(input_key), + mm_problem_tree_make_leaf(relu_key)); + + CHECK(result == correct); + } + + SUBCASE("two layers in parallel") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = pcg_make_parallel( + pcg_make_leaf(input1_layer), pcg_make_leaf(input2_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); + + CHECK(result == correct); + } + + SUBCASE("multiple tensors across split") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + ParallelTensorShape ew_op_output_shape = input_shape; + ParallelLayerAddedResult ew_op_added = + add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); + parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/ew_op_attrs, + /*input_shapes=*/{input_shape, input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{ew_op_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = + pcg_make_series(pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + /*pre=*/ + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), + /*post=*/mm_problem_tree_make_leaf(ew_op_key)); + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..254d6b2784 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,342 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "pcg/machine_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("series_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + float pre_cost = 2.0; + MachineMappingResult pre = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + float post_cost = 4.0; + MachineMappingResult post = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + float comm_cost = 3.0; + + SUBCASE("pre is infeasbile") { + MachineMappingResult result = series_combine( + comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("post is infeasbile") { + MachineMappingResult result = series_combine( + comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = + series_combine(comm_cost, + infeasible, + infeasible, + ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult no_parallel_split_transform = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingResult result = + series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult lhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult rhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = parallel_combine(infeasible, rhs); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasbile") { + MachineMappingResult result = parallel_combine(lhs, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = parallel_combine(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult result = parallel_combine(lhs, rhs); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("minimize_runtime") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult faster = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult slower = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = minimize_runtime(infeasible, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasible") { + MachineMappingResult result = minimize_runtime(slower, infeasible); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = minimize_runtime(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + SUBCASE("lhs is faster") { + MachineMappingResult result = minimize_runtime(faster, slower); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("rhs is faster") { + MachineMappingResult result = minimize_runtime(slower, faster); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs have the same speed") { + MachineMappingResult result = minimize_runtime(slower, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 96% rename from lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 564cffaebe..2b59669aad 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" @@ -89,14 +89,14 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ input_layer.raw_node, projection_weights_layer.raw_node, bias_weights_layer.raw_node, - }, + }}, operator_layer.raw_node, - }}; + }}}; CHECK(result == correct); } @@ -159,17 +159,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ w1.raw_node, input.raw_node, w2.raw_node, - }, - ParallelSplit{ + }}, + ParallelSplit{{ op1.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("SP with or without preprocessing, but preprocessing would SP " @@ -214,16 +214,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ - SeriesSplit{ + SeriesParallelDecomposition{ParallelSplit{{ + SeriesSplit{{ input1.raw_node, op1.raw_node, - }, - SeriesSplit{ + }}, + SeriesSplit{{ input2.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("not SP with or without weight nodes") { diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc new file mode 100644 index 0000000000..46177ad420 --- /dev/null +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -0,0 +1,80 @@ +#include "compiler/graph_optimize_state.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("GraphOptimizeState::operator==") { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + parallel_tensor_guid_t dense1 = builder.dense(dense0, + 4, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense1"); + + ParallelComputationGraph pcg = builder.pcg; + + // `machine_mapping` is determined by the PCG and the device mapping + // algorithm, and `runtime` is determined by the PCG and the device mapping, + // so their values here do not matter. + std::unordered_map empty_machine_views; + MachineMapping empty_machine_mapping(empty_machine_views); + bool result1 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0); + bool correct1 = true; + CHECK(result1 == correct1); + + ParallelComputationGraphBuilder builder_; + + parallel_tensor_guid_t input0_ = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0_ = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + ParallelComputationGraph pcg_ = builder.pcg; + + bool result2 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0); + bool correct2 = false; + CHECK(result2 == correct2); + } +} diff --git a/lib/compiler/test/src/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h deleted file mode 100644 index 9417b863e4..0000000000 --- a/lib/compiler/test/src/test_cost_estimator.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H -#define _FLEXFLOW_TEST_COST_ESTIMATOR_H - -#include "compiler/cost_estimate.h" - -namespace FlexFlow { - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 0.1; - } -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h deleted file mode 100644 index 9f5a768b27..0000000000 --- a/lib/compiler/test/src/test_generator.h +++ /dev/null @@ -1,174 +0,0 @@ -#ifndef _FLEXFLOW_TEST_GENERATOR_H -#define _FLEXFLOW_TEST_GENERATOR_H - -#include "compiler/machine_mapping.h" -#include "pcg/computation_graph.h" -#include "rapidcheck.h" -#include "substitutions/sub_parallel_computation_graph.h" - -using namespace FlexFlow; - -// Rapidcheck does not work for now -// /* -// Generates computation graphs with trivial layers and tensors, which are -// used for tests focusing on graph structures. -// */ -// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, -// [](Tensor(MultiDiOutput const &)) { -// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; -// })); -// } - -// /* -// Generates parallel computation graphs with trivial layers and tensors, -// which are used for tests focusing on graph structures. -// */ -// ParallelComputationGraph -// test_parallel_computation_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, -// [](Operator(MultiDiOutput const &)) { -// return ParallelTensor(ParallelTensorDims(TensorDims({})), -// DataType::FLOAT); -// })); -// } - -// rc::Gen small_integer_generator() { -// return rc::gen::inRange(1, 4); -// } - -// namespace rc { - -// Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), -// multidigraph_from_sp_decomposition); -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_computataion_graph); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_parallel_computation_graph); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>( -// gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Serial::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Parallel::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( -// gen::arbitrary()) -// : gen::construct( -// gen::arbitrary()); -// }); -// } -// }; - -// template -// struct Arbitrary { -// static Gen< -// std::enable_if, -// Tag>::value>::type> arbitrary() { -// return gen::construct(gen::arbitrary()); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::apply(make_1d_machine_view, -// gen::arbitrary, -// gen::arbitrary, -// small_integer_generator()); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineMapping::machine_views, -// gen::container>( -// gen::arbitrary(), -// gen::arbitrary()))); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, -// 64)), gen::set(&MachineSpecification::num_gpus_per_node, -// gen::inRange(1, 16)), -// gen::set(&MachineSpecification::inter_node_bandwidth, -// gen::nonZero()), -// gen::set(&MachineSpecification::intra_node_bandwidth, -// gen::nonZero())); -// } -// } - -// } // namespace rc - -#endif diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc deleted file mode 100644 index 59fa0f1e5e..0000000000 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ /dev/null @@ -1,132 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// // #include "rapidcheck.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, -// node_set0); auto subgraph2 = -// get_subgraph(g, node_set0); -// auto subgraph3 = get_subgraph(g, node_set0); - -// CHECK(bool(get_nodes(subgraph0) == node_set0)); -// CHECK(bool(get_nodes(subgraph1) == node_set0)); -// CHECK(bool(get_nodes(subgraph2) == node_set0)); -// CHECK(bool(get_nodes(subgraph3) == node_set0)); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK( -// bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(bool(get_closed_sources(subgraph2) == -// std::unordered_set{n3})); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(bool(get_edges(g).size() == 1)); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(bool(open_graph.at(n0) == 0)); -// CHECK(bool(open_graph.at(n1) == 1)); -// CHECK(bool(open_graph.at(e0) == 2)); - -// CHECK(get_edges(open_graph).size() == 1); -// } - -// TEST_CASE("OutputLabelledOpenMultiDiGraph") { -// OutputLabelledOpenMultiDiGraph g = -// OutputLabelledOpenMultiDiGraph::create< -// UnorderedOutputLabelledOpenMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_label(e0, 2); - -// CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); -// CHECK(bool(get_edges(g).size() == 1)); -// } -// } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc deleted file mode 100644 index 4f9b879574..0000000000 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "doctest/doctest.h" -#include "test_generator.h" - -TEST_SUITE(FF_TEST_SUITE) { - // TEST_CASE("MachineMapping::combine") { - // RC_SUBCASE([](MachineMapping const &m0, MachineMapping const &m1) { - // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - - // MachineMapping comb = MachineMapping::combine(m0, m1); - - // RC_ASSERT(comb.machine_views.size() == - // m0.machine_views.size() + m1.machine_views.size()); - // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - // }); - // } - - // TEST_CASE("OptimalCostResult::infinity") { - // RC_SUBCASE([](OptimalCostResult const &c) { - // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - // }); - // } -} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc deleted file mode 100644 index e3426aa293..0000000000 --- a/lib/compiler/test/src/test_open_graph.cc +++ /dev/null @@ -1,81 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "utils/graph/algorithms.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_source_sink_open_graph") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// NodePort p0 = g.add_node_port(); -// InputMultiDiEdge e0{ -// n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; -// g.add_edge(e0); - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -// } - -// TEST_CASE("get_source_sink_open_graph:unconnected") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; -// OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; -// g.add_edge(e0); -// g.add_edge(e1); - -// /* -// g: ->n0 -// n1-> -// */ - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -// } - -// TEST_CASE("get_cut") { -// auto g = OpenMultiDiGraph::create(); - -// std::vector ns = add_nodes(g, 5); - -// MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; -// MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; -// MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; -// OutputMultiDiEdge e5{ -// ns[4], g.add_node_port(), std::make_pair(ns[4].value(), -// ns[4].value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; -// CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, -// e2})); - -// GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; -// CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, -// e4})); -// } -// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc deleted file mode 100644 index 133558f83a..0000000000 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ /dev/null @@ -1,72 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "test_cost_estimator.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// // Rapidcheck infrastructures for graphs does not work for now -// /* -// Tests whether optimal_cost can give a valid result given random PCG, -// trivial allowed machine views, trivial cost estimator and random machine -// specification. -// */ -// // TEST_CASE("optimal_cost") { -// // auto test_allowed_machine_views = [](Operator const &, -// // MachineSpecification const &) { -// // return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// // }; -// // RC_SUBCASE([](ParallelComputationGraph const &g, -// // MachineSpecification const &machine_spec) { -// // OptimalCostCache cached_subgraph_costs; -// // OptimalCostResult result = optimal_cost(g, -// // test_allowed_machine_views, -// // TestCostEstimator{}, -// // machine_spec, -// // cached_subgraph_costs); -// // RC_ASSERT(result.runtime > 0); -// // RC_ASSERT(keys(result.machine_mapping.machine_views) == -// get_nodes(g)); -// // }); -// // } - -// TEST_CASE("optimal_cost_0") { -// auto pcg = -// OutputLabelledMultiDiGraph::template -// create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); -// Node n1 = pcg.add_node(Operator{ -// LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, -// std::nullopt}, "linear"}); - -// MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; -// pcg.add_edge(e); -// ParallelDim dim = {2, 1, false}; -// ParallelTensorDims dims = {FFOrdered{dim}}; -// pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, -// CreateGrad::YES)); - -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{ -// make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; -// }; - -// CostEstimator estimator = CostEstimator::create(); - -// MachineSpecification machine_spec{1, 1, 1, 1, 1}; - -// OptimalCostCache cached_results; - -// OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), -// test_allowed_machine_views, -// estimator, -// machine_spec, -// cached_results); - -// CHECK(bool(result.runtime > 0)); -// } -// } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc similarity index 93% rename from lib/compiler/test/src/test_unity_algorithm.cc rename to lib/compiler/test/src/unity_algorithm.cc index ed5e895a75..8ff0978ea5 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/unity_algorithm.cc @@ -1,7 +1,5 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "test_cost_estimator.h" -#include "test_generator.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index f523520f9f..5fbcd91a06 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -18,8 +18,8 @@ class GenericTensorAccessorW { if (this->data_type == DT) { return static_cast *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -49,8 +49,8 @@ class GenericTensorAccessorR { if (this->data_type == DT) { return static_cast const *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -97,7 +97,7 @@ typename data_type_enum_to_class
::type * return static_cast *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } @@ -118,7 +118,7 @@ typename data_type_enum_to_class
::type const * return static_cast const *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } diff --git a/lib/kernels/include/kernels/datatype_dispatch.h b/lib/kernels/include/kernels/datatype_dispatch.h index e6ab9fa8cc..e83fc3325d 100644 --- a/lib/kernels/include/kernels/datatype_dispatch.h +++ b/lib/kernels/include/kernels/datatype_dispatch.h @@ -22,7 +22,7 @@ Out dispatch(DataType dt, Args &&...args) { case DataType::BOOL: return F{}(std::forward(args)...); default: - throw mk_runtime_error("Unknown datatype {}", dt); + throw mk_runtime_error(fmt::format("Unknown datatype {}", dt)); } } diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index c761eaf1d9..3128e39fd0 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 3a36e02327..4035aaf7cf 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -28,10 +28,11 @@ struct DeviceSpecific { T const *get(size_t curr_device_idx) const { if (curr_device_idx != this->device_idx) { - throw mk_runtime_error("Invalid access to DeviceSpecific: attempted " - "device_idx {} != correct device_idx {})", - curr_device_idx, - this->device_idx); + throw mk_runtime_error( + fmt::format("Invalid access to DeviceSpecific: attempted " + "device_idx {} != correct device_idx {})", + curr_device_idx, + this->device_idx)); } return (T const *)this->ptr.get(); } diff --git a/lib/local-execution/include/local-execution/permissions.h b/lib/local-execution/include/local-execution/permissions.h index ce19e38e7e..f34969f233 100644 --- a/lib/local-execution/include/local-execution/permissions.h +++ b/lib/local-execution/include/local-execution/permissions.h @@ -42,8 +42,8 @@ struct formatter<::FlexFlow::Permissions> : formatter { name = "READ_WRITE"; break; default: - throw ::FlexFlow::mk_runtime_error("Unknown permission {}", - static_cast(p)); + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } return formatter::format(name, ctx); } diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 5d0156201e..54eca7e514 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -30,7 +30,7 @@ GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( @@ -49,7 +49,7 @@ VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return variadic_tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } diff --git a/lib/local-execution/src/ops/batch_matmul.h b/lib/local-execution/src/ops/batch_matmul.h index c082dec020..a7e29b1931 100644 --- a/lib/local-execution/src/ops/batch_matmul.h +++ b/lib/local-execution/src/ops/batch_matmul.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/op_task_signature.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_matmul.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/batch_norm.h b/lib/local-execution/src/ops/batch_norm.h index 1f6cceec19..36aa8ffa4e 100644 --- a/lib/local-execution/src/ops/batch_norm.h +++ b/lib/local-execution/src/ops/batch_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/cast.h b/lib/local-execution/src/ops/cast.h index b4a1e91c91..e7af6aca6b 100644 --- a/lib/local-execution/src/ops/cast.h +++ b/lib/local-execution/src/ops/cast.h @@ -17,7 +17,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/combine.h b/lib/local-execution/src/ops/combine.h index c6157a2955..e85e8fba39 100644 --- a/lib/local-execution/src/ops/combine.h +++ b/lib/local-execution/src/ops/combine.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/combine_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/concat.h b/lib/local-execution/src/ops/concat.h index 1f1443f25d..eab70d621c 100644 --- a/lib/local-execution/src/ops/concat.h +++ b/lib/local-execution/src/ops/concat.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/conv_2d.h b/lib/local-execution/src/ops/conv_2d.h index f70d36d514..0358d71eea 100644 --- a/lib/local-execution/src/ops/conv_2d.h +++ b/lib/local-execution/src/ops/conv_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 84b67a29c2..a3dc5ff8af 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" #include "local-execution/task_id_t.dtg.h" -#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.h b/lib/local-execution/src/ops/element_binary.h index 05273e34b4..72c0976df8 100644 --- a/lib/local-execution/src/ops/element_binary.h +++ b/lib/local-execution/src/ops/element_binary.h @@ -3,7 +3,7 @@ #include "local-execution/sim_environment.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_unary.h b/lib/local-execution/src/ops/element_unary.h index 4d1783f1f6..04a72e2e12 100644 --- a/lib/local-execution/src/ops/element_unary.h +++ b/lib/local-execution/src/ops/element_unary.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/embedding.h b/lib/local-execution/src/ops/embedding.h index 0463984122..995d2296e1 100644 --- a/lib/local-execution/src/ops/embedding.h +++ b/lib/local-execution/src/ops/embedding.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/flat.h b/lib/local-execution/src/ops/flat.h index d1501f85ca..e019bfc654 100644 --- a/lib/local-execution/src/ops/flat.h +++ b/lib/local-execution/src/ops/flat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_FLAT_H #include "local-execution/sim_environment.h" -#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/flat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/src/ops/gather.h index 74db276e35..e339683381 100644 --- a/lib/local-execution/src/ops/gather.h +++ b/lib/local-execution/src/ops/gather.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/gather_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h index 97985585e1..baad25b798 100644 --- a/lib/local-execution/src/ops/input.h +++ b/lib/local-execution/src/ops/input.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_INPUT_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" +#include "op-attrs/ops/input_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/layer_norm.h b/lib/local-execution/src/ops/layer_norm.h index 4f8d87153b..8e034ac519 100644 --- a/lib/local-execution/src/ops/layer_norm.h +++ b/lib/local-execution/src/ops/layer_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/linear.h b/lib/local-execution/src/ops/linear.h index 2c76483df4..2aaf13a95a 100644 --- a/lib/local-execution/src/ops/linear.h +++ b/lib/local-execution/src/ops/linear.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index 959f7dc054..1097adeb5e 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -2,9 +2,7 @@ #define _FLEXFLOW_NOOP_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" -#include "op-attrs/ops/noop.h" -#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/pool_2d.h b/lib/local-execution/src/ops/pool_2d.h index e8624185ac..908fd5462f 100644 --- a/lib/local-execution/src/ops/pool_2d.h +++ b/lib/local-execution/src/ops/pool_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduce.h b/lib/local-execution/src/ops/reduce.h index 92f0578757..7900c28159 100644 --- a/lib/local-execution/src/ops/reduce.h +++ b/lib/local-execution/src/ops/reduce.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduce.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduction.h b/lib/local-execution/src/ops/reduction.h index a0af4f3aea..56833602e6 100644 --- a/lib/local-execution/src/ops/reduction.h +++ b/lib/local-execution/src/ops/reduction.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/repartition.h b/lib/local-execution/src/ops/repartition.h index b38a93f8b1..5187d04ca0 100644 --- a/lib/local-execution/src/ops/repartition.h +++ b/lib/local-execution/src/ops/repartition.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/replicate.h b/lib/local-execution/src/ops/replicate.h index 77bda411c1..85d1dff41a 100644 --- a/lib/local-execution/src/ops/replicate.h +++ b/lib/local-execution/src/ops/replicate.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reshape.h b/lib/local-execution/src/ops/reshape.h index 06a6b32597..37f07534ee 100644 --- a/lib/local-execution/src/ops/reshape.h +++ b/lib/local-execution/src/ops/reshape.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reshape.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reverse.h b/lib/local-execution/src/ops/reverse.h index 10072860b0..7c16073be7 100644 --- a/lib/local-execution/src/ops/reverse.h +++ b/lib/local-execution/src/ops/reverse.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reverse.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/softmax.h b/lib/local-execution/src/ops/softmax.h index b5756d92ff..d440fe7239 100644 --- a/lib/local-execution/src/ops/softmax.h +++ b/lib/local-execution/src/ops/softmax.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/split.h b/lib/local-execution/src/ops/split.h index c82152b06a..dde46c20bf 100644 --- a/lib/local-execution/src/ops/split.h +++ b/lib/local-execution/src/ops/split.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/split.h" +#include "op-attrs/ops/split_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/topk.h b/lib/local-execution/src/ops/topk.h index b04d807400..c8f3175ebd 100644 --- a/lib/local-execution/src/ops/topk.h +++ b/lib/local-execution/src/ops/topk.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/topk.h" +#include "op-attrs/ops/topk_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/transpose.h b/lib/local-execution/src/ops/transpose.h index 3feffc7d86..0f3a2e80a0 100644 --- a/lib/local-execution/src/ops/transpose.h +++ b/lib/local-execution/src/ops/transpose.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/permissions.cc b/lib/local-execution/src/permissions.cc index e5c46b42f8..2286215987 100644 --- a/lib/local-execution/src/permissions.cc +++ b/lib/local-execution/src/permissions.cc @@ -33,7 +33,8 @@ static int as_int(Permissions p) { case Permissions::RW: return 2; default: - throw mk_runtime_error("Unknown permission {}", static_cast(p)); + throw mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } } diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h deleted file mode 100644 index 268554b5be..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H - -#include "op-attrs/ops/core.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/broadcast.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" -#include "utils/record_formatter.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes); - -bool is_valid(PCGOperatorAttrs const &, - std::vector const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc deleted file mode 100644 index e6459c6819..0000000000 --- a/lib/op-attrs/src/operator_attrs.cc +++ /dev/null @@ -1,287 +0,0 @@ -#include "op-attrs/operator_attrs.h" -#include "utils/fmt.h" -#include "utils/record_formatter.h" -#include "utils/type_traits.h" - -namespace FlexFlow { - -/* OperatorType GetOpType::operator()(BatchMatmulAttrs const &p) const { return - * OP_BATCHMATMUL; } */ -/* OperatorType GetOpType::operator()(Conv2DAttrs const &p) const { return - * OP_CONV2D; } */ -/* OperatorType GetOpType::operator()(ConcatAttrs const &p) const { return - * OP_CONCAT; } */ -/* OperatorType GetOpType::operator()(CastAttrs const &p) const { return - * OP_CAST; } */ -/* OperatorType GetOpType::operator()(ElementBinaryAttrs const &p) const { - * return p.type; } */ -/* OperatorType GetOpType::operator()(ElementUnaryAttrs const &p) const { return - * p.op_type; } */ -/* OperatorType GetOpType::operator()(DropoutAttrs const &p) const { return - * OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(EmbeddingAttrs const &p) const { return - * OP_EMBEDDING; } */ -/* OperatorType GetOpType::operator()(FlatAttrs const &p) const { return - * OP_FLAT; } */ -/* OperatorType GetOpType::operator()(LayerNormAttrs const &p) const { return - * OP_LAYERNORM; } */ -/* OperatorType GetOpType::operator()(LinearAttrs const &p) const { return - * OP_LINEAR; } */ -/* OperatorType GetOpType::operator()(MultiHeadAttentionAttrs const &p) const { - * return OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(Pool2DAttrs const &p) const { return - * OP_POOL2D; } */ -/* OperatorType GetOpType::operator()(ReshapeAttrs const &p) const { return - * OP_RESHAPE; } */ -/* OperatorType GetOpType::operator()(SplitAttrs const &p) const { return - * OP_SPLIT; } */ -/* OperatorType GetOpType::operator()(SoftmaxAttrs const &p) const { return - * OP_SOFTMAX; } */ -/* OperatorType GetOpType::operator()(TransposeAttrs const &p) const { return - * OP_TRANSPOSE; } */ -/* OperatorType GetOpType::operator()(RepartitionAttrs const &p) const { return - * OP_REPARTITION; } */ -/* OperatorType GetOpType::operator()(ReplicateAttrs const &p) const { return - * OP_REPLICATE; } */ -/* OperatorType GetOpType::operator()(ReductionAttrs const &p) const { return - * OP_REDUCTION; } */ -/* OperatorType GetOpType::operator()(CombineAttrs const &p) const { return - * OP_COMBINE; } */ -/* OperatorType GetOpType::operator()(FusedParallelOpAttrs const &p) const { - * return OP_FUSED_PARALLEL; } */ - -/* struct AsOpAttrs { */ -/* template */ -/* OpAttrsInterface const &operator()(T const &p) { */ -/* return p; */ -/* } */ -/* }; */ - -/* OperatorType get_op_type(OpAttrsInterface const &o) { */ -/* return o.op_type(); */ -/* } */ -/* // */ -/* OperatorType get_op_type(CompGraphOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* OperatorType get_op_type(PCGOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* std::vector get_output_shapes(PCGOperatorAttrs const - * &op_params, std::vector const &input_tensor_shapes) { */ -/* return mpark::visit(AsOpAttrs{}, - * op_params).output_shapes(input_tensor_shapes); */ -/* } */ - -/* bool is_parallel_op(PCGOperatorAttrs const &o) { */ -/* return is_parallel_op(get_op_type(o)); */ -/* } */ -template -typename std::enable_if<(is_streamable::value && - !is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - std::ostringstream oss; - oss << t; - r << oss; -} - -template -typename std::enable_if<(is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - r << fmt::to_string(t); -} -void as_dot(int x, RecordFormatter &r) { - r << std::to_string(x); -} - -void as_dot(std::string const &s, RecordFormatter &r) { - r << s; -} - -template -void as_dot(std::vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -template -void as_dot(stack_vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -struct as_dot_visitor { - as_dot_visitor() = delete; - as_dot_visitor(RecordFormatter &result) : result(result) {} - - RecordFormatter &result; - - template - void operator()(char const *name, T const &t) { - RecordFormatter kv; - kv << name; - as_dot(t, result); - result << kv; - } - - template - void operator()(T const &t) { - as_dot(t, result); - } - - /* template */ - /* void operator()(const char *name, std::vector const &t) { */ - /* RecordFormatter kv; */ - /* kv << name; */ - /* RecordFormatter v; */ - /* for (V const &vv : t) { */ - /* v << as_dot_str(vv); */ - /* } */ - /* kv << v; */ - /* } */ -}; - -template -typename std::enable_if::value>::type - as_dot(T const &t, RecordFormatter &r) { - as_dot_visitor vis(r); - visit_struct::for_each(t, vis); -} - -struct AsDot { - template - RecordFormatter operator()(T const &t) { - return as_dot(t); - } -}; - -template -RecordFormatter as_dot(std::variant const &o) { - return std::visit(AsDot{}, o); -} - -struct IsValidFunctor { - IsValidFunctor(std::vector const &_input_shapes) - : input_shapes(_input_shapes) {} - - std::vector const &input_shapes; - - // bool operator()(AggregateAttrs const &attrs) { - // return is_valid(attrs, - // input_shapes.at(0), - // input_shapes.at(1), - // input_shapes.at(2), - // input_shapes.at(3), - // subvec(input_shapes, 4, nullopt)); - // } - - template - bool operator()(T const &) { - return true; // TODO FIXME @lockshaw - } -}; - -bool is_valid(PCGOperatorAttrs const &attrs, - std::vector const &input_shapes) { - NOT_IMPLEMENTED(); -} - -/* int num_outputs(OperatorParameters const &o) { */ -/* switch (get_op_type(o)) { */ -/* case OP_SPLIT: */ -/* } */ -/* } */ - -// tl::optional get_op_parameters(Op const *op) { -// switch (op->op_type) { -// case OP_LINEAR: -// return ((Linear *)op)->get_params(); -// case OP_CONV2D: -// return ((Conv2D *)op)->get_params(); -// case OP_EW_ADD: -// case OP_EW_SUB: -// case OP_EW_MUL: -// case OP_EW_DIV: -// return ((ElementBinary *)op)->get_params(); -// case OP_EXP: -// case OP_SIN: -// case OP_COS: -// case OP_SCALAR_MULTIPLY: -// case OP_SCALAR_ADD: -// case OP_SCALAR_SUB: -// case OP_SCALAR_TRUE_DIV: -// case OP_RELU: -// case OP_SIGMOID: -// case OP_TANH: -// case OP_IDENTITY: -// case OP_GELU: -// case OP_ELU: -// return ((ElementUnary *)op)->get_params(); -// case OP_CONCAT: -// return ((Concat *)op)->get_params(); -// case OP_POOL2D: -// return ((Pool2D *)op)->get_params(); -// case OP_CAST: -// return ((Cast *)op)->get_params(); -// case OP_DROPOUT: -// return ((Dropout *)op)->get_params(); -// case OP_EMBEDDING: -// return ((Embedding *)op)->get_params(); -// case OP_FLAT: -// return ((Flat *)op)->get_params(); -// case OP_MULTIHEAD_ATTENTION: -// return ((MultiHeadAttention *)op)->get_params(); -// case OP_LAYERNORM: -// return ((LayerNorm *)op)->get_params(); -// case OP_RESHAPE: -// return ((Reshape *)op)->get_params(); -// case OP_SOFTMAX: -// return ((Softmax *)op)->get_params(); -// case OP_REPARTITION: -// return ((Repartition *)op)->get_params(); -// case OP_REPLICATE: -// return ((Replicate *)op)->get_params(); -// case OP_REDUCTION: -// return ((Reduction *)op)->get_params(); -// case OP_COMBINE: -// return ((Combine *)op)->get_params(); -// case OP_FUSED_PARALLEL: -// return ((FusedParallelOp *)op)->get_params(); -// case OP_TRANSPOSE: -// return ((Transpose *)op)->get_params(); -// case OP_BATCHMATMUL: -// return ((BatchMatmul *)op)->get_params(); -// case OP_SPLIT: -// return ((Split *)op)->get_params(); -// -// // TODO: implement the get_params() function for the operators below -// and -// // uncomment the lines below -// -// // case OP_NOOP: -// // return ((NoOp *)op)->get_params(); -// // case OP_TOPK: -// // return ((TopK *)op)->get_params(); -// // case OP_MEAN: -// // return ((Mean *)op)->get_params(); -// // case OP_CACHE: -// // return ((Cache *)op)->get_params(); -// // case OP_REVERSE: -// // return ((Reverse *)op)->get_params(); -// // case OP_BATCHNORM: -// // return ((BatchNorm *)op)->get_params(); -// -// default: -// return tl::nullopt; -// } -// } - -} // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index f70d9f7404..b29d683edb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -52,6 +52,11 @@ LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +ComputationGraph without_layer_names(ComputationGraph const &); + +bool computation_graphs_are_isomorphic(ComputationGraph const &, + ComputationGraph const &); + std::string as_dot(ComputationGraph const &); void debug_print_dot(ComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h new file mode 100644 index 0000000000..a1ca0aceed --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H + +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::V1BinarySPDecomposition> { + static ::FlexFlow::V1BinarySPDecomposition from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySPDecomposition const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinarySeriesSplit> { + static ::FlexFlow::V1BinarySeriesSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySeriesSplit const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinaryParallelSplit> { + static ::FlexFlow::V1BinaryParallelSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinaryParallelSplit const &); +}; + +} // namespace nlohmann + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..d2d0c3bc77 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml new file mode 100644 index 0000000000..317fa8b6ce --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..0fe0b1761f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "V1BinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::V1BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::V1BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "int" +key = "leaf" diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 2e878c5c53..c67999ff32 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -41,7 +41,7 @@ key = "normal" [[values]] type = "::FlexFlow::TruncatedNormalInitializerAttrs" -key = "normal" +key = "truncated_normal" [[values]] type = "::FlexFlow::ConstantInitializerAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index d7248afde4..c740e1ffd2 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -21,6 +22,15 @@ ParallelLayerAddedResult std::vector const &inputs, std::vector const &output_labels); +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape); + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &, + parallel_layer_guid_t const &); + std::vector get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -37,8 +47,12 @@ std::vector ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); +ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::vector topological_ordering(ParallelComputationGraph const &); @@ -47,6 +61,11 @@ parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); +ParallelComputationGraph without_layer_names(ParallelComputationGraph const &); + +bool pcgs_are_isomorphic(ParallelComputationGraph const &, + ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 4d61f24d37..027b9f6c80 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/operator_attrs.h", + "op-attrs/pcg_operator_attrs.dtg.h", "utils/stack_string.h", "", ] diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index a69e54fd93..3d1bc629e4 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -11,6 +11,8 @@ #include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" @@ -175,6 +177,26 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +ComputationGraph without_layer_names(ComputationGraph const &cg) { + return ComputationGraph{ + LabelledDataflowGraph::create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels(cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, + ComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + std::string as_dot(ComputationGraph const &cg) { std::function get_node_label = [](LayerAttrs const &a) -> std::string { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 4a565476bd..dff647f5a1 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -489,11 +489,12 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " - "{} (should be {} or {})", - this->get_shape(input).data_type, - DataType::INT32, - DataType::INT64); + throw mk_runtime_error( + fmt::format("Invalid data type for input tensor 2 for Gather: " + "{} (should be {} or {})", + this->get_shape(input).data_type, + DataType::INT32, + DataType::INT64)); } TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..5341e03c0a --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,84 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include "utils/exception.h" +#include "utils/fmt/json.h" +#include "utils/overload.h" + +using namespace ::FlexFlow; + +namespace nlohmann { + +V1BinarySPDecomposition + adl_serializer::from_json(json const &j) { + std::string type = j.at("type").get(); + + if (type == "series") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "parallel") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "leaf") { + return V1BinarySPDecomposition{ + j.at("value").get(), + }; + } else { + throw mk_runtime_error(fmt::format( + "Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" " + "in json object: {}", + type, + j)); + } +} + +void adl_serializer::to_json( + json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload{ + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, + }); +} + +V1BinarySeriesSplit + adl_serializer::from_json(json const &j) { + return V1BinarySeriesSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinarySeriesSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +V1BinaryParallelSplit + adl_serializer::from_json(json const &j) { + return V1BinaryParallelSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinaryParallelSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +} // namespace nlohmann diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b04d9d37b3..781c44640c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -4,8 +4,11 @@ #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -42,6 +45,39 @@ ParallelLayerAddedResult }; } +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); +} + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst) { + std::unordered_set raw_edges = + get_dataflow_edges_from_node_to_node( + pcg.raw_graph, src.raw_graph_node, dst.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + std::vector get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -110,12 +146,23 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, return pcg.raw_graph.at(l.raw_graph_node); } +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).op_attrs; +} + ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, parallel_tensor_guid_t const &t) { return pcg.raw_graph.at(t.raw_graph_output); } +ParallelTensorShape + get_parallel_tensor_shape(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg, t).shape; +} + std::vector topological_ordering(ParallelComputationGraph const &pcg) { return transform(get_topological_ordering(pcg.raw_graph), @@ -132,4 +179,28 @@ parallel_layer_guid_t return get_only(found); } +ParallelComputationGraph + without_layer_names(ParallelComputationGraph const &pcg) { + return ParallelComputationGraph{ + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + pcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index ce00ea62f4..f33b4dcd17 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,5 +1,18 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -182,7 +195,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, std::optional const &maybe_name) { LinearAttrs attrs = LinearAttrs{ @@ -205,9 +218,10 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::vector weights; { - ParallelTensorShape kernel_shape = + ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(attrs, input_shape)); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + weights.push_back( + make_weight_attrs(projection_shape, projection_initializer)); } if (use_bias) { diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..9068e14517 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,178 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }, + }; + + nlohmann::json example_json = { + {"type", "series"}, + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_tree; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySPDecomposition result = + example_json.get(); + V1BinarySPDecomposition correct = example_tree; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySeriesSplit result = example_json.get(); + V1BinarySeriesSplit correct = example_split; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinaryParallelSplit result = example_json.get(); + V1BinaryParallelSplit correct = example_split; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 77d938e08a..fc07edf5b3 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" @@ -262,4 +265,51 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("pcg_add_input_layer") { + ParallelTensorShape tensor_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraph result = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + pcg_add_input_layer(pcg, tensor_shape); + return pcg; + }(); + + ParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); + + return pcg; + }(); + + CHECK(pcgs_are_isomorphic(result, correct)); + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index c445085635..20bd0ac92d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/ops/conv_2d.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index e46a481a1a..471f2a2709 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" -#include "op-attrs/operator_attrs.h" #include "serialization.h" #include diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h index de9d1cd78a..b7ce13db0e 100644 --- a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 0bbe0e97a7..0c673f0a8a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -4,7 +4,7 @@ #include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" @@ -54,12 +54,6 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( UnorderedSetLabelledOpenDataflowGraph>( sub_pcg.raw_graph)}; - // return ParallelComputationGraph{ - // make_lazy_copy_of< - // UnorderedSetLabelledOpenDataflowGraph - // >(sub_pcg.raw_graph) - // }; } parallel_layer_guid_t diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc index 0bde326bd1..9fa91d75b7 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,5 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/get_output_shapes.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc index 6c3d53d3b9..4d4e557fb8 100644 --- a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,7 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" diff --git a/lib/utils/include/utils/any_value_type/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h new file mode 100644 index 0000000000..a99ce5c8f0 --- /dev/null +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -0,0 +1,66 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H + +#include +#include +#include +#include + +namespace FlexFlow { + +struct any_value_type { +public: + any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string); + + bool operator==(any_value_type const &other) const; + bool operator!=(any_value_type const &other) const; + + template + T get() const { + return std::any_cast(value); + } + + friend std::string format_as(any_value_type const &); + +private: + std::any value; + std::function eq; + std::function neq; + std::function hash; + std::function to_string; + + friend std::hash; +}; + +template +any_value_type make_any_value_type(T const &t) { + return any_value_type{ + std::make_any(t), + [](std::any const &l, std::any const &r) { + return std::any_cast(l) == std::any_cast(r); + }, + [](std::any const &l, std::any const &r) { + return std::any_cast(l) != std::any_cast(r); + }, + [](std::any const &v) { return std::hash{}(std::any_cast(v)); }, + [](std::any const &v) { return fmt::to_string(std::any_cast(v)); }, + }; +} + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::any_value_type> { + size_t operator()(::FlexFlow::any_value_type const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h new file mode 100644 index 0000000000..1635747612 --- /dev/null +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H + +#include +#include + +namespace FlexFlow { + +template +struct value_type { + value_type() = delete; + + value_type(value_type const &) { + assert(false); + } + value_type &operator=(value_type const &) { + assert(false); + } + + value_type(value_type &&) { + assert(false); + } + value_type &operator=(value_type &&) { + assert(false); + } + + bool operator==(value_type const &) const { + assert(false); + } + bool operator!=(value_type const &) const { + assert(false); + } +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::value_type> { + size_t operator()(::FlexFlow::value_type const &) const { + assert(false); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6ac9eb10b0..0e3b1fc0bd 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -126,34 +126,6 @@ std::optional optional_all_of(Container const &container, return true; } -template -std::vector flatmap(std::vector const &v, F const &f) { - std::vector result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - template std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; diff --git a/lib/utils/include/utils/containers/cartesian_product.h b/lib/utils/include/utils/containers/cartesian_product.h new file mode 100644 index 0000000000..bcba52113e --- /dev/null +++ b/lib/utils/include/utils/containers/cartesian_product.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H + +#include "utils/containers/vector_of.h" +#include "utils/hash/vector.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set> + cartesian_product(std::vector> const &containers) { + std::unordered_set> result; + + std::function &, size_t)> recurse = + [&](std::vector ¤t, size_t depth) { + if (depth == containers.size()) { + result.insert(current); + return; + } + + for (E const &item : containers.at(depth)) { + current.push_back(item); + recurse(current, depth + 1); + current.pop_back(); + } + }; + + std::vector current; + recurse(current, 0); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 0f8906f34a..b016a1e03d 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -3,7 +3,9 @@ #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/merge_maps.h" #include +#include namespace FlexFlow { @@ -39,6 +41,23 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } +template < + typename InK, + typename InV, + typename F, + typename OutK = typename std::invoke_result_t::key_type, + typename OutV = typename std::invoke_result_t::mapped_type> +std::unordered_map flatmap(std::unordered_map const &m, + F &&f) { + std::unordered_map result; + + for (auto const &[k, v] : m) { + result = merge_maps(result, f(k, v)); + } + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h new file mode 100644 index 0000000000..b7b30cbae4 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H + +#include "utils/containers/cartesian_product.h" +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/zip.h" +#include "utils/hash/unordered_map.h" +#include +#include +#include + +namespace FlexFlow { + +/** + * @note If \p options_per_key is empty, an set containing a single empty + * assignment is returned + */ +template +std::unordered_set> get_all_assignments( + std::unordered_map> const &options_per_key) { + if (options_per_key.empty()) { + return {{}}; + } + + std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector> ordered_value_option_sets = transform( + ordered_keys, [&](K const &k) { return options_per_key.at(k); }); + + std::unordered_set> result = transform( + cartesian_product(ordered_value_option_sets), + [&](std::vector const &chosen_values) { + return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); + }); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index fedb87413d..201095c47d 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -10,8 +10,8 @@ namespace FlexFlow { template typename C::value_type get_only(C const &c) { return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error("Encountered container with size {} in get_only", - c.size()); + throw mk_runtime_error(fmt::format( + "Encountered container with size {} in get_only", c.size())); }); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 653c9d24f1..dd886ab8aa 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -3,6 +3,8 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" +#include "utils/exception.h" +#include "utils/fmt/unordered_map.h" #include namespace FlexFlow { @@ -10,7 +12,12 @@ namespace FlexFlow { template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); + if (!are_disjoint(keys(lhs), keys(rhs))) { + throw mk_runtime_error(fmt::format("Key sets of merge_maps parameters are " + "non-disjoint: lhs = {}, rhs = {}", + lhs, + rhs)); + } std::unordered_map result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ec3d5f5612..ef6a26c79a 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -22,9 +22,7 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template ()(std::declval()))> +template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; for (auto const &e : v) { @@ -33,9 +31,17 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } -template ()(std::declval()))> +template > +std::unordered_multiset transform(std::unordered_multiset const &v, + F const &f) { + std::unordered_multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + +template > std::set transform(std::set const &v, F const &f) { std::set result; for (auto const &e : v) { @@ -44,6 +50,15 @@ std::set transform(std::set const &v, F const &f) { return result; } +template > +std::multiset transform(std::multiset const &v, F const &f) { + std::multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template std::string transform(std::string const &s, F const &f) { std::string result; diff --git a/lib/utils/include/utils/containers/try_at.h b/lib/utils/include/utils/containers/try_at.h new file mode 100644 index 0000000000..45e50fca27 --- /dev/null +++ b/lib/utils/include/utils/containers/try_at.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H + +#include "utils/containers/contains_key.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::optional try_at(std::unordered_map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +template +std::optional try_at(std::map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_map_from_pairs.h b/lib/utils/include/utils/containers/unordered_map_from_pairs.h new file mode 100644 index 0000000000..660c57c5e7 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_pairs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H + +#include + +namespace FlexFlow { + +template +std::unordered_map unordered_map_from_pairs(C const &c) { + return std::unordered_map(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/values.h b/lib/utils/include/utils/containers/values.h index 7c487d1d43..2a730ccc42 100644 --- a/lib/utils/include/utils/containers/values.h +++ b/lib/utils/include/utils/containers/values.h @@ -1,15 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H -#include +#include namespace FlexFlow { template -std::vector values(C const &c) { - std::vector result; +std::unordered_multiset values(C const &c) { + std::unordered_multiset result; for (auto const &kv : c) { - result.push_back(kv.second); + result.insert(kv.second); } return result; } diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index 20a8098040..080cbb3611 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -34,12 +34,7 @@ T throw_if_unexpected(tl::expected const &r) { } } -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args) { - return std::runtime_error( - fmt::vformat(fmt_str, fmt::make_format_args(args...))); -} +std::runtime_error mk_runtime_error(std::string const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h new file mode 100644 index 0000000000..c7aa87e3eb --- /dev/null +++ b/lib/utils/include/utils/fmt/json.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H + +#include +#include + +namespace fmt { + +template +struct formatter<::nlohmann::json, Char> : formatter { + template + auto format(::nlohmann::json const &j, FormatContext &ctx) { + std::ostringstream oss; + oss << j; + return formatter::format(oss.str(), ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h new file mode 100644 index 0000000000..884f4d389e --- /dev/null +++ b/lib/utils/include/utils/fmt/monostate.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H + +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::monostate, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::std::monostate const &, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = ""; + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &, std::monostate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.h b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h new file mode 100644 index 0000000000..e3ed967a23 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path(); +BinaryTreePath nest_inside_left_child(BinaryTreePath const &); +BinaryTreePath nest_inside_right_child(BinaryTreePath const &); + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &); +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml new file mode 100644 index 0000000000..08955c2d75 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "BinaryTreePath" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "utils/full_binary_tree/binary_tree_path_entry.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "entries" +type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml new file mode 100644 index 0000000000..6c81123dcf --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "BinaryTreePathEntry" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT_CHILD" +key = "left" + +[[values]] +name = "RIGHT_CHILD" +key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..9cf5d63210 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform( + find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform( + find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml new file mode 100644 index 0000000000..bf08701840 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "get_left_child" +type = "std::function" + +[[fields]] +name = "get_right_child" +type = "std::function" + +[[fields]] +name = "is_leaf" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" + +[[fields]] +name = "require_parent" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml new file mode 100644 index 0000000000..1f8af17cf3 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml new file mode 100644 index 0000000000..7418d7a016 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeVisitor" +features = [] + +template_params = [ + "Result", + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "parent_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..822acfe9ee --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h new file mode 100644 index 0000000000..7517028ec0 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include + +namespace FlexFlow { + +template +Tree get_child(Parent const &parent, + FullBinaryTreeImplementation const &impl, + BinaryTreePathEntry const &e) { + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return impl.get_left_child(parent); + case BinaryTreePathEntry::RIGHT_CHILD: + return impl.get_right_child(parent); + default: + throw mk_runtime_error( + fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h new file mode 100644 index 0000000000..8f9d8e919f --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = + FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_multiset { + return multiset_union( + get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) -> std::unordered_multiset { + return {leaf}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..922a42242c --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H + +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor{ + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..83ce1367b9 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + BinaryTreePath const &p) { + if (p == binary_tree_root_path()) { + return tree; + } + + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { return std::nullopt; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h new file mode 100644 index 0000000000..832d39bdff --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" + +namespace FlexFlow { + +template +Result visit(Tree const &tree, + FullBinaryTreeImplementation const &impl, + FullBinaryTreeVisitor const &visitor) { + if (impl.is_leaf(tree)) { + return visitor.leaf_func(impl.require_leaf(tree)); + } else { + return visitor.parent_func(impl.require_parent(tree)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h new file mode 100644 index 0000000000..de7ead8fb6 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h new file mode 100644 index 0000000000..be0e57435a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h new file mode 100644 index 0000000000..e53bb876a1 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h new file mode 100644 index 0000000000..ad8eadda0e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml new file mode 100644 index 0000000000..32582a6b74 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "SplitBoundaryNodes" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h new file mode 100644 index 0000000000..916e8f7896 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml new file mode 100644 index 0000000000..54c710b26e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TransitiveReducedDataflowGraphView" +features = [] + +includes = [ + "utils/graph/dataflow_graph/dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_dataflow_graph" +type = "::FlexFlow::DataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h new file mode 100644 index 0000000000..240fc66426 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" +namespace FlexFlow { + +std::unordered_set + get_edges_from_subgraph_to_subgraph(DiGraphView const &, + std::unordered_set const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h similarity index 88% rename from lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h rename to lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h index a8e08cb995..b9894fbac3 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" @@ -8,6 +8,10 @@ namespace FlexFlow { +// NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. +// Figuring out what to do with it is tracked in +// https://github.com/flexflow/FlexFlow/issues/1513 + template struct LazyLabelledDataflowGraph final : public ILabelledDataflowGraph { @@ -42,11 +46,11 @@ struct LazyLabelledDataflowGraph final return this->get_view().query_outputs(q); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->get_view().at(n); } - ValueLabel const &at(DataflowOutput const &v) const override { + ValueLabel at(DataflowOutput const &v) const override { return this->get_view().at(v); } @@ -95,7 +99,7 @@ template static typename std::enable_if< std::is_base_of, T>::value, LabelledDataflowGraph>::type - make_lazy_copy_of( + create_lazy_copy_of_labelled_dataflow_graph_view( LabelledDataflowGraphView const &view) { std::function( LabelledDataflowGraphView const &)> diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..07aa64aa62 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" + +namespace FlexFlow { + +template > +LabelledDataflowGraphView rewrite_node_labels( + LabelledDataflowGraphView const &g, F f) { + return rewrite_node_labels( + view_as_labelled_open_dataflow_graph(g), f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml new file mode 100644 index 0000000000..37e3bbee09 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml new file mode 100644 index 0000000000..7e6e86ba76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index b1607e7a76..de48cd17e9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -1,23 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include namespace FlexFlow { -BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree make_leaf_node(Node const &); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree(); bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml deleted file mode 100644 index 1241311150..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "BinarySPDecompositionTree" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..c586b49d9d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::Node" +key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..105f5490a4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return find_paths_to_leaf(tree, full_binary_impl, needle); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h deleted file mode 100644 index 42d71ce54e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace FlexFlow { - -template -std::string format_as(GenericBinarySeriesSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySeriesSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinaryParallelSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinaryParallelSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return fmt::format("", s); - }, - [](GenericBinaryParallelSplit const &s) { - return fmt::format("", s); - }, - [](T const &t) { - return fmt::format("", t); - }, - }); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySPDecompositionTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h deleted file mode 100644 index 74f5ba5d8a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h +++ /dev/null @@ -1,155 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct GenericBinarySPDecompositionTree; - -template -struct GenericBinarySeriesSplit { -public: - GenericBinarySeriesSplit() = delete; - explicit GenericBinarySeriesSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; - - bool operator==(GenericBinarySeriesSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySeriesSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySeriesSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinaryParallelSplit { -public: - GenericBinaryParallelSplit() = delete; - explicit GenericBinaryParallelSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; - - bool operator==(GenericBinaryParallelSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinaryParallelSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinaryParallelSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinarySPDecompositionTree { -public: - GenericBinarySPDecompositionTree() = delete; - explicit GenericBinarySPDecompositionTree( - GenericBinarySeriesSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree( - GenericBinaryParallelSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} - - GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = - default; - - bool operator==(GenericBinarySPDecompositionTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySPDecompositionTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySPDecompositionTree const &other) const { - return this->tie() < other.tie(); - } - -public: - std::variant, GenericBinaryParallelSplit, T> - root; - -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -// namespace rc { -// -// template <> -// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { -// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { -// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { -// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); -// }; -// -// } // namespace rc - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h new file mode 100644 index 0000000000..0bddbee81c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &impl) { + + using Parent = std::variant; + + auto full_binary_impl = FullBinaryTreeImplementation{ + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, + parent); + }, + /*get_right_child=*/ + [impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, + parent); + }, + /*is_leaf=*/ + [impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/ + [impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/ + [impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format( + "Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + }}; + + return full_binary_impl; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml new file mode 100644 index 0000000000..3ccbfd959b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", + "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", +] + +[[fields]] +name = "series_get_left_child" +type = "std::function" + +[[fields]] +name = "parallel_get_left_child" +type = "std::function" + +[[fields]] +name = "series_get_right_child" +type = "std::function" + +[[fields]] +name = "parallel_get_right_child" +type = "std::function" + +[[fields]] +name = "get_node_type" +type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" + +[[fields]] +name = "require_series" +type = "std::function" + +[[fields]] +name = "require_parallel" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml new file mode 100644 index 0000000000..6275c82a0c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeVisitor" +features = [] + +template_params = [ + "ReturnType", + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "series_func" +type = "std::function" + +[[fields]] +name = "parallel_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h deleted file mode 100644 index c6c1186d3d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -TT const &get(GenericBinarySPDecompositionTree const &t) { - return std::get(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..b0bb8355db --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_all_leaf_paths(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 51e1e20bac..c543375148 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -1,38 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#include "utils/containers/multiset_union.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" -#include +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](T const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, - [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, - }); -} +template +std::unordered_multiset get_leaves( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { - return multiset_union(get_leaves(get_left_child(s)), - get_leaves(get_right_child(s))); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { - return multiset_union(get_leaves(get_left_child(p)), - get_leaves(get_right_child(p))); + return get_leaves(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h deleted file mode 100644 index 46a460b64e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &s) { - return *s.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &p) { - return *p.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_left_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_left_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_left_child incorrectly called on leaf node"); - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h deleted file mode 100644 index 883acda480..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplit const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](T const &) { return SPDecompositionTreeNodeType::NODE; }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index 7c6d28d7b4..4678e0c0f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -1,38 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { - return visit(tt, - overload{ - [](T const &t) { return 1; }, - [](GenericBinarySeriesSplit const &s) { - return get_num_tree_nodes(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_num_tree_nodes(p); - }, - }); -} +template +int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { - return 1 + get_num_tree_nodes(get_left_child(s)) + - get_num_tree_nodes(get_right_child(s)); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { - return 1 + get_num_tree_nodes(get_left_child(p)) + - get_num_tree_nodes(get_right_child(p)); + return get_num_tree_nodes(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h deleted file mode 100644 index f0bfba43a2..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &s) { - return *s.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &p) { - return *p.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_right_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_right_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_right_child incorrectly called on leaf node"); - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..c48185fb7f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + BinaryTreePath const &path) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_subtree_at_path(tree, full_binary_impl, path); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h deleted file mode 100644 index 983dc4a572..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::GenericBinarySeriesSplit> { - size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinaryParallelSplit> { - size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { - size_t operator()( - ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { - return get_std_hash(s.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h deleted file mode 100644 index 8086f38244..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); -} - -template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); -} - -template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 3ffa63753a..68e0a3af32 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,32 +1,44 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](T const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_right_child(s)) && - is_binary_sp_tree_left_associative(get_left_child(s)) && - is_binary_sp_tree_left_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_right_child(p)) && - is_binary_sp_tree_left_associative(get_left_child(p)) && - is_binary_sp_tree_left_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index d88459b432..7042765203 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,32 +1,43 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](T const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h deleted file mode 100644 index 4f1f8266e1..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace nlohmann { - -template -struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { - static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { - return ::FlexFlow::GenericBinarySeriesSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinarySeriesSplit const &v) { - j["__type"] = "GenericBinarySeriesSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { - static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { - return ::FlexFlow::GenericBinaryParallelSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinaryParallelSplit const &v) { - j["__type"] = "GenericBinaryParallelSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { - static ::FlexFlow::GenericBinarySPDecompositionTree - from_json(json const &j) { - std::string key = j.at("type").get(); - - if (key == "series") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), - }; - } else if (key == "parallel") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), - }; - } else if (key == "leaf") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get(), - }; - } else { - throw ::FlexFlow::mk_runtime_error( - fmt::format("Unknown json type key: {}", key)); - } - } - - static void - to_json(json &j, - ::FlexFlow::GenericBinarySPDecompositionTree const &v) { - j["__type"] = "GenericBinarySPDecompositionTree"; - ::FlexFlow::visit( - v, - ::FlexFlow::overload{ - [&](::FlexFlow::GenericBinarySeriesSplit const &s) { - j["type"] = "series"; - j["value"] = s; - return std::monostate{}; - }, - [&](::FlexFlow::GenericBinaryParallelSplit const &p) { - j["type"] = "parallel"; - j["value"] = p; - return std::monostate{}; - }, - [&](T const &t) { - j["type"] = "leaf"; - j["value"] = t; - return std::monostate{}; - }, - }); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h deleted file mode 100644 index f55b71146a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree make_generic_binary_series_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - lhs, - rhs, - }, - }; -} - -template -GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - lhs, - rhs, - }, - }; -} - -template -GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { - return GenericBinarySPDecompositionTree{t}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h deleted file mode 100644 index a8de1ee8f8..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" - -namespace FlexFlow { - -template -GenericBinarySeriesSplit const & - require_series(GenericBinarySPDecompositionTree const &t) { - return get>(t); -} - -template -GenericBinaryParallelSplit const & - require_parallel(GenericBinarySPDecompositionTree const &t) { - return get>(t); -} - -template -T const &require_node(GenericBinarySPDecompositionTree const &t) { - return get(t); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h deleted file mode 100644 index 4d7fa05960..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" - -namespace FlexFlow { - -template > -GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, F f) { - return visit>( - tt, - overload{ - [&](GenericBinarySeriesSplit const &s) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, - }; - }, - [&](GenericBinaryParallelSplit const &s) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, - }; - }, - [&](T const &t) { - return GenericBinarySPDecompositionTree{ - f(t), - }; - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 0d9503e59f..c06db135b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -2,34 +2,45 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, F f) { - if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative(tt.root)) { - return f(std::get(tt.root)); - } else { - throw mk_runtime_error( - "Unexpected case in visit(GenericBinarySPDecompositionTree)"); +template +ReturnType + visit(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + ReturnType result = visitor.series_func(impl.require_series(tree)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + ReturnType result = visitor.parallel_func(impl.require_parallel(tree)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + ReturnType result = visitor.leaf_func(impl.require_leaf(tree)); + return result; + } + default: + throw mk_runtime_error(fmt::format( + "Unknown SPDecompositionTreeNodeType value: {}", node_type)); } - - // return std::visit(tt.root, overload { - // [&](GenericBinarySeriesSplit const &s) -> Result { - // return f(s); - // }, - // [&](GenericBinaryParallelSplit const &p) -> Result { - // return f(p); - // }, - // [&](T const &t) -> Result { - // return f(t); - // }, - // }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml new file mode 100644 index 0000000000..dd68adf3f6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct SeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 18434d2b67..7374b45a60 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,80 +1,76 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#include "utils/graph/node/node.dtg.h" -#include -#include +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" namespace FlexFlow { -struct SeriesSplit; -struct ParallelSplit; - -struct SeriesSplit { -public: - SeriesSplit() = delete; - explicit SeriesSplit(std::vector> const &); - explicit SeriesSplit( - std::initializer_list> const &); - - bool operator==(SeriesSplit const &) const; - bool operator!=(SeriesSplit const &) const; - -public: - std::vector> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(SeriesSplit const &); -std::ostream &operator<<(std::ostream &, SeriesSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::SeriesSplit> { - size_t operator()(::FlexFlow::SeriesSplit const &) const; -}; - -} // namespace std - -namespace FlexFlow { - -struct ParallelSplit { -public: - ParallelSplit() = delete; - explicit ParallelSplit( - std::unordered_multiset> const &); - explicit ParallelSplit( - std::initializer_list> const &); - - bool operator==(ParallelSplit const &) const; - bool operator!=(ParallelSplit const &) const; - -public: - std::unordered_multiset> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(ParallelSplit const &); -std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// struct SeriesSplit { +// public: +// SeriesSplit() = delete; +// explicit SeriesSplit(std::vector> const +// &); explicit SeriesSplit( +// std::initializer_list> const &); +// +// bool operator==(SeriesSplit const &) const; +// bool operator!=(SeriesSplit const &) const; +// +// public: +// std::vector> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(SeriesSplit const &); +// std::ostream &operator<<(std::ostream &, SeriesSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::SeriesSplit> { +// size_t operator()(::FlexFlow::SeriesSplit const &) const; +// }; +// +// } // namespace std +// +// namespace FlexFlow { +// +// struct ParallelSplit { +// public: +// ParallelSplit() = delete; +// explicit ParallelSplit( +// std::unordered_multiset> const &); +// explicit ParallelSplit( +// std::initializer_list> const &); +// +// bool operator==(ParallelSplit const &) const; +// bool operator!=(ParallelSplit const &) const; +// +// public: +// std::unordered_multiset> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(ParallelSplit const &); +// std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::ParallelSplit> { +// size_t operator()(::FlexFlow::ParallelSplit const &) const; +// }; } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::ParallelSplit> { - size_t operator()(::FlexFlow::ParallelSplit const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml new file mode 100644 index 0000000000..fdb0a29972 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "SeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h new file mode 100644 index 0000000000..dd5f397c19 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H + +#include "utils/json/is_json_deserializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h new file mode 100644 index 0000000000..dfcb26081d --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H + +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ + #TYPENAME " should be json serializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 6c66949fd8..07e4554299 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -135,7 +135,7 @@ auto seq_get(F const &f, int i, seq const &s) template auto seq_get(F const &f, int i, seq<> const &) -> decltype(f(std::declval>())) { - throw mk_runtime_error("Failed seq_get for index {}", i); + throw mk_runtime_error(fmt::format("Failed seq_get for index {}", i)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index afc16d4c4b..0296e365a3 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -67,8 +67,8 @@ template std::any get(std::tuple const &t, int idx) { size_t tuple_size = std::tuple_size::value; if (idx < 0 || idx >= tuple_size) { - throw mk_runtime_error( - "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size); + throw mk_runtime_error(fmt::format( + "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size)); } std::any result; visit_tuple(t, tuple_get_visitor{idx, result}); diff --git a/lib/utils/src/utils/any_value_type/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc new file mode 100644 index 0000000000..d4c605c441 --- /dev/null +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -0,0 +1,34 @@ +#include "utils/any_value_type/any_value_type.h" + +namespace FlexFlow { + +any_value_type::any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string) + : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) {} + +bool any_value_type::operator==(any_value_type const &other) const { + return this->eq(this->value, other.value); +} + +bool any_value_type::operator!=(any_value_type const &other) const { + return this->neq(this->value, other.value); +} + +std::string format_as(any_value_type const &v) { + return v.to_string(v.value); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::any_value_type>::operator()( + ::FlexFlow::any_value_type const &v) const { + return v.hash(v); +} + +} // namespace std diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc new file mode 100644 index 0000000000..f7da47d8f9 --- /dev/null +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template struct value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/cartesian_product.cc b/lib/utils/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..b716a49ad5 --- /dev/null +++ b/lib/utils/src/utils/containers/cartesian_product.cc @@ -0,0 +1 @@ +#include "utils/containers/cartesian_product.h" diff --git a/lib/utils/src/utils/containers/get_all_assignments.cc b/lib/utils/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..3a7cf6377a --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_assignments.cc @@ -0,0 +1 @@ +#include "utils/containers/get_all_assignments.h" diff --git a/lib/utils/src/utils/containers/try_at.cc b/lib/utils/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..0d1ed3b04a --- /dev/null +++ b/lib/utils/src/utils/containers/try_at.cc @@ -0,0 +1 @@ +#include "utils/containers/try_at.h" diff --git a/lib/utils/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..60cc978be7 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/containers/unordered_map_from_pairs.h" diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc index 9bbf780fd8..c645f241aa 100644 --- a/lib/utils/src/utils/exception.cc +++ b/lib/utils/src/utils/exception.cc @@ -1 +1,9 @@ #include "utils/exception.h" + +namespace FlexFlow { + +std::runtime_error mk_runtime_error(std::string const &s) { + return std::runtime_error(s); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc new file mode 100644 index 0000000000..49ad57fba7 --- /dev/null +++ b/lib/utils/src/utils/fmt/json.cc @@ -0,0 +1,7 @@ +#include "utils/fmt/json.h" + +namespace fmt { + +template struct formatter<::nlohmann::json, char>; + +} diff --git a/lib/utils/src/utils/fmt/monostate.cc b/lib/utils/src/utils/fmt/monostate.cc new file mode 100644 index 0000000000..55988cdce0 --- /dev/null +++ b/lib/utils/src/utils/fmt/monostate.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/monostate.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::monostate const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc new file mode 100644 index 0000000000..8445a2721a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc @@ -0,0 +1,34 @@ +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/subvec.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path() { + return BinaryTreePath{{}}; +} + +BinaryTreePath nest_inside_left_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::LEFT_CHILD); + return result; +} + +BinaryTreePath nest_inside_right_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::RIGHT_CHILD); + return result; +} + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &p) { + return p.entries.at(0); +} + +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &p) { + return BinaryTreePath{ + subvec(p.entries, 1, std::nullopt), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..47845720ed --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_set + find_paths_to_leaf(Tree const &, + FullBinaryTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..b4d8aa1011 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -0,0 +1,12 @@ +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, + value_type<1>, + value_type<2>> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc new file mode 100644 index 0000000000..19362ae510 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_child.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template Tree + get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc new file mode 100644 index 0000000000..0d7e9106f6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -0,0 +1,14 @@ +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..7a99dd60fa --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -0,0 +1,13 @@ +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template int get_num_tree_nodes( + Tree const &, FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..1eea13fedd --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::optional get_subtree_at_path( + Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc new file mode 100644 index 0000000000..4a4f7c9302 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -0,0 +1,9 @@ +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..c07d344d05 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,15 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst) { + return g.query_edges(DataflowEdgeQuery{ + /*src_nodes=*/query_set{src}, + /*src_idxs=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_idxs=*/query_set::matchall(), + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..70a66c9a21 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set edges = + get_transitive_reduced_edges_across_split(tr_g, split); + + std::unordered_set src_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.src.node; }); + + std::unordered_set dst_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.dst.node; }); + + return SplitBoundaryNodes{ + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..8a4adf0b3a --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,27 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set src_subgraph = + unordered_set_of(get_leaves(split.get_left_child())); + std::unordered_set dst_subgraph = + unordered_set_of(get_leaves(split.get_right_child())); + + std::unordered_set raw_edges = + get_edges_from_subgraph_to_subgraph( + tr_g.transitive_reduction, src_subgraph, dst_subgraph); + + return flatmap(raw_edges, [&](DirectedEdge const &e) { + return get_dataflow_edges_from_node_to_node( + tr_g.full_dataflow_graph, e.src, e.dst); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0bb94c87f4 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,14 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + return transform(get_transitive_reduced_edges_across_split(tr_g, split), + [](DataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc new file mode 100644 index 0000000000..81751702a2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc @@ -0,0 +1,17 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { + DiGraphView as_digraph = g; + DiGraphView transitive_reduced = transitive_reduction(as_digraph); + + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..2c6606a06b --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,25 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +std::unordered_set get_edges_from_subgraph_to_subgraph( + DiGraphView const &g, + std::unordered_set const &src_subgraph, + std::unordered_set const &dst_subgraph) { + if (!are_disjoint(src_subgraph, dst_subgraph)) { + throw mk_runtime_error( + fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) " + "expected src_subgraph and dst_subgraph to be disjoint, " + "but found src_subgraph={}, dst_subgraph={}", + src_subgraph, + dst_subgraph)); + } + + return g.query_edges(DirectedEdgeQuery{ + /*srcs=*/query_set{src_subgraph}, + /*dsts=*/query_set{dst_subgraph}, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 4f4c846433..941c8e8e3e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -6,6 +6,7 @@ #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 92a6d0b9eb..61c4f80763 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -23,12 +23,12 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.bigger); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.bigger)); } if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.smaller); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.smaller)); } this->adjacency.at(e.bigger).insert(e.smaller); diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc new file mode 100644 index 0000000000..28d63f9ee1 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..dc5ce4fbda --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 4ade34941c..08dda09698 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -5,6 +5,7 @@ #include "utils/containers/values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/hash/vector.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 18d1f922c6..62489ff75f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -2,42 +2,84 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" namespace FlexFlow { -BinarySPDecompositionTree - make_series_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), - }; -} +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree() { -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + return GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node>{ + /*series_get_left_child=*/[](BinarySeriesSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/ + [](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/ + [](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, }; } -BinarySPDecompositionTree make_leaf_node(Node const &n) { - return BinarySPDecompositionTree{ - make_generic_binary_sp_leaf(n), - }; +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_left_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_left_associative(tt.raw_tree); +bool is_binary_sp_tree_right_associative( + BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_right_associative(tt.raw_tree); +std::unordered_multiset + get_leaves(BinarySPDecompositionTree const &tree) { + return get_leaves(tree, generic_impl_for_binary_sp_tree()); } -std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { - return get_leaves(tt.raw_tree); +SPDecompositionTreeNodeType + get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload{ + [](BinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](BinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..07e2c3e3e3 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set find_paths_to_leaf( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 4cd7206408..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc deleted file mode 100644 index 3a4dbad8ec..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc new file mode 100644 index 0000000000..56a6d0cc85 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc deleted file mode 100644 index 4ee18af5be..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..71d3f6ac31 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 71b67acc54..3bb90bfa32 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 227e5bd79c..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc deleted file mode 100644 index 1618128226..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 05ec6b5925..3d166145c1 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index f168ba1e2f..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..d1d8079c0b --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::optional get_subtree_at_path( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index 75c472c435..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc deleted file mode 100644 index 3da024743c..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 8fe9397003..69cbb28582 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index d202f55964..584099e33e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index b569ff9265..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc deleted file mode 100644 index fb1532b3ef..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc deleted file mode 100644 index 3fee45fcf5..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index cabd66cff7..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc index 25409333f2..056ae2a8d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -1 +1,24 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using ReturnType = value_type<0>; +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template ReturnType + visit(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 02e541b7e4..69b2ebea8e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -2,50 +2,53 @@ #include "utils/containers/foldl1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; - auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { - return GenericBinarySPDecompositionTree{n}; + auto from_node = [](Node const &n) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{n}; }; - auto from_series = - [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = + auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + std::vector children = transform(s.children, from_series_child); - return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; + }); }; auto from_parallel = - [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = - transform(vector_of(s.children), from_parallel_child); - return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{accum, x}}; - }); + [&](ParallelSplit const &s) -> BinarySPDecompositionTree { + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; + }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -54,7 +57,7 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -63,13 +66,11 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 673a4118a6..478d90e0c3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -2,47 +2,50 @@ #include "utils/containers/foldr1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; - auto from_node = [](Node const &n) { - return GenericBinarySPDecompositionTree{n}; - }; + auto from_node = [](Node const &n) { return BinarySPDecompositionTree{n}; }; auto from_series = [&](SeriesSplit const &s) { - std::vector> children = + std::vector children = transform(s.children, from_series_child); - return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{x, accum}}; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) { - std::vector> children = - transform(vector_of(s.children), from_parallel_child); - return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{x, accum}}; - }); + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; + }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -51,7 +54,7 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -60,13 +63,11 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index ab231f256c..cd29af59a0 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -33,10 +33,10 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); std::unordered_map - ttsp_edge_to_sp_tree = - map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { return make_leaf_node(n); }); + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -46,8 +46,12 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = make_parallel_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -62,8 +66,12 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = make_series_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 48c936ec39..410a40236d 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,8 +1,6 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/containers/extend.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { @@ -49,36 +47,31 @@ std::variant flatten_ast( return std::visit(FlattenAST{}, ast); } -std::variant - from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { - return visit>( - binary, - overload{ - [](Node const &n) { return n; }, - [](GenericBinarySeriesSplit const &s) { - return IntermediateSpDecompositionTree{ - SplitType::SERIES, - { - from_binary_sp_tree(get_left_child(s)), - from_binary_sp_tree(get_right_child(s)), - }, - }; - }, - [](GenericBinaryParallelSplit const &p) { - return IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - { - from_binary_sp_tree(get_left_child(p)), - from_binary_sp_tree(get_right_child(p)), - }, - }; - }, - }); -} - std::variant from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return from_binary_sp_tree(binary.raw_tree); + return binary + .template visit>( + overload{ + [](Node const &n) { return n; }, + [](BinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), + }, + }; + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index e697533054..b7a84b871a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -64,7 +64,7 @@ std::unordered_multiset get_nodes(SeriesSplit const &serial) { std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { return multiset_union(transform( - vector_of(parallel.children), + vector_of(parallel.get_children()), [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 0e04a4f904..7d36371e49 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,85 +1,85 @@ #include "utils/graph/series_parallel/series_parallel_splits.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" -#include "utils/hash-utils.h" -#include "utils/hash/unordered_multiset.h" -#include "utils/hash/vector.h" - -namespace FlexFlow { - -SeriesSplit::SeriesSplit( - std::vector> const &children) - : children(children) {} - -SeriesSplit::SeriesSplit( - std::initializer_list> const &children) - : children(children) {} - -bool SeriesSplit::operator==(SeriesSplit const &other) const { - return this->tie() == other.tie(); -} - -bool SeriesSplit::operator!=(SeriesSplit const &other) const { - return this->tie() != other.tie(); -} - -SeriesSplit::Tie SeriesSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(SeriesSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { - return s << fmt::to_string(split); -} - -ParallelSplit::ParallelSplit( - std::unordered_multiset> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit( - std::initializer_list> const &children) - : children(children) {} - -bool ParallelSplit::operator==(ParallelSplit const &other) const { - return this->tie() == other.tie(); -} - -bool ParallelSplit::operator!=(ParallelSplit const &other) const { - return this->tie() != other.tie(); -} - -ParallelSplit::Tie ParallelSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(ParallelSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { - return s << fmt::to_string(split); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::SeriesSplit>::operator()( - ::FlexFlow::SeriesSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -size_t hash<::FlexFlow::ParallelSplit>::operator()( - ::FlexFlow::ParallelSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -} // namespace std +// #include "utils/fmt/unordered_multiset.h" +// #include "utils/fmt/variant.h" +// #include "utils/fmt/vector.h" +// #include "utils/hash-utils.h" +// #include "utils/hash/unordered_multiset.h" +// #include "utils/hash/vector.h" +// +// namespace FlexFlow { +// +// SeriesSplit::SeriesSplit( +// std::vector> const &children) +// : children(children) {} +// +// SeriesSplit::SeriesSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool SeriesSplit::operator==(SeriesSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool SeriesSplit::operator!=(SeriesSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// SeriesSplit::Tie SeriesSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(SeriesSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// ParallelSplit::ParallelSplit( +// std::unordered_multiset> const &children) +// : children(children) {} +// +// ParallelSplit::ParallelSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool ParallelSplit::operator==(ParallelSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool ParallelSplit::operator!=(ParallelSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// ParallelSplit::Tie ParallelSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(ParallelSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// } // namespace FlexFlow +// +// namespace std { +// +// size_t hash<::FlexFlow::SeriesSplit>::operator()( +// ::FlexFlow::SeriesSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// size_t hash<::FlexFlow::ParallelSplit>::operator()( +// ::FlexFlow::ParallelSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// } // namespace std diff --git a/lib/utils/src/utils/json/check_is_json_deserializable.cc b/lib/utils/src/utils/json/check_is_json_deserializable.cc new file mode 100644 index 0000000000..7e17ced7e5 --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/check_is_json_serializable.cc b/lib/utils/src/utils/json/check_is_json_serializable.cc new file mode 100644 index 0000000000..1c9af4d3cb --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_serializable.h" diff --git a/lib/utils/test/src/utils/containers/cartesian_product.cc b/lib/utils/test/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..42b8a10439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/cartesian_product.cc @@ -0,0 +1,62 @@ +#include "utils/containers/cartesian_product.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cartesian_product") { + + SUBCASE("empty") { + std::vector> containers = {}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{}}; + CHECK(result == correct); + } + + SUBCASE("single container, one element") { + std::vector> containers = {{1}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}}; + CHECK(result == correct); + } + + SUBCASE("single container, multiple elements") { + std::vector> containers = {{1, 2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}, {2}, {3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, one element each") { + std::vector> containers = {{1}, {2}, {3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1, 2, 3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, multiple elements") { + std::vector> containers = {{1, 2}, {3, 4}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = { + {1, 3}, {1, 4}, {2, 3}, {2, 4}}; + CHECK(result == correct); + } + + SUBCASE("1 empty container, 1 non-empty container") { + std::vector> containers = {{}, {2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc new file mode 100644 index 0000000000..c10cc5ae75 --- /dev/null +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -0,0 +1,105 @@ +#include "utils/containers/flatmap.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/containers/map_keys.h" +#include "utils/hash/pair.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::unordered_set, F)") { + auto get_chars = [](std::string const &s) { + std::unordered_set result; + for (char c : s) { + result.insert(c); + } + return result; + }; + + SUBCASE("type changing") { + std::unordered_set input = {"hello", " ", "", "world", "!"}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = { + 'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::unordered_set input = {}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("flatmap(std::unordered_map, F)") { + auto de_nest_keys = [](int k1, + std::unordered_map const &v) { + return map_keys(v, [&](int k2) { return std::pair{k1, k2}; }); + }; + + SUBCASE("input is empty") { + std::unordered_map> input = {}; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::unordered_map> input = { + { + 1, + { + {2, "a"}, + {3, "b"}, + }, + }, + { + 2, + {}, + }, + { + 3, + { + {3, "a"}, + }, + }, + }; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = { + {{1, 2}, "a"}, + {{1, 3}, "b"}, + {{3, 3}, "a"}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate result keys") { + auto always_return_same_map = [](int, std::string const &) { + return std::unordered_map{ + {"mykey", 10000}, + }; + }; + + std::unordered_map input = { + {1, "a"}, + {2, "b"}, + }; + + CHECK_THROWS(flatmap(input, always_return_same_map)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..d5f989318f --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -0,0 +1,53 @@ +#include "utils/containers/get_all_assignments.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_assignments") { + SUBCASE("empty input") { + std::unordered_map> input = {}; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {{}}; + + CHECK(result == correct); + } + + SUBCASE("non-empty input") { + std::unordered_map> input = { + {"a", {1, 2, 3}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = { + {{"a", 1}, {"b", 2}}, + {{"a", 1}, {"b", 3}}, + {{"a", 2}, {"b", 2}}, + {{"a", 2}, {"b", 3}}, + {{"a", 3}, {"b", 2}}, + {{"a", 3}, {"b", 3}}, + }; + + CHECK(result == correct); + } + + SUBCASE("one possible-values set is empty") { + std::unordered_map> input = { + {"a", {}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/try_at.cc b/lib/utils/test/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..548c9b0c79 --- /dev/null +++ b/lib/utils/test/src/utils/containers/try_at.cc @@ -0,0 +1,29 @@ +#include "utils/containers/try_at.h" +#include "test/utils/doctest/fmt/optional.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("try_at(T, K)", + T, + std::unordered_map, + std::map) { + T m = {{1, "one"}, {2, "two"}}; + + SUBCASE("map contains key") { + std::optional result = try_at(m, 1); + std::optional correct = "one"; + + CHECK(result == correct); + } + + SUBCASE("map does not contain key") { + std::optional result = try_at(m, 3); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..f0cdb19611 --- /dev/null +++ b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1,57 @@ +#include "utils/containers/unordered_map_from_pairs.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "utils/containers/contains.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("unordered_map_from_pairs") { + SUBCASE("nonempty input") { + std::vector> input = { + {1, "hello"}, + {3, "world"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = { + {1, "hello"}, + {3, "world"}, + }; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector> input = {}; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input with duplicate keys") { + std::vector> input = { + {1, "a"}, + {2, "c"}, + {1, "b"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + + std::vector> + possible_correct_values = { + {{1, "a"}, {2, "c"}}, + {{1, "b"}, {2, "c"}}, + }; + + CHECK(contains(possible_correct_values, result)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..fec5d3401e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,104 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dataflow_edges_from_node_to_node") { + DataflowGraph g = DataflowGraph::create(); + + SUBCASE("gets edges if there are multiple") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o0 = n1_added.outputs.at(0); + DataflowOutput n1_o1 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o0, n1_o0, n1_o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = { + DataflowEdge{ + n1_o0, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not get edges to/from other nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n3); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "does not get flipped edges (i.e., respects from vs to direction)") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n2, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if no edges exist between the given nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if src node == dst node (as cycles cannot exist " + "in DataflowGraph") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..c35789044d --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,55 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + SplitBoundaryNodes result = + get_transitive_reduced_boundary_nodes_for_split(tr_g, split); + SplitBoundaryNodes correct = SplitBoundaryNodes{ + /*pre_split_boundary=*/{n2}, + /*post_split_boundary=*/{n3}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..1f8f66b932 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,146 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_edges_across_split") { + DataflowGraph g = DataflowGraph::create(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("multiple nodes with edges across") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o1, + DataflowInput{n3, 1}, + }, + DataflowEdge{ + o2, + DataflowInput{n3, 0}, + }, + DataflowEdge{ + o1, + DataflowInput{n4, 0}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("nodes each have multiple edges across") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o1 = n1_added.outputs.at(0); + DataflowOutput n1_o2 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o1, n1_o2, n1_o1}, 1); + Node n2 = n2_added.node; + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_leaf(n1), + make_leaf(n2), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + n1_o1, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o2, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not return edges eliminated by transitive reduction") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o2, + DataflowInput{n3, 1}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0e77739434 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,52 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_outputs_across_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_outputs_across_split(tr_g, split); + std::unordered_set correct = {o2}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..5a1ea99671 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,142 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_edges_from_subgraph_to_subgraph") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + SUBCASE("basic tests") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(4)}; + std::unordered_set dst_subgraph = {n.at(2), n.at(3)}; + + SUBCASE("returns all edges between subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(4), n.at(2)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = unordered_set_of(e); + + CHECK(result == correct); + } + + SUBCASE("does not return reverse edges") { + std::vector e = { + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(0)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(0)}; + + CHECK(result == correct); + } + + SUBCASE("does not return edges within subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if there are no edges from src_subgraph to " + "dst_subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("empty subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + SUBCASE("returns no edges if no nodes in src_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if no nodes in dst_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if both subgraphs are empty") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("if subgraphs do not cover graph, then does not return external " + "edges") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set src_subgraph = {n.at(0)}; + std::unordered_set dst_subgraph = {n.at(3)}; + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("throws an error if subgraphs are not disjoint") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(2)}; + std::unordered_set dst_subgraph = {n.at(1), n.at(3)}; + CHECK_THROWS( + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 66b657eaaa..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("fmt GenericBinarySPDecompositionTree") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - std::string result = fmt::to_string(input); - std::string correct = ""; - - CHECK(result == correct); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index abae9286b6..9ca869b2b0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,41 +1,61 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "test/utils/doctest/fmt/unordered_multiset.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { + TEST_CASE("get_leaves") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { + return get_leaves(tree, impl); + }; + SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); + BinarySPDecompositionTree input = BinarySPDecompositionTree{n1}; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1}; CHECK(result == correct); } SUBCASE("series split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; CHECK(result == correct); } @@ -43,42 +63,54 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; CHECK(result == correct); } } + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {2, 2, 4, 4, 5}; + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1, n2, n2, n3}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 92c556ad28..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_left_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(5); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(4); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 3de61d3313..ad7e1c2609 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1,16 +1,43 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { + TEST_CASE("get_num_tree_nodes") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + auto generic_get_num_tree_nodes = + [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; + SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); + BinarySPDecompositionTree input = make_leaf(n1); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 1; CHECK(result == correct); @@ -18,22 +45,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("series split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n2)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n1)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); @@ -42,22 +67,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n2)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n1)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); @@ -65,18 +88,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - int result = get_num_tree_nodes(input); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + int result = generic_get_num_tree_nodes(input); int correct = 9; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index 33b5d37955..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_right_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(3); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(7); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index e7025dbfad..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("std::hash>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree leaf_5 = - make_generic_binary_sp_leaf(5); - size_t leaf_5_hash = get_std_hash(leaf_5); - - SUBCASE("leaves with same labels hash to the same value") { - GenericBinarySPDecompositionTree also_leaf_5 = - make_generic_binary_sp_leaf(5); - size_t also_leaf_5_hash = get_std_hash(also_leaf_5); - - CHECK(leaf_5_hash == also_leaf_5_hash); - } - - SUBCASE("leaves with different labels hash to different values") { - GenericBinarySPDecompositionTree leaf_6 = - make_generic_binary_sp_leaf(6); - size_t leaf_6_hash = get_std_hash(leaf_6); - - CHECK(leaf_5_hash != leaf_6_hash); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t series_5_6_hash = get_std_hash(series_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_series_5_6_hash = get_std_hash(also_series_5_6); - - CHECK(series_5_6_hash == also_series_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree series_6_5 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t series_6_5_hash = get_std_hash(series_6_5); - - CHECK(series_5_6_hash != series_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree series_4_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t series_4_6_hash = get_std_hash(series_4_6); - - CHECK(series_5_6_hash != series_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree series_5_7 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t series_5_7_hash = get_std_hash(series_5_7); - - CHECK(series_5_6_hash != series_5_7_hash); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t parallel_5_6_hash = get_std_hash(parallel_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); - - CHECK(parallel_5_6_hash == also_parallel_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree parallel_6_5 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t parallel_6_5_hash = get_std_hash(parallel_6_5); - - CHECK(parallel_5_6_hash != parallel_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree parallel_4_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t parallel_4_6_hash = get_std_hash(parallel_4_6); - - CHECK(parallel_5_6_hash != parallel_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree parallel_5_7 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t parallel_5_7_hash = get_std_hash(parallel_5_7); - - CHECK(parallel_5_6_hash != parallel_5_7_hash); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 7a8756c6cc..3fae155280 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1,22 +1,38 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_left_associative(" - "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; + TEST_CASE("is_binary_sp_tree_left_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually left associative") { SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -25,12 +41,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -39,12 +51,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -53,14 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -71,12 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not left associative") { SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; @@ -85,12 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 3cf87368ab..5b4e26107e 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,22 +1,38 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_right_associative(" - "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; + TEST_CASE("is_binary_sp_tree_right_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually right associative") { SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -25,12 +41,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -39,12 +51,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -53,14 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -71,12 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not right associative") { SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; @@ -85,12 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index cc234bacf8..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1,131 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("adl_serializer>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_series_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "series"}, - { - "value", - { - {"__type", "GenericBinarySeriesSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "parallel"}, - { - "value", - { - {"__type", "GenericBinaryParallelSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index 4ede4e84b5..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split(make_generic_binary_sp_leaf(1), - make_generic_binary_sp_leaf(4)), - make_generic_binary_sp_leaf(2)); - - GenericBinarySPDecompositionTree result = - transform(input, [](int x) { return std::to_string(x); }); - - GenericBinarySPDecompositionTree correct = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(std::string{"1"}), - make_generic_binary_sp_leaf(std::string{"4"})), - make_generic_binary_sp_leaf(std::string{"2"})); - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 1e3217a2de..fee971e5e0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,34 +18,45 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_leaf_node(n3)); + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -64,20 +75,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 0befbde5cc..fd540f853f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -14,8 +14,20 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("leaf") { - BinarySPDecompositionTree input = make_leaf_node(n1); + BinarySPDecompositionTree input = make_leaf(n1); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; @@ -25,35 +37,33 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative series") { BinarySPDecompositionTree input = make_series_split( - make_series_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_series_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative series") { BinarySPDecompositionTree input = make_series_split( - make_leaf_node(n2), - make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), make_series_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("series with duplicate children") { BinarySPDecompositionTree input = - make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + make_series_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + SeriesParallelDecomposition{SeriesSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -61,35 +71,33 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_parallel_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_leaf_node(n2), - make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), make_parallel_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("parallel with duplicate children") { BinarySPDecompositionTree input = - make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + make_parallel_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + SeriesParallelDecomposition{ParallelSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -99,31 +107,31 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree input = make_parallel_split( make_parallel_split( make_parallel_split( - make_leaf_node(n1), + make_leaf(n1), make_series_split( - make_series_split(make_series_split(make_leaf_node(n2), - make_leaf_node(n3)), - make_leaf_node(n3)), - make_leaf_node(n5))), - make_series_split(make_leaf_node(n6), make_leaf_node(n4))), - make_leaf_node(n5)); + make_series_split( + make_series_split(make_leaf(n2), make_leaf(n3)), + make_leaf(n3)), + make_leaf(n5))), + make_series_split(make_leaf(n6), make_leaf(n4))), + make_leaf(n5)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index db1b440481..532ff86c90 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,34 +16,45 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_leaf_node(n1), - make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -62,20 +73,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 45f796c824..e5b9045739 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -24,10 +24,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ + SeriesParallelDecomposition{ParallelSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -39,10 +39,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -59,13 +59,13 @@ TEST_SUITE(FF_TEST_SUITE) { get_series_parallel_decomposition(g); std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), - ParallelSplit{ + ParallelSplit{{ n.at(1), n.at(2), - }, - }, + }}, + }}, }; CHECK(result == correct); } @@ -86,20 +86,20 @@ TEST_SUITE(FF_TEST_SUITE) { }); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), - ParallelSplit{ - SeriesSplit{ + ParallelSplit{{ + SeriesSplit{{ n.at(1), n.at(3), - }, - SeriesSplit{ + }}, + SeriesSplit{{ n.at(2), n.at(4), - }, - }, + }}, + }}, n.at(5), - }}; + }}}; std::optional result = get_series_parallel_decomposition(g); @@ -122,16 +122,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ - ParallelSplit{ + SeriesSplit{{ + ParallelSplit{{ n.at(0), n.at(1), - }, - ParallelSplit{ + }}, + ParallelSplit{{ n.at(2), n.at(3), - }, - }, + }}, + }}, }; std::optional result = @@ -177,12 +177,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), n.at(1), n.at(2), n.at(3), - }, + }}, }; std::optional result = get_series_parallel_decomposition(g);