From 510c2d9b3c7d84736d3d309de44d08a30afc3d68 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 21:03:41 -0700 Subject: [PATCH 01/71] Start on pcg builder --- .../pcg/parallel_computation_graph_builder.h | 20 ++++++ .../include/pcg/parallel_tensor_guid_t.dtg.h | 46 +++++++++++++ lib/pcg/include/pcg/parallel_tensor_guid_t.h | 18 ----- .../pcg/parallel_tensor_guid_t.struct.toml | 16 +++++ lib/pcg/include/pcg/tensor_guid_t.dtg.h | 2 +- lib/pcg/include/pcg/tensor_guid_t.struct.toml | 2 - lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc | 66 +++++++++++++++++++ lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 2 +- 8 files changed, 150 insertions(+), 22 deletions(-) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph_builder.h create mode 100644 lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h delete mode 100644 lib/pcg/include/pcg/parallel_tensor_guid_t.h create mode 100644 lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml create mode 100644 lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph_builder.h new file mode 100644 index 0000000000..f1b0734f6c --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph_builder.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H + +#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +struct ParallelComputationGraphBuilder { +public: + ParallelComputationGraphBuilder(); + + +public: + ParallelComputationGraph pcg; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h new file mode 100644 index 0000000000..4041544903 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidiedge.h" +#include +#include +#include + +namespace FlexFlow { +struct parallel_tensor_guid_t { + parallel_tensor_guid_t() = delete; + parallel_tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + + bool operator==(parallel_tensor_guid_t const &) const; + bool operator!=(parallel_tensor_guid_t const &) const; + bool operator<(parallel_tensor_guid_t const &) const; + bool operator>(parallel_tensor_guid_t const &) const; + bool operator<=(parallel_tensor_guid_t const &) const; + bool operator>=(parallel_tensor_guid_t const &) const; + ::FlexFlow::MultiDiOutput raw_graph_output; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::parallel_tensor_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(parallel_tensor_guid_t const &); +std::ostream &operator<<(std::ostream &, parallel_tensor_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_guid_t.h b/lib/pcg/include/pcg/parallel_tensor_guid_t.h deleted file mode 100644 index db8f84b7e2..0000000000 --- a/lib/pcg/include/pcg/parallel_tensor_guid_t.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_H - -#include "utils/graph/multidiedge.h" -#include "utils/strong_typedef.h" - -namespace FlexFlow { - -struct parallel_tensor_guid_t - : strong_typedef { - using strong_typedef::strong_typedef; -}; -FF_TYPEDEF_HASHABLE(parallel_tensor_guid_t); -FF_TYPEDEF_PRINTABLE(parallel_tensor_guid_t, "parallel_tensor_guid"); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..7837d7b39b --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "parallel_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidiedge.h" +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::MultiDiOutput" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h index c6109c6103..f9841a4d06 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.dtg.h +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/tensor_guid_t.struct.toml /* proj-data { - "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" + "generated_from": "1e3914b97a465f1752ce510614145b37" } */ diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml index aea4fad108..795c0166eb 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.struct.toml +++ b/lib/pcg/include/pcg/tensor_guid_t.struct.toml @@ -4,8 +4,6 @@ features = [ "eq", "ord", "hash", - # "json", - # "rapidcheck", "fmt", ] diff --git a/lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc new file mode 100644 index 0000000000..b64cf2901f --- /dev/null +++ b/lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" +} +*/ + +#include "pcg/parallel_tensor_guid_t.dtg.h" + +#include "utils/graph/multidiedge.h" +#include + +namespace FlexFlow { +parallel_tensor_guid_t::parallel_tensor_guid_t( + ::FlexFlow::MultiDiOutput const &raw_graph_output) + : raw_graph_output(raw_graph_output) {} +bool parallel_tensor_guid_t::operator==( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) == std::tie(other.raw_graph_output); +} +bool parallel_tensor_guid_t::operator!=( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) != std::tie(other.raw_graph_output); +} +bool parallel_tensor_guid_t::operator<( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) < std::tie(other.raw_graph_output); +} +bool parallel_tensor_guid_t::operator>( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) > std::tie(other.raw_graph_output); +} +bool parallel_tensor_guid_t::operator<=( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) <= std::tie(other.raw_graph_output); +} +bool parallel_tensor_guid_t::operator>=( + parallel_tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) >= std::tie(other.raw_graph_output); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::parallel_tensor_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(parallel_tensor_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, parallel_tensor_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc index 9d57291112..779018296d 100644 --- a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/tensor_guid_t.struct.toml /* proj-data { - "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" + "generated_from": "1e3914b97a465f1752ce510614145b37" } */ From 7b55ed1eaa7de68a33f5cfb7038190206fb013e6 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 23:06:32 -0700 Subject: [PATCH 02/71] Add tests and some implementation for pcg builder --- lib/op-attrs/include/op-attrs/datatype.h | 2 + .../include/op-attrs/pcg_operator_attrs.dtg.h | 24 +- .../op-attrs/pcg_operator_attrs.variant.toml | 5 + lib/op-attrs/src/{ => op-attrs}/datatype.cc | 4 + lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 9 - .../src/op-attrs/pcg_operator_attrs.dtg.cc | 24 +- lib/op-attrs/test/src/datatype.cc | 23 ++ lib/op-attrs/test/src/test_conv_2d.cc | 183 ++++++++++- lib/pcg/include/pcg/operator_guid_t.dtg.h | 46 --- .../include/pcg/parallel_computation_graph.h | 14 +- .../pcg/parallel_computation_graph_builder.h | 107 +++++++ .../include/pcg/parallel_layer_guid_t.dtg.h | 46 +++ ...toml => parallel_layer_guid_t.struct.toml} | 2 +- lib/pcg/src/pcg/parallel_computation_graph.cc | 20 ++ .../pcg/parallel_computation_graph_builder.cc | 293 ++++++++++++++++++ ..._t.dtg.cc => parallel_layer_guid_t.dtg.cc} | 37 ++- .../pcg/parallel_computation_graph_builder.cc | 32 ++ 17 files changed, 778 insertions(+), 93 deletions(-) rename lib/op-attrs/src/{ => op-attrs}/datatype.cc (85%) create mode 100644 lib/op-attrs/test/src/datatype.cc delete mode 100644 lib/pcg/include/pcg/operator_guid_t.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h rename lib/pcg/include/pcg/{operator_guid_t.struct.toml => parallel_layer_guid_t.struct.toml} (86%) create mode 100644 lib/pcg/src/pcg/parallel_computation_graph.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph_builder.cc rename lib/pcg/src/pcg/{operator_guid_t.dtg.cc => parallel_layer_guid_t.dtg.cc} (51%) create mode 100644 lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index a435c1bc12..6204b9ca49 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -58,6 +58,8 @@ using DataTypeValue = std::variant, size_t size_of_datatype(DataType); +bool can_strictly_promote_datatype_from_to(DataType, DataType); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h index 5370773a45..76fdaab919 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "e1d10b0c7c98524c27886bdae0972321" + "generated_from": "befa524c61393938b5b02f2e0401a122" } */ @@ -41,6 +41,7 @@ #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" #include "rapidcheck.h" #include #include @@ -80,6 +81,7 @@ struct PCGOperatorAttrs { explicit PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::TopKAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::WeightAttrs const &); template static constexpr bool IsPartOfPCGOperatorAttrs_v = std::is_same_v || @@ -110,7 +112,8 @@ struct PCGOperatorAttrs { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v; + std::is_same_v || + std::is_same_v; template ReturnType visit(Visitor &&v) const { switch (this->index()) { @@ -230,6 +233,10 @@ struct PCGOperatorAttrs { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } + case 29: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } default: { throw std::runtime_error(fmt::format( "Unknown index {} for type PCGOperatorAttrs", this->index())); @@ -355,6 +362,10 @@ struct PCGOperatorAttrs { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } + case 29: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } default: { throw std::runtime_error(fmt::format( "Unknown index {} for type PCGOperatorAttrs", this->index())); @@ -380,7 +391,7 @@ struct PCGOperatorAttrs { "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " - "::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TransposeAttrs, ::FlexFlow::WeightAttrs], received T"); return std::holds_alternative(this->raw_variant); } template @@ -402,7 +413,7 @@ struct PCGOperatorAttrs { "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " - "::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TransposeAttrs, ::FlexFlow::WeightAttrs], received T"); return std::get(this->raw_variant); } template @@ -424,7 +435,7 @@ struct PCGOperatorAttrs { "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " - "::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TransposeAttrs, ::FlexFlow::WeightAttrs], received T"); return std::get(this->raw_variant); } size_t index() const { @@ -464,7 +475,8 @@ struct PCGOperatorAttrs { ::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, - ::FlexFlow::TransposeAttrs> + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs> raw_variant; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index ddb8a109d8..a2936b56c5 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -39,6 +39,7 @@ includes = [ "op-attrs/ops/split_attrs.dtg.h", "op-attrs/ops/topk_attrs.dtg.h", "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", ] [[values]] @@ -156,3 +157,7 @@ key = "topk" [[values]] type = "::FlexFlow::TransposeAttrs" key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/src/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc similarity index 85% rename from lib/op-attrs/src/datatype.cc rename to lib/op-attrs/src/op-attrs/datatype.cc index 20e55a641f..e382ea298d 100644 --- a/lib/op-attrs/src/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -21,4 +21,8 @@ size_t size_of_datatype(DataType data_type) { } } +bool can_strictly_promote_datatype_from_to(DataType src, DataType dst) { + return src < dst; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index c9ec467af4..d20690d705 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -90,9 +90,6 @@ ParallelTensorShape input.datatype, }; - assert(total_parallel_degree(result.dims) == - total_parallel_degree(raw_input_shape.dims)); - return result; } @@ -122,9 +119,6 @@ ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, input.datatype, }; - assert(total_parallel_degree(result.dims) == - total_parallel_degree(raw_input_shape.dims)); - return result; } @@ -167,9 +161,6 @@ ParallelTensorShape input.datatype, }; - assert(total_parallel_degree(result.dims) == - total_parallel_degree(raw_input_shape.dims)); - return result; } diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc index 5334c8a7ab..3e01858163 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "e1d10b0c7c98524c27886bdae0972321" + "generated_from": "befa524c61393938b5b02f2e0401a122" } */ @@ -72,6 +72,8 @@ PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TopKAttrs const &v) : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &v) : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::WeightAttrs const &v) + : raw_variant(v) {} bool PCGOperatorAttrs::operator==(PCGOperatorAttrs const &other) const { return this->raw_variant == other.raw_variant; } @@ -122,7 +124,8 @@ size_t hash<::FlexFlow::PCGOperatorAttrs>::operator()( ::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, - ::FlexFlow::TransposeAttrs>>{}(x.raw_variant); + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs>>{}(x.raw_variant); } } // namespace std namespace nlohmann { @@ -216,6 +219,9 @@ ::FlexFlow::PCGOperatorAttrs } else if (key == "transpose") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else if (key == "weight") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::WeightAttrs>()}; } else { throw std::runtime_error(fmt::format("Unknown type key {}", key)); } @@ -369,6 +375,11 @@ void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( j["value"] = x.get<::FlexFlow::TransposeAttrs>(); break; } + case 29: { + j["type"] = "weight"; + j["value"] = x.get<::FlexFlow::WeightAttrs>(); + break; + } default: { throw std::runtime_error( fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); @@ -436,7 +447,9 @@ Gen<::FlexFlow::PCGOperatorAttrs> gen::construct<::FlexFlow::PCGOperatorAttrs>( gen::arbitrary<::FlexFlow::TopKAttrs>()), gen::construct<::FlexFlow::PCGOperatorAttrs>( - gen::arbitrary<::FlexFlow::TransposeAttrs>())); + gen::arbitrary<::FlexFlow::TransposeAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::WeightAttrs>())); } } // namespace rc namespace FlexFlow { @@ -584,6 +597,11 @@ std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { << x.get<::FlexFlow::TransposeAttrs>() << ">"; break; } + case 29: { + oss << ""; + break; + } default: { throw std::runtime_error( fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/datatype.cc new file mode 100644 index 0000000000..90aa0c20f6 --- /dev/null +++ b/lib/op-attrs/test/src/datatype.cc @@ -0,0 +1,23 @@ +#include "test/utils/doctest.h" +#include "op-attrs/datatype.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { + CHECK(can_strictly_promote_datatype_from_to(DataType::BOOL, DataType::INT32)); + CHECK(can_strictly_promote_datatype_from_to(DataType::INT32, DataType::INT64)); + CHECK(can_strictly_promote_datatype_from_to(DataType::FLOAT, DataType::DOUBLE)); + + SUBCASE("is strict") { + rc::check([](DataType d) { + RC_ASSERT(!can_strictly_promote_datatype_from_to(d, d)); + }); + } + + SUBCASE("is asymmetric") { + rc::check([](DataType l, DataType r) { + RC_PRE(can_strictly_promote_datatype_from_to(l, r)); + RC_ASSERT(!can_strictly_promote_datatype_from_to(r, l)); + }); + } + } +} diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/test_conv_2d.cc index b16a26a7b1..85d95b42cb 100644 --- a/lib/op-attrs/test/src/test_conv_2d.cc +++ b/lib/op-attrs/test/src/test_conv_2d.cc @@ -1,8 +1,9 @@ #include "doctest/doctest.h" #include "op-attrs/ops/conv_2d.h" +#include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_output_shape(Conv2DAttrs, TensorShape)") { + TEST_CASE("Conv2D shape inference") { int out_channels = 4; int kernel_h = 3; int kernel_w = 2; @@ -32,7 +33,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t input_height = 10; size_t input_width = 15; - TensorShape input_shape = { + TensorShape input = { TensorDims{FFOrdered{ num_samples, input_channels, @@ -42,21 +43,181 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape result = get_output_shape(attrs, input_shape); + size_t output_height = 3; + size_t output_width = 6; - size_t correct_output_height = 3; - size_t correct_output_width = 6; - - TensorShape correct_output_shape = { + TensorShape output = { TensorDims{FFOrdered{ num_samples, - static_cast(out_channels), - correct_output_height, - correct_output_width, + size_t_from_int(out_channels), + output_height, + output_width, }}, DataType::FLOAT, }; - CHECK(result == correct_output_shape); + TensorShape kernel = { + TensorDims{FFOrdered{ + size_t_from_int(out_channels), + input_channels, + size_t_from_int(kernel_h), + size_t_from_int(kernel_w), + }}, + DataType::FLOAT, + }; + + TensorShape bias = { + TensorDims{FFOrdered{ + size_t_from_int(out_channels), + }}, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(Conv2DAttrs, TensorShape)") { + TensorShape result_output = get_output_shape(attrs, input); + TensorShape correct_output = output; + CHECK(result_output == correct_output); + } + + SUBCASE("get_kernel_shape(Conv2DAttrs, TensorShape)") { + TensorShape result_kernel = get_kernel_shape(attrs, input); + TensorShape correct_kernel = kernel; + CHECK(result_kernel == correct_kernel); + } + + SUBCASE("get_bias_shape(Conv2DAttrs, TensorShape)") { + TensorShape result_bias = get_bias_shape(attrs, input); + TensorShape correct_bias = bias; + CHECK(result_bias == correct_bias); + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_n, + int o_c, + int o_h, + int o_w) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_n, o_c, o_h, o_w}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_n, + int o_c, + int o_h, + int o_w) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_n, o_c, o_h, o_w}); + }; + + auto make_kernel = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_outchannels, + int o_inchannels, + int o_kernel_h, + int o_kernel_w) { + return lift_to_parallel_with_degrees( + kernel, o_sum, o_eq, FFOrdered{o_outchannels, o_inchannels, o_kernel_h, o_kernel_w}); + }; + + auto make_bias = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_outchannels) { + return lift_to_parallel_with_degrees( + bias, o_sum, o_eq, FFOrdered{o_outchannels}); + }; + + SUBCASE("data parallelism") { + int degree = 2; + ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + + SUBCASE("get_output_shape") { + ParallelTensorShape result = get_output_shape(attrs, par_input); + ParallelTensorShape correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_kernel_shape") { + ParallelTensorShape result = get_kernel_shape(attrs, par_input); + ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_bias_shape") { + ParallelTensorShape result = get_bias_shape(attrs, par_input); + ParallelTensorShape correct = make_bias(SumDegree{1}, DiscardCopyDegree{degree}, 1); + CHECK(result == correct); + } + } + + SUBCASE("input channel parallelism") { + int degree = 2; + ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + + SUBCASE("get_output_shape") { + ParallelTensorShape result = get_output_shape(attrs, par_input); + ParallelTensorShape correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_kernel_shape") { + ParallelTensorShape result = get_kernel_shape(attrs, par_input); + ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_bias_shape") { + ParallelTensorShape result = get_bias_shape(attrs, par_input); + ParallelTensorShape correct = make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel parallelism") { + int degree = 2; + ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + + SUBCASE("get_output_shape") { + ParallelTensorShape result = get_output_shape(attrs, par_input); + ParallelTensorShape correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_kernel_shape") { + ParallelTensorShape result = get_kernel_shape(attrs, par_input); + ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_bias_shape") { + ParallelTensorShape result = get_bias_shape(attrs, par_input); + ParallelTensorShape correct = make_bias(SumDegree{1}, DiscardCopyDegree{1}, degree); + CHECK(result == correct); + } + } + + SUBCASE("propagating sum degree") { + int degree = 2; + ParallelTensorShape par_input = make_input(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + + SUBCASE("get_output_shape") { + ParallelTensorShape result = get_output_shape(attrs, par_input); + ParallelTensorShape correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_kernel_shape") { + ParallelTensorShape result = get_kernel_shape(attrs, par_input); + ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_bias_shape") { + ParallelTensorShape result = get_bias_shape(attrs, par_input); + ParallelTensorShape correct = make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); + CHECK(result == correct); + } + } } } diff --git a/lib/pcg/include/pcg/operator_guid_t.dtg.h b/lib/pcg/include/pcg/operator_guid_t.dtg.h deleted file mode 100644 index bf08150e5e..0000000000 --- a/lib/pcg/include/pcg/operator_guid_t.dtg.h +++ /dev/null @@ -1,46 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/operator_guid_t.struct.toml -/* proj-data -{ - "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" -} -*/ - -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H - -#include "fmt/format.h" -#include "utils/graph.h" -#include -#include -#include - -namespace FlexFlow { -struct operator_guid_t { - operator_guid_t() = delete; - operator_guid_t(::FlexFlow::Node const &raw_graph_node); - - bool operator==(operator_guid_t const &) const; - bool operator!=(operator_guid_t const &) const; - bool operator<(operator_guid_t const &) const; - bool operator>(operator_guid_t const &) const; - bool operator<=(operator_guid_t const &) const; - bool operator>=(operator_guid_t const &) const; - ::FlexFlow::Node raw_graph_node; -}; -} // namespace FlexFlow - -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::operator_guid_t const &) const; -}; -} // namespace std - -namespace FlexFlow { -std::string format_as(operator_guid_t const &); -std::ostream &operator<<(std::ostream &, operator_guid_t const &); -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 9d7103f4fd..aae5122671 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -1,8 +1,18 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H -#include "pcg/parallel_computation_graph_t.h" +#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_tensor_guid_t.dtg.h" -namespace FlexFlow {} +namespace FlexFlow { + +ParallelComputationGraph empty_parallel_computation_graph(); + +std::unordered_set get_parallel_layers(ParallelComputationGraph const &); + +ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); + +} #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph_builder.h index f1b0734f6c..6e21110e0e 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph_builder.h @@ -3,6 +3,7 @@ #include "pcg/parallel_computation_graph.dtg.h" #include "pcg/parallel_tensor_guid_t.dtg.h" +#include namespace FlexFlow { @@ -10,7 +11,113 @@ struct ParallelComputationGraphBuilder { public: ParallelComputationGraphBuilder(); + parallel_tensor_guid_t create_input_tensor(ParallelTensorShape const &shape, + bool create_grad = true, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t add(parallel_tensor_guid_t const &lhs, + parallel_tensor_guid_t const &rhs, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + batch_matmul(parallel_tensor_guid_t const &a, + parallel_tensor_guid_t const &b, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t cast(parallel_tensor_guid_t const &input, + DataType result_type, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t conv2d( + parallel_tensor_guid_t const &input, + int outChannels, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + std::optional const &activation = std::nullopt, + int groups = 1, + bool use_bias = true, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_regularizer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t dense( + parallel_tensor_guid_t const &input, + int outDim, + std::optional activation = std::nullopt, + bool use_bias = true, + DataType data_type = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t embedding( + parallel_tensor_guid_t const &input, + int num_entries, + int outDim, + AggregateOp aggr, + DataType dtype = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t multihead_attention( + parallel_tensor_guid_t const &query, + parallel_tensor_guid_t const &key, + parallel_tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = true, + bool add_bias_kv = false, + bool add_zero_attn = false, + std::optional initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t relu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t parallel_partition(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t parallel_combine(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t parallel_replicate(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t parallel_reduce(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name = std::nullopt); + +private: + parallel_tensor_guid_t as_type(parallel_tensor_guid_t const &, DataType, std::string const &name); +private: + ParallelTensorShape get_shape(parallel_tensor_guid_t const &) const; +private: + std::vector add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + std::vector add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &output); + parallel_tensor_guid_t add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorAttrs const &output); + parallel_tensor_guid_t add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorShape const &output); public: ParallelComputationGraph pcg; }; diff --git a/lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h new file mode 100644 index 0000000000..8fc81cee05 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "c31301efeb92e151b04943786aa7bec1" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct parallel_layer_guid_t { + parallel_layer_guid_t() = delete; + parallel_layer_guid_t(::FlexFlow::Node const &raw_graph_node); + + bool operator==(parallel_layer_guid_t const &) const; + bool operator!=(parallel_layer_guid_t const &) const; + bool operator<(parallel_layer_guid_t const &) const; + bool operator>(parallel_layer_guid_t const &) const; + bool operator<=(parallel_layer_guid_t const &) const; + bool operator>=(parallel_layer_guid_t const &) const; + ::FlexFlow::Node raw_graph_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::parallel_layer_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(parallel_layer_guid_t const &); +std::ostream &operator<<(std::ostream &, parallel_layer_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/operator_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml similarity index 86% rename from lib/pcg/include/pcg/operator_guid_t.struct.toml rename to lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml index f89d30137e..63fb25a45b 100644 --- a/lib/pcg/include/pcg/operator_guid_t.struct.toml +++ b/lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "operator_guid_t" +name = "parallel_layer_guid_t" features = [ "eq", "ord", diff --git a/lib/pcg/src/pcg/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph.cc new file mode 100644 index 0000000000..c5557488b8 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph.cc @@ -0,0 +1,20 @@ +#include "pcg/parallel_computation_graph.h" +#include "utils/containers.h" + +namespace FlexFlow { + +ParallelComputationGraph empty_parallel_computation_graph() { + return ParallelComputationGraph{DataflowGraph{}}; +} + +std::unordered_set get_parallel_layers(ParallelComputationGraph const &pcg) { + return transform(get_nodes(pcg.raw_graph), + [&](Node const &n) { return parallel_layer_guid_t{n}; }); +} + +ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return pcg.raw_graph.at(t.raw_graph_output); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph_builder.cc new file mode 100644 index 0000000000..81f479e8f6 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph_builder.cc @@ -0,0 +1,293 @@ +#include "pcg/parallel_computation_graph_builder.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers.h" + +namespace FlexFlow { + +static std::string get_default_name(OperatorType op_type) { + return get_operator_type_name(op_type); +} + +static std::string get_default_name(PCGOperatorAttrs const &attrs) { + return get_default_name(get_op_type(attrs)); +} + +static ParallelTensorAttrs make_weight_attrs( + ParallelTensorShape const &shape, + std::optional const &initializer_attrs) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/initializer_attrs, + /*create_gradients=*/CreateGrad::YES, + }; +} + + +ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() + : pcg(empty_parallel_computation_graph()) { } + +parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor(ParallelTensorShape const &shape, + bool create_grad, + std::optional const &name) { + ParallelTensorAttrs tensor_attrs = { + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), + }; + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{InputAttrs{}}, + name, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::add(parallel_tensor_guid_t const &lhs, + parallel_tensor_guid_t const &rhs, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::batch_matmul(parallel_tensor_guid_t const &a, + parallel_tensor_guid_t const &b, + /* int a_seq_length_dim = -1, */ + /* int b_seq_length_dim = -1, */ + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::cast(parallel_tensor_guid_t const &input, + DataType result_type, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::conv2d(parallel_tensor_guid_t const &raw_input, + int outChannels, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + std::optional const &activation, + int groups, + bool use_bias, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, + std::optional const &kernel_regularizer, + std::optional const &maybe_name) { + Conv2DAttrs attrs = {outChannels, + kernelH, + kernelW, + strideH, + strideW, + paddingH, + paddingW, + groups, + activation, + use_bias}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + parallel_tensor_guid_t input = + this->as_type(raw_input, DataType::FLOAT, name + "input_pre_cast"); + + ParallelLayerAttrs layer = {PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + std::vector weights; + + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), + kernel_initializer)); + + if (use_bias) { + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), + bias_initializer)); + } + + return this->add_layer(layer, {input}, weights, output_shape); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::dense(parallel_tensor_guid_t const &input, + int outDim, + std::optional activation, + bool use_bias, + DataType data_type, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::embedding( + parallel_tensor_guid_t const &input, + int num_entries, + int outDim, + AggregateOp aggr, + DataType dtype, + std::optional const &kernel_initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::multihead_attention( + parallel_tensor_guid_t const &query, + parallel_tensor_guid_t const &key, + parallel_tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool bias, + bool add_bias_kv, + bool add_zero_attn, + std::optional initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::relu(parallel_tensor_guid_t const &input, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::parallel_partition(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::parallel_combine(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::parallel_replicate(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::parallel_reduce(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::as_type(parallel_tensor_guid_t const &input, DataType goal_datatype, std::string const &name) { + DataType input_datatype = this->get_shape(input).data_type; + if (input_datatype == goal_datatype) { + return input; + } else if (can_strictly_promote_datatype_from_to(input_datatype, goal_datatype)) { + return this->cast(input, goal_datatype, name); + } else { + throw mk_runtime_error( + fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + input_datatype, + goal_datatype)); + } +} + +ParallelTensorShape +ParallelComputationGraphBuilder::get_shape(parallel_tensor_guid_t const &t) const { + return get_parallel_tensor_attrs(this->pcg, t).shape; +} + +std::vector +ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + std::vector raw_weight_tensors; + for (auto const &kv : enumerate_vector(weights)) { + int weight_idx = kv.first; + ParallelTensorAttrs weight_tensor_attrs = kv.second; + + std::optional weight_name = + transform(layer.name, [&](std::string const &layer_name) { + return fmt::format("{}.weights[{}]", layer_name, weight_idx); + }); + ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{}}, + weight_name, + }; + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + raw_weight_tensors.push_back( + get_only(this->pcg.raw_graph.add_operator( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); + } + + std::vector raw_inputs = transform( + inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = + this->pcg.raw_graph.add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + return transform(raw_outputs, + [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); +} + +std::vector +ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer( + layer, inputs, weights, transform(outputs, [](ParallelTensorShape const &s) { + return ParallelTensorAttrs{ + /*shape=*/s, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + })); +} + + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorAttrs const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); +} + +parallel_tensor_guid_t +ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorShape const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); +} + + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc similarity index 51% rename from lib/pcg/src/pcg/operator_guid_t.dtg.cc rename to lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc index 46b031f7e1..876a735b14 100644 --- a/lib/pcg/src/pcg/operator_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc @@ -1,43 +1,50 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/operator_guid_t.struct.toml +// lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml /* proj-data { - "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" + "generated_from": "c31301efeb92e151b04943786aa7bec1" } */ -#include "pcg/operator_guid_t.dtg.h" +#include "pcg/parallel_layer_guid_t.dtg.h" #include "utils/graph.h" #include namespace FlexFlow { -operator_guid_t::operator_guid_t(::FlexFlow::Node const &raw_graph_node) +parallel_layer_guid_t::parallel_layer_guid_t( + ::FlexFlow::Node const &raw_graph_node) : raw_graph_node(raw_graph_node) {} -bool operator_guid_t::operator==(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator==( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) == std::tie(other.raw_graph_node); } -bool operator_guid_t::operator!=(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator!=( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) != std::tie(other.raw_graph_node); } -bool operator_guid_t::operator<(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator<( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) < std::tie(other.raw_graph_node); } -bool operator_guid_t::operator>(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator>( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) > std::tie(other.raw_graph_node); } -bool operator_guid_t::operator<=(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator<=( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) <= std::tie(other.raw_graph_node); } -bool operator_guid_t::operator>=(operator_guid_t const &other) const { +bool parallel_layer_guid_t::operator>=( + parallel_layer_guid_t const &other) const { return std::tie(this->raw_graph_node) >= std::tie(other.raw_graph_node); } } // namespace FlexFlow namespace std { -size_t hash::operator()( - FlexFlow::operator_guid_t const &x) const { +size_t hash::operator()( + FlexFlow::parallel_layer_guid_t const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.raw_graph_node) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,14 +53,14 @@ size_t hash::operator()( } // namespace std namespace FlexFlow { -std::string format_as(operator_guid_t const &x) { +std::string format_as(parallel_layer_guid_t const &x) { std::ostringstream oss; - oss << ""; return oss.str(); } -std::ostream &operator<<(std::ostream &s, operator_guid_t const &x) { +std::ostream &operator<<(std::ostream &s, parallel_layer_guid_t const &x) { return s << fmt::to_string(x); } } // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc new file mode 100644 index 0000000000..b11cb504e2 --- /dev/null +++ b/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc @@ -0,0 +1,32 @@ +#include "test/utils/doctest.h" +#include "pcg/parallel_computation_graph_builder.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ParallelComputationGraphBuilder") { + ParallelComputationGraphBuilder b; + + size_t batch_size = 2; + + TensorShape unpar_input_shape = { + TensorDims{FFOrdered{batch_size, 3, 10, 10}}, + DataType::FLOAT, + }; + + ParallelTensorShape input_shape = lift_to_parallel_with_degrees(unpar_input_shape, SumDegree{1}, DiscardCopyDegree{1}, FFOrdered{2, 1, 1, 1}); + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + + parallel_tensor_guid_t output = b.conv2d(input, + /*outChannels=*/5, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0); + + CHECK(get_parallel_layers(b.pcg).size() == 1); + }; +} From c379efd7c2ebce14f0326cb185b8098ccbbe20b3 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 9 Jun 2024 20:38:22 -0700 Subject: [PATCH 03/71] Add pcg tests, make dtgen constructors explicit to fix bug --- flake.lock | 6 +- .../include/kernels/legion_dim_t.dtg.h | 12 +- lib/kernels/src/kernels/legion_dim_t.dtg.cc | 12 +- lib/op-attrs/include/op-attrs/ff_dim.dtg.h | 12 +- .../op-attrs/l1_regularizer_attrs.dtg.h | 16 +- .../op-attrs/l2_regularizer_attrs.dtg.h | 16 +- .../multihead_attention_inputs.dtg.h | 26 +- .../multihead_attention_parallel_inputs.dtg.h | 16 +- .../op-attrs/ops/attention_attrs.dtg.h | 30 +- .../include/op-attrs/ops/batch_matmul.dtg.h | 17 +- .../op-attrs/ops/batch_norm_attrs.dtg.h | 16 +- .../include/op-attrs/ops/broadcast.dtg.h | 16 +- .../include/op-attrs/ops/cast_attrs.dtg.h | 16 +- .../include/op-attrs/ops/combine_attrs.dtg.h | 18 +- .../include/op-attrs/ops/concat_attrs.dtg.h | 16 +- .../ops/conv_2d/conv_2d_input_shape.dtg.h | 24 +- .../conv_2d_parallel_input_shape.dtg.h | 29 +- .../include/op-attrs/ops/conv_2d_attrs.dtg.h | 34 +- .../include/op-attrs/ops/dropout_attrs.dtg.h | 16 +- .../op-attrs/ops/element_binary_attrs.dtg.h | 22 +- .../ops/element_scalar_unary_attrs.dtg.h | 18 +- .../op-attrs/ops/element_unary_attrs.dtg.h | 16 +- .../op-attrs/ops/embedding_attrs.dtg.h | 22 +- .../include/op-attrs/ops/flat_attrs.dtg.h | 14 +- .../include/op-attrs/ops/gather_attrs.dtg.h | 16 +- .../include/op-attrs/ops/input_attrs.dtg.h | 14 +- .../op-attrs/ops/layer_norm_attrs.dtg.h | 22 +- .../include/op-attrs/ops/linear_attrs.dtg.h | 25 +- .../include/op-attrs/ops/noop_attrs.dtg.h | 14 +- .../ops/parallel_attention_inputs.dtg.h | 16 +- .../include/op-attrs/ops/pool_2d_attrs.dtg.h | 30 +- .../include/op-attrs/ops/reduce_attrs.dtg.h | 22 +- .../op-attrs/ops/reduction_attrs.dtg.h | 16 +- .../op-attrs/ops/repartition_attrs.dtg.h | 18 +- .../op-attrs/ops/replicate_attrs.dtg.h | 16 +- .../include/op-attrs/ops/reshape_attrs.dtg.h | 16 +- .../include/op-attrs/ops/reverse_attrs.dtg.h | 16 +- .../include/op-attrs/ops/softmax_attrs.dtg.h | 16 +- .../include/op-attrs/ops/split_attrs.dtg.h | 19 +- .../include/op-attrs/ops/topk_attrs.dtg.h | 16 +- .../op-attrs/ops/transpose_attrs.dtg.h | 17 +- .../include/op-attrs/ops/weight_attrs.dtg.h | 14 +- .../op-attrs/parallel_tensor_dims.dtg.h | 16 +- .../op-attrs/parallel_tensor_shape.dtg.h | 18 +- .../discard_copy_degree.dtg.h | 16 +- .../parallel_tensor_shape/sum_degree.dtg.h | 16 +- .../op-attrs/replica_parallel_dim.dtg.h | 18 +- .../op-attrs/replica_parallel_dim_set.dtg.h | 16 +- .../include/op-attrs/shard_parallel_dim.dtg.h | 16 +- .../include/op-attrs/tensor_dims.dtg.h | 16 +- .../include/op-attrs/tensor_shape.dtg.h | 18 +- lib/op-attrs/src/op-attrs/datatype.cc | 2 +- lib/op-attrs/src/op-attrs/ff_dim.dtg.cc | 13 +- .../src/op-attrs/l1_regularizer_attrs.dtg.cc | 19 +- .../src/op-attrs/l2_regularizer_attrs.dtg.cc | 19 +- .../multihead_attention_inputs.dtg.cc | 29 +- .../multihead_attention_parallel_inputs.cc | 2 +- ...multihead_attention_parallel_inputs.dtg.cc | 18 +- .../src/op-attrs/ops/attention_attrs.dtg.cc | 33 +- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 4 +- .../src/op-attrs/ops/batch_matmul.dtg.cc | 23 +- .../src/op-attrs/ops/batch_norm_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/broadcast.dtg.cc | 20 +- .../src/op-attrs/ops/cast_attrs.dtg.cc | 18 +- .../src/op-attrs/ops/combine_attrs.dtg.cc | 19 +- .../src/op-attrs/ops/concat_attrs.dtg.cc | 19 +- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 146 ++++----- .../ops/conv_2d/conv_2d_input_shape.dtg.cc | 27 +- .../conv_2d/conv_2d_parallel_input_shape.cc | 7 +- .../conv_2d_parallel_input_shape.dtg.cc | 31 +- .../src/op-attrs/ops/conv_2d_attrs.dtg.cc | 16 +- .../src/op-attrs/ops/dropout_attrs.dtg.cc | 19 +- .../op-attrs/ops/element_binary_attrs.dtg.cc | 25 +- .../ops/element_scalar_unary_attrs.dtg.cc | 21 +- .../op-attrs/ops/element_unary_attrs.dtg.cc | 19 +- lib/op-attrs/src/op-attrs/ops/embedding.cc | 7 +- .../src/op-attrs/ops/embedding_attrs.dtg.cc | 24 +- .../src/op-attrs/ops/flat_attrs.dtg.cc | 18 +- .../src/op-attrs/ops/gather_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/input_attrs.dtg.cc | 16 +- .../src/op-attrs/ops/layer_norm_attrs.dtg.cc | 17 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 16 +- .../src/op-attrs/ops/linear_attrs.dtg.cc | 16 +- .../src/op-attrs/ops/noop_attrs.dtg.cc | 18 +- .../ops/parallel_attention_inputs.dtg.cc | 23 +- .../src/op-attrs/ops/pool_2d_attrs.dtg.cc | 31 +- .../src/op-attrs/ops/reduce_attrs.dtg.cc | 16 +- .../src/op-attrs/ops/reduction_attrs.dtg.cc | 18 +- .../src/op-attrs/ops/repartition_attrs.dtg.cc | 21 +- .../src/op-attrs/ops/replicate_attrs.dtg.cc | 18 +- .../src/op-attrs/ops/reshape_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/reverse_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/softmax_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/split_attrs.dtg.cc | 21 +- .../src/op-attrs/ops/topk_attrs.dtg.cc | 21 +- .../src/op-attrs/ops/transpose_attrs.dtg.cc | 17 +- .../src/op-attrs/ops/weight_attrs.dtg.cc | 16 +- .../src/op-attrs/parallel_tensor_dims.dtg.cc | 18 +- .../src/op-attrs/parallel_tensor_shape.cc | 2 +- .../src/op-attrs/parallel_tensor_shape.dtg.cc | 21 +- .../discard_copy_degree.dtg.cc | 18 +- .../parallel_tensor_shape/sum_degree.dtg.cc | 18 +- .../src/op-attrs/replica_parallel_dim.dtg.cc | 21 +- .../src/op-attrs/replica_parallel_dim_set.cc | 2 +- .../op-attrs/replica_parallel_dim_set.dtg.cc | 24 +- .../src/op-attrs/shard_parallel_dim.dtg.cc | 22 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 3 +- lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc | 17 +- lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc | 19 +- lib/op-attrs/test/src/datatype.cc | 11 +- .../{test_attention.cc => ops/attention.cc} | 70 ++-- .../batch_matmul.cc} | 115 ++++--- lib/op-attrs/test/src/ops/combine.cc | 6 +- .../src/{test_conv_2d.cc => ops/conv_2d.cc} | 92 +++--- .../element_binary.cc} | 26 +- .../element_unary.cc} | 9 +- .../{test_embedding.cc => ops/embedding.cc} | 2 +- lib/op-attrs/test/src/ops/reduction.cc | 2 +- lib/op-attrs/test/src/ops/repartition.cc | 4 +- lib/op-attrs/test/src/ops/replicate.cc | 4 +- lib/pcg/include/pcg/computation_graph.dtg.h | 6 +- .../include/pcg/computation_graph.struct.toml | 2 +- .../layer_added_result.dtg.h | 8 +- .../layer_added_result.struct.toml | 1 + lib/pcg/include/pcg/cpu_id_t.dtg.h | 16 +- lib/pcg/include/pcg/dataflow_graph.h | 77 ----- .../pcg/dataflow_graph/dataflow_graph.h | 123 +++++++ .../operator_added_result.dtg.h | 43 +++ .../operator_added_result.struct.toml | 22 ++ lib/pcg/include/pcg/file_format/v1/graphs.h | 6 +- .../file_format/v1/graphs/v1_graph_edge.dtg.h | 18 +- .../v1/graphs/v1_graph_output.dtg.h | 12 +- .../v1/graphs/v1_jsonable_graph.dtg.h | 18 +- .../v1/graphs/v1_multidigraph.dtg.h | 18 +- .../v1/graphs/v1_multidigraph.struct.toml | 3 +- .../v1/graphs/v1_operator_graph.dtg.h | 16 +- .../v1/graphs/v1_operator_graph.struct.toml | 3 +- lib/pcg/include/pcg/gpu_id_t.dtg.h | 16 +- .../constant_initializer_attrs.dtg.h | 12 +- .../initializers/glorot_uniform_attrs.dtg.h | 16 +- .../initializers/norm_initializer_attrs.dtg.h | 18 +- .../uniform_initializer_attrs.dtg.h | 16 +- .../initializers/zero_initializer_attrs.dtg.h | 14 +- lib/pcg/include/pcg/layer_attrs.dtg.h | 15 +- lib/pcg/include/pcg/layer_guid_t.dtg.h | 6 +- .../include/pcg/machine_specification.dtg.h | 20 +- lib/pcg/include/pcg/machine_view.dtg.h | 14 +- lib/pcg/include/pcg/num_points_t.dtg.h | 16 +- .../operator_graph/operator_graph_input.dtg.h | 6 +- .../operator_graph_output.dtg.h | 6 +- .../pcg/optimizers/adam_optimizer_attrs.dtg.h | 28 +- .../pcg/optimizers/sgd_optimizer_attrs.dtg.h | 22 +- .../include/pcg/parallel_computation_graph.h | 18 -- .../parallel_computation_graph.dtg.h | 18 +- .../parallel_computation_graph.h | 32 ++ .../parallel_computation_graph.struct.toml | 6 +- .../parallel_computation_graph_builder.h | 144 +++++++++ .../parallel_layer_attrs.dtg.h | 26 +- .../parallel_layer_attrs.h | 12 + .../parallel_layer_attrs.struct.toml | 2 +- .../parallel_layer_guid_t.dtg.h | 14 +- .../parallel_layer_guid_t.struct.toml | 0 .../parallel_tensor_attrs.dtg.h | 20 +- .../parallel_tensor_attrs.struct.toml | 0 .../parallel_tensor_guid_t.dtg.h | 15 +- .../parallel_tensor_guid_t.struct.toml | 0 .../pcg/parallel_computation_graph_builder.h | 127 -------- lib/pcg/include/pcg/parallel_tensor.h | 32 -- lib/pcg/include/pcg/side_size_t.dtg.h | 16 +- lib/pcg/include/pcg/strided_rectangle.dtg.h | 16 +- .../include/pcg/strided_rectangle_side.dtg.h | 18 +- lib/pcg/include/pcg/tensor_attrs.dtg.h | 19 +- lib/pcg/include/pcg/tensor_guid_t.dtg.h | 6 +- lib/pcg/src/file_format/v1/graphs.cc | 16 +- lib/pcg/src/pcg/computation_graph.dtg.cc | 4 +- .../layer_added_result.dtg.cc | 3 +- lib/pcg/src/pcg/computation_graph_builder.cc | 64 ++-- lib/pcg/src/pcg/cpu_id_t.dtg.cc | 17 +- .../operator_added_result.dtg.cc | 60 ++++ .../v1/graphs/v1_graph_edge.dtg.cc | 18 +- .../v1/graphs/v1_graph_output.dtg.cc | 14 +- .../v1/graphs/v1_multidigraph.dtg.cc | 22 +- .../v1/graphs/v1_operator_graph.dtg.cc | 20 +- lib/pcg/src/pcg/gpu_id_t.dtg.cc | 17 +- .../constant_initializer_attrs.dtg.cc | 13 +- .../initializers/glorot_uniform_attrs.dtg.cc | 18 +- .../norm_initializer_attrs.dtg.cc | 22 +- .../uniform_initializer_attrs.dtg.cc | 17 +- .../zero_initializer_attrs.dtg.cc | 18 +- lib/pcg/src/pcg/layer_attrs.dtg.cc | 12 +- lib/pcg/src/pcg/layer_guid_t.dtg.cc | 2 +- lib/pcg/src/pcg/machine_specification.dtg.cc | 21 +- lib/pcg/src/pcg/machine_view.cc | 7 +- lib/pcg/src/pcg/machine_view.dtg.cc | 15 +- lib/pcg/src/pcg/num_points_t.dtg.cc | 16 +- .../operator_graph_input.dtg.cc | 2 +- .../operator_graph_output.dtg.cc | 2 +- .../optimizers/adam_optimizer_attrs.dtg.cc | 44 +-- .../pcg/optimizers/sgd_optimizer_attrs.dtg.cc | 32 +- lib/pcg/src/pcg/parallel_computation_graph.cc | 20 -- .../parallel_computation_graph.cc | 49 +++ .../parallel_computation_graph.dtg.cc | 12 +- .../parallel_computation_graph_builder.cc | 302 ++++++++++++++++++ .../parallel_layer_attrs.cc | 10 + .../parallel_layer_attrs.dtg.cc | 48 +-- .../parallel_layer_guid_t.dtg.cc | 6 +- .../parallel_tensor_attrs.dtg.cc | 16 +- .../parallel_tensor_guid_t.dtg.cc | 6 +- .../pcg/parallel_computation_graph_builder.cc | 293 ----------------- lib/pcg/src/pcg/side_size_t.dtg.cc | 16 +- lib/pcg/src/pcg/strided_rectangle.dtg.cc | 23 +- lib/pcg/src/pcg/strided_rectangle_side.cc | 2 +- lib/pcg/src/pcg/strided_rectangle_side.dtg.cc | 21 +- lib/pcg/src/pcg/tensor_attrs.dtg.cc | 12 +- lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 2 +- lib/pcg/test/src/pcg/dataflow_graph.cc | 48 +++ .../parallel_computation_graph_builder.cc | 125 ++++++++ .../pcg/parallel_computation_graph_builder.cc | 32 -- .../src/test_computation_graph_builder.cc | 2 +- .../operator_attribute_constraint.dtg.h | 12 +- .../operator_attribute_list_access.dtg.h | 16 +- .../operator_attribute_list_size.dtg.h | 16 +- .../operator_attribute_pattern.dtg.h | 12 +- .../output_graph/attr_constant.dtg.h | 6 +- .../output_graph/output_graph_expr.dtg.h | 5 +- .../output_operator_attr_access.dtg.h | 9 +- .../output_operator_attrs_assignment.dtg.h | 6 +- .../include/substitutions/pcg_pattern.dtg.h | 6 +- .../sub_parallel_computation_graph.dtg.h | 2 +- .../include/substitutions/substitution.dtg.h | 16 +- .../tensor_attribute_constraint.dtg.h | 12 +- .../tensor_attribute_list_access.dtg.h | 17 +- .../tensor_attribute_list_size.dtg.h | 17 +- .../tensor_attribute_pattern.dtg.h | 12 +- .../unlabelled/closed_pattern_edge.dtg.h | 6 +- .../downward_open_pattern_edge.dtg.h | 7 +- .../unlabelled/edge_splits.dtg.h | 2 +- .../unlabelled/input_pattern_edge.dtg.h | 6 +- .../match_additional_criterion.dtg.h | 2 +- .../unlabelled/match_split.dtg.h | 4 +- .../multidigraph_pattern_match.dtg.h | 2 +- .../unlabelled/output_pattern_edge.dtg.h | 6 +- .../unlabelled/pattern_edge.dtg.h | 6 +- .../unlabelled/pattern_node.dtg.h | 6 +- .../unlabelled/pattern_split.dtg.h | 11 +- .../unlabelled/unlabelled_graph_pattern.dtg.h | 3 +- .../unlabelled/upward_open_pattern_edge.dtg.h | 7 +- .../operator_attribute_constraint.dtg.cc | 12 +- .../operator_attribute_list_access.dtg.cc | 18 +- .../operator_attribute_list_size.dtg.cc | 18 +- .../operator_attribute_pattern.dtg.cc | 12 +- .../output_graph/attr_constant.dtg.cc | 2 +- .../output_operator_attr_access.dtg.cc | 2 +- .../output_operator_attrs_assignment.dtg.cc | 2 +- .../tensor_attribute_constraint.dtg.cc | 12 +- .../tensor_attribute_list_access.dtg.cc | 21 +- .../tensor_attribute_list_size.dtg.cc | 19 +- .../tensor_attribute_pattern.dtg.cc | 18 +- .../unlabelled/closed_pattern_edge.dtg.cc | 2 +- .../downward_open_pattern_edge.dtg.cc | 2 +- .../unlabelled/input_pattern_edge.dtg.cc | 2 +- .../unlabelled/output_pattern_edge.dtg.cc | 2 +- .../unlabelled/pattern_edge.dtg.cc | 2 +- .../unlabelled/pattern_node.dtg.cc | 2 +- .../unlabelled/pattern_split.dtg.cc | 10 +- .../upward_open_pattern_edge.dtg.cc | 2 +- .../utils/containers/without_nullopts.h | 22 ++ lib/utils/include/utils/fmt.decl.h | 25 -- lib/utils/include/utils/fmt.h | 41 --- lib/utils/include/utils/fmt/unordered_set.h | 41 +++ lib/utils/include/utils/fmt/vector.h | 42 +++ lib/utils/include/utils/graph/multidiedge.h | 1 + lib/utils/include/utils/stack_vector.h | 9 +- 273 files changed, 3115 insertions(+), 2564 deletions(-) rename lib/op-attrs/test/src/{test_attention.cc => ops/attention.cc} (78%) rename lib/op-attrs/test/src/{test_batch_matmul.cc => ops/batch_matmul.cc} (59%) rename lib/op-attrs/test/src/{test_conv_2d.cc => ops/conv_2d.cc} (67%) rename lib/op-attrs/test/src/{test_element_binary.cc => ops/element_binary.cc} (82%) rename lib/op-attrs/test/src/{test_element_unary.cc => ops/element_unary.cc} (83%) rename lib/op-attrs/test/src/{test_embedding.cc => ops/embedding.cc} (99%) delete mode 100644 lib/pcg/include/pcg/dataflow_graph.h create mode 100644 lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h create mode 100644 lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h create mode 100644 lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml delete mode 100644 lib/pcg/include/pcg/parallel_computation_graph.h rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_computation_graph.dtg.h (61%) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_computation_graph.struct.toml (56%) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_layer_attrs.dtg.h (58%) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.h rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_layer_attrs.struct.toml (95%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_layer_guid_t.dtg.h (64%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_layer_guid_t.struct.toml (100%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_tensor_attrs.dtg.h (69%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_tensor_attrs.struct.toml (100%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_tensor_guid_t.dtg.h (64%) rename lib/pcg/include/pcg/{ => parallel_computation_graph}/parallel_tensor_guid_t.struct.toml (100%) delete mode 100644 lib/pcg/include/pcg/parallel_computation_graph_builder.h delete mode 100644 lib/pcg/include/pcg/parallel_tensor.h create mode 100644 lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc delete mode 100644 lib/pcg/src/pcg/parallel_computation_graph.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc rename lib/pcg/src/pcg/{ => parallel_computation_graph}/parallel_computation_graph.dtg.cc (50%) create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc rename lib/pcg/src/pcg/{ => parallel_computation_graph}/parallel_layer_attrs.dtg.cc (55%) rename lib/pcg/src/pcg/{ => parallel_computation_graph}/parallel_layer_guid_t.dtg.cc (90%) rename lib/pcg/src/pcg/{ => parallel_computation_graph}/parallel_tensor_attrs.dtg.cc (91%) rename lib/pcg/src/pcg/{ => parallel_computation_graph}/parallel_tensor_guid_t.dtg.cc (90%) delete mode 100644 lib/pcg/src/pcg/parallel_computation_graph_builder.cc create mode 100644 lib/pcg/test/src/pcg/dataflow_graph.cc create mode 100644 lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc delete mode 100644 lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc create mode 100644 lib/utils/include/utils/containers/without_nullopts.h create mode 100644 lib/utils/include/utils/fmt/unordered_set.h create mode 100644 lib/utils/include/utils/fmt/vector.h diff --git a/flake.lock b/flake.lock index f0fc292a5e..dde0c989c3 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1717449667, - "narHash": "sha256-xFGnB44WadxlCa2LnlH82g1c89+7UAomVgytIewSwO0=", + "lastModified": 1717990636, + "narHash": "sha256-wqIc2qAkRfVp2d+NAVIYPKMx7YYpu8iBGHHT1U5sxhE=", "owner": "lockshaw", "repo": "proj", - "rev": "28b37a9bd993d3de3d80695eb3834a0436c805a4", + "rev": "f7e20a9c232dda1b945a775d91e1ed4f525b5f51", "type": "github" }, "original": { diff --git a/lib/kernels/include/kernels/legion_dim_t.dtg.h b/lib/kernels/include/kernels/legion_dim_t.dtg.h index 622f9c240a..3dbdfb55d8 100644 --- a/lib/kernels/include/kernels/legion_dim_t.dtg.h +++ b/lib/kernels/include/kernels/legion_dim_t.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct legion_dim_t { legion_dim_t() = delete; - legion_dim_t(int const &value); + explicit legion_dim_t(int const &value); bool operator==(legion_dim_t const &) const; bool operator!=(legion_dim_t const &) const; @@ -33,16 +33,16 @@ struct legion_dim_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::legion_dim_t const &) const; +struct hash<::FlexFlow::legion_dim_t> { + size_t operator()(::FlexFlow::legion_dim_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::legion_dim_t from_json(json const &); - static void to_json(json &, FlexFlow::legion_dim_t const &); +struct adl_serializer<::FlexFlow::legion_dim_t> { + static ::FlexFlow::legion_dim_t from_json(json const &); + static void to_json(json &, ::FlexFlow::legion_dim_t const &); }; } // namespace nlohmann diff --git a/lib/kernels/src/kernels/legion_dim_t.dtg.cc b/lib/kernels/src/kernels/legion_dim_t.dtg.cc index 99c1a3b3a2..bb85e4b9dd 100644 --- a/lib/kernels/src/kernels/legion_dim_t.dtg.cc +++ b/lib/kernels/src/kernels/legion_dim_t.dtg.cc @@ -35,7 +35,7 @@ bool legion_dim_t::operator>=(legion_dim_t const &other) const { namespace std { size_t hash::operator()( - FlexFlow::legion_dim_t const &x) const { + ::FlexFlow::legion_dim_t const &x) const { size_t result = 0; result ^= std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,12 +44,12 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::legion_dim_t - adl_serializer::from_json(json const &j) { - return {j.at("value").template get()}; +::FlexFlow::legion_dim_t + adl_serializer<::FlexFlow::legion_dim_t>::from_json(json const &j) { + return ::FlexFlow::legion_dim_t{j.at("value").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::legion_dim_t const &v) { +void adl_serializer<::FlexFlow::legion_dim_t>::to_json( + json &j, ::FlexFlow::legion_dim_t const &v) { j["__type"] = "legion_dim_t"; j["value"] = v.value; } diff --git a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h index 1697f78196..f7df8f414b 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct ff_dim_t { ff_dim_t() = delete; - ff_dim_t(int const &value); + explicit ff_dim_t(int const &value); bool operator==(ff_dim_t const &) const; bool operator!=(ff_dim_t const &) const; @@ -33,16 +33,16 @@ struct ff_dim_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ff_dim_t const &) const; +struct hash<::FlexFlow::ff_dim_t> { + size_t operator()(::FlexFlow::ff_dim_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ff_dim_t from_json(json const &); - static void to_json(json &, FlexFlow::ff_dim_t const &); +struct adl_serializer<::FlexFlow::ff_dim_t> { + static ::FlexFlow::ff_dim_t from_json(json const &); + static void to_json(json &, ::FlexFlow::ff_dim_t const &); }; } // namespace nlohmann diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h index 1d4747db7e..9981219ca4 100644 --- a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct L1RegularizerAttrs { L1RegularizerAttrs() = delete; - L1RegularizerAttrs(float const &lambda); + explicit L1RegularizerAttrs(float const &lambda); bool operator==(L1RegularizerAttrs const &) const; bool operator!=(L1RegularizerAttrs const &) const; @@ -34,23 +34,23 @@ struct L1RegularizerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::L1RegularizerAttrs const &) const; +struct hash<::FlexFlow::L1RegularizerAttrs> { + size_t operator()(::FlexFlow::L1RegularizerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::L1RegularizerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::L1RegularizerAttrs const &); +struct adl_serializer<::FlexFlow::L1RegularizerAttrs> { + static ::FlexFlow::L1RegularizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::L1RegularizerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::L1RegularizerAttrs> { + static Gen<::FlexFlow::L1RegularizerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h index 981d3f4905..cd26069de1 100644 --- a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct L2RegularizerAttrs { L2RegularizerAttrs() = delete; - L2RegularizerAttrs(float const &lambda); + explicit L2RegularizerAttrs(float const &lambda); bool operator==(L2RegularizerAttrs const &) const; bool operator!=(L2RegularizerAttrs const &) const; @@ -34,23 +34,23 @@ struct L2RegularizerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::L2RegularizerAttrs const &) const; +struct hash<::FlexFlow::L2RegularizerAttrs> { + size_t operator()(::FlexFlow::L2RegularizerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::L2RegularizerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::L2RegularizerAttrs const &); +struct adl_serializer<::FlexFlow::L2RegularizerAttrs> { + static ::FlexFlow::L2RegularizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::L2RegularizerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::L2RegularizerAttrs> { + static Gen<::FlexFlow::L2RegularizerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h index 7b61305a1a..815ca5edea 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h @@ -22,12 +22,12 @@ namespace FlexFlow { struct MultiHeadAttentionInputs { MultiHeadAttentionInputs() = delete; - MultiHeadAttentionInputs(size_t const &batch_size, - size_t const &sequence_length, - size_t const &query_size, - size_t const &key_size, - size_t const &value_size, - ::FlexFlow::DataType const &datatype); + explicit MultiHeadAttentionInputs(size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype); bool operator==(MultiHeadAttentionInputs const &) const; bool operator!=(MultiHeadAttentionInputs const &) const; @@ -46,23 +46,23 @@ struct MultiHeadAttentionInputs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; +struct hash<::FlexFlow::MultiHeadAttentionInputs> { + size_t operator()(::FlexFlow::MultiHeadAttentionInputs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::MultiHeadAttentionInputs from_json(json const &); - static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); +struct adl_serializer<::FlexFlow::MultiHeadAttentionInputs> { + static ::FlexFlow::MultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, ::FlexFlow::MultiHeadAttentionInputs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::MultiHeadAttentionInputs> { + static Gen<::FlexFlow::MultiHeadAttentionInputs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h index 297b1f8f1c..fa7c83a881 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h @@ -25,7 +25,7 @@ namespace FlexFlow { struct MultiHeadAttentionParallelInputs { MultiHeadAttentionParallelInputs() = delete; - MultiHeadAttentionParallelInputs( + explicit MultiHeadAttentionParallelInputs( ::FlexFlow::ShardParallelDim const &batch_dim, ::FlexFlow::ShardParallelDim const &sequence_dim, ::FlexFlow::ShardParallelDim const &query_dim, @@ -52,24 +52,24 @@ struct MultiHeadAttentionParallelInputs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::MultiHeadAttentionParallelInputs const &) const; +struct hash<::FlexFlow::MultiHeadAttentionParallelInputs> { + size_t operator()(::FlexFlow::MultiHeadAttentionParallelInputs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::MultiHeadAttentionParallelInputs from_json(json const &); +struct adl_serializer<::FlexFlow::MultiHeadAttentionParallelInputs> { + static ::FlexFlow::MultiHeadAttentionParallelInputs from_json(json const &); static void to_json(json &, - FlexFlow::MultiHeadAttentionParallelInputs const &); + ::FlexFlow::MultiHeadAttentionParallelInputs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::MultiHeadAttentionParallelInputs> { + static Gen<::FlexFlow::MultiHeadAttentionParallelInputs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h index 18b2906759..8eef2df2eb 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h @@ -20,14 +20,14 @@ namespace FlexFlow { struct MultiHeadAttentionAttrs { MultiHeadAttentionAttrs() = delete; - MultiHeadAttentionAttrs(int const &embed_dim, - int const &num_heads, - int const &kdim, - int const &vdim, - float const &dropout, - bool const &bias, - bool const &add_bias_kv, - bool const &add_zero_attn); + explicit MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn); bool operator==(MultiHeadAttentionAttrs const &) const; bool operator!=(MultiHeadAttentionAttrs const &) const; @@ -48,23 +48,23 @@ struct MultiHeadAttentionAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::MultiHeadAttentionAttrs const &) const; +struct hash<::FlexFlow::MultiHeadAttentionAttrs> { + size_t operator()(::FlexFlow::MultiHeadAttentionAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::MultiHeadAttentionAttrs from_json(json const &); - static void to_json(json &, FlexFlow::MultiHeadAttentionAttrs const &); +struct adl_serializer<::FlexFlow::MultiHeadAttentionAttrs> { + static ::FlexFlow::MultiHeadAttentionAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::MultiHeadAttentionAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::MultiHeadAttentionAttrs> { + static Gen<::FlexFlow::MultiHeadAttentionAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h index a8ab52d2b3..64c4dd9ae3 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h @@ -20,7 +20,8 @@ namespace FlexFlow { struct BatchMatmulAttrs { BatchMatmulAttrs() = delete; - BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); + explicit BatchMatmulAttrs(int const &a_seq_length_dim, + int const &b_seq_length_dim); bool operator==(BatchMatmulAttrs const &) const; bool operator!=(BatchMatmulAttrs const &) const; @@ -35,23 +36,23 @@ struct BatchMatmulAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; +struct hash<::FlexFlow::BatchMatmulAttrs> { + size_t operator()(::FlexFlow::BatchMatmulAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::BatchMatmulAttrs from_json(json const &); - static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); +struct adl_serializer<::FlexFlow::BatchMatmulAttrs> { + static ::FlexFlow::BatchMatmulAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::BatchMatmulAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::BatchMatmulAttrs> { + static Gen<::FlexFlow::BatchMatmulAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h index f153bfde7e..a7d29d565c 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct BatchNormAttrs { BatchNormAttrs() = delete; - BatchNormAttrs(bool const &relu); + explicit BatchNormAttrs(bool const &relu); bool operator==(BatchNormAttrs const &) const; bool operator!=(BatchNormAttrs const &) const; @@ -34,23 +34,23 @@ struct BatchNormAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::BatchNormAttrs const &) const; +struct hash<::FlexFlow::BatchNormAttrs> { + size_t operator()(::FlexFlow::BatchNormAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::BatchNormAttrs from_json(json const &); - static void to_json(json &, FlexFlow::BatchNormAttrs const &); +struct adl_serializer<::FlexFlow::BatchNormAttrs> { + static ::FlexFlow::BatchNormAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::BatchNormAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::BatchNormAttrs> { + static Gen<::FlexFlow::BatchNormAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h index e4de3dcc75..baff0fdad5 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct BroadcastAttrs { BroadcastAttrs() = delete; - BroadcastAttrs( + explicit BroadcastAttrs( ::FlexFlow::stack_vector const &target_dims); bool operator==(BroadcastAttrs const &) const; @@ -36,23 +36,23 @@ struct BroadcastAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::BroadcastAttrs const &) const; +struct hash<::FlexFlow::BroadcastAttrs> { + size_t operator()(::FlexFlow::BroadcastAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::BroadcastAttrs from_json(json const &); - static void to_json(json &, FlexFlow::BroadcastAttrs const &); +struct adl_serializer<::FlexFlow::BroadcastAttrs> { + static ::FlexFlow::BroadcastAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::BroadcastAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::BroadcastAttrs> { + static Gen<::FlexFlow::BroadcastAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h index 33391eb221..0cfb1c2161 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct CastAttrs { CastAttrs() = delete; - CastAttrs(DataType const &dtype); + explicit CastAttrs(DataType const &dtype); bool operator==(CastAttrs const &) const; bool operator!=(CastAttrs const &) const; @@ -35,23 +35,23 @@ struct CastAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::CastAttrs const &) const; +struct hash<::FlexFlow::CastAttrs> { + size_t operator()(::FlexFlow::CastAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::CastAttrs from_json(json const &); - static void to_json(json &, FlexFlow::CastAttrs const &); +struct adl_serializer<::FlexFlow::CastAttrs> { + static ::FlexFlow::CastAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::CastAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::CastAttrs> { + static Gen<::FlexFlow::CastAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h index 43db204bc5..a9f2385fed 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h @@ -22,8 +22,8 @@ namespace FlexFlow { struct CombineAttrs { CombineAttrs() = delete; - CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, - int const &combine_degree); + explicit CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree); bool operator==(CombineAttrs const &) const; bool operator!=(CombineAttrs const &) const; @@ -38,23 +38,23 @@ struct CombineAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::CombineAttrs const &) const; +struct hash<::FlexFlow::CombineAttrs> { + size_t operator()(::FlexFlow::CombineAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::CombineAttrs from_json(json const &); - static void to_json(json &, FlexFlow::CombineAttrs const &); +struct adl_serializer<::FlexFlow::CombineAttrs> { + static ::FlexFlow::CombineAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::CombineAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::CombineAttrs> { + static Gen<::FlexFlow::CombineAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h index 3c26473a4e..435cc08f90 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct ConcatAttrs { ConcatAttrs() = delete; - ConcatAttrs(::FlexFlow::ff_dim_t const &axis, int const &num_inputs); + explicit ConcatAttrs(::FlexFlow::ff_dim_t const &axis, int const &num_inputs); bool operator==(ConcatAttrs const &) const; bool operator!=(ConcatAttrs const &) const; @@ -37,23 +37,23 @@ struct ConcatAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ConcatAttrs const &) const; +struct hash<::FlexFlow::ConcatAttrs> { + size_t operator()(::FlexFlow::ConcatAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ConcatAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ConcatAttrs const &); +struct adl_serializer<::FlexFlow::ConcatAttrs> { + static ::FlexFlow::ConcatAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ConcatAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ConcatAttrs> { + static Gen<::FlexFlow::ConcatAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h index 2e7833064c..353213e33f 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h @@ -22,11 +22,11 @@ namespace FlexFlow { struct Conv2DInputShape { Conv2DInputShape() = delete; - Conv2DInputShape(size_t const &num_samples, - size_t const &num_channels, - size_t const &height, - size_t const &width, - ::FlexFlow::DataType const &datatype); + explicit Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype); bool operator==(Conv2DInputShape const &) const; bool operator!=(Conv2DInputShape const &) const; @@ -44,23 +44,23 @@ struct Conv2DInputShape { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::Conv2DInputShape const &) const; +struct hash<::FlexFlow::Conv2DInputShape> { + size_t operator()(::FlexFlow::Conv2DInputShape const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::Conv2DInputShape from_json(json const &); - static void to_json(json &, FlexFlow::Conv2DInputShape const &); +struct adl_serializer<::FlexFlow::Conv2DInputShape> { + static ::FlexFlow::Conv2DInputShape from_json(json const &); + static void to_json(json &, ::FlexFlow::Conv2DInputShape const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::Conv2DInputShape> { + static Gen<::FlexFlow::Conv2DInputShape> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h index 846c9e413a..0b02d74a4b 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h @@ -22,13 +22,14 @@ namespace FlexFlow { struct Conv2DParallelInputShape { Conv2DParallelInputShape() = delete; - Conv2DParallelInputShape(::FlexFlow::ShardParallelDim const &sample_dim, - ::FlexFlow::ShardParallelDim const &channel_dim, - ::FlexFlow::ShardParallelDim const &height_dim, - ::FlexFlow::ShardParallelDim const &width_dim, - int const &sum_reduction_degree, - int const &discard_copy_reduction_degree, - ::FlexFlow::DataType const &datatype); + explicit Conv2DParallelInputShape( + ::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype); bool operator==(Conv2DParallelInputShape const &) const; bool operator!=(Conv2DParallelInputShape const &) const; @@ -48,23 +49,23 @@ struct Conv2DParallelInputShape { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::Conv2DParallelInputShape const &) const; +struct hash<::FlexFlow::Conv2DParallelInputShape> { + size_t operator()(::FlexFlow::Conv2DParallelInputShape const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::Conv2DParallelInputShape from_json(json const &); - static void to_json(json &, FlexFlow::Conv2DParallelInputShape const &); +struct adl_serializer<::FlexFlow::Conv2DParallelInputShape> { + static ::FlexFlow::Conv2DParallelInputShape from_json(json const &); + static void to_json(json &, ::FlexFlow::Conv2DParallelInputShape const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::Conv2DParallelInputShape> { + static Gen<::FlexFlow::Conv2DParallelInputShape> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h index 06827656da..0602a6eb92 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h @@ -23,16 +23,16 @@ namespace FlexFlow { struct Conv2DAttrs { Conv2DAttrs() = delete; - Conv2DAttrs(int const &out_channels, - int const &kernel_h, - int const &kernel_w, - int const &stride_h, - int const &stride_w, - int const &padding_h, - int const &padding_w, - int const &groups, - std::optional<::FlexFlow::Activation> const &activation, - bool const &use_bias); + explicit Conv2DAttrs(int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias); bool operator==(Conv2DAttrs const &) const; bool operator!=(Conv2DAttrs const &) const; @@ -55,23 +55,23 @@ struct Conv2DAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::Conv2DAttrs const &) const; +struct hash<::FlexFlow::Conv2DAttrs> { + size_t operator()(::FlexFlow::Conv2DAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::Conv2DAttrs from_json(json const &); - static void to_json(json &, FlexFlow::Conv2DAttrs const &); +struct adl_serializer<::FlexFlow::Conv2DAttrs> { + static ::FlexFlow::Conv2DAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::Conv2DAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::Conv2DAttrs> { + static Gen<::FlexFlow::Conv2DAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h index ef86e49560..433e2c8aa7 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct DropoutAttrs { DropoutAttrs() = delete; - DropoutAttrs(float const &rate, unsigned long long const &seed); + explicit DropoutAttrs(float const &rate, unsigned long long const &seed); bool operator==(DropoutAttrs const &) const; bool operator!=(DropoutAttrs const &) const; @@ -35,23 +35,23 @@ struct DropoutAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::DropoutAttrs const &) const; +struct hash<::FlexFlow::DropoutAttrs> { + size_t operator()(::FlexFlow::DropoutAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::DropoutAttrs from_json(json const &); - static void to_json(json &, FlexFlow::DropoutAttrs const &); +struct adl_serializer<::FlexFlow::DropoutAttrs> { + static ::FlexFlow::DropoutAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::DropoutAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::DropoutAttrs> { + static Gen<::FlexFlow::DropoutAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h index 10d93c87d3..c4049f9c8d 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h @@ -22,10 +22,10 @@ namespace FlexFlow { struct ElementBinaryAttrs { ElementBinaryAttrs() = delete; - ElementBinaryAttrs(::FlexFlow::OperatorType const &type, - ::FlexFlow::DataType const &compute_type, - bool const &should_broadcast_lhs, - bool const &should_broadcast_rhs); + explicit ElementBinaryAttrs(::FlexFlow::OperatorType const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs); bool operator==(ElementBinaryAttrs const &) const; bool operator!=(ElementBinaryAttrs const &) const; @@ -42,23 +42,23 @@ struct ElementBinaryAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ElementBinaryAttrs const &) const; +struct hash<::FlexFlow::ElementBinaryAttrs> { + size_t operator()(::FlexFlow::ElementBinaryAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ElementBinaryAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ElementBinaryAttrs const &); +struct adl_serializer<::FlexFlow::ElementBinaryAttrs> { + static ::FlexFlow::ElementBinaryAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ElementBinaryAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ElementBinaryAttrs> { + static Gen<::FlexFlow::ElementBinaryAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h index a9fe63ca71..b05185e3d3 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h @@ -21,8 +21,8 @@ namespace FlexFlow { struct ElementScalarUnaryAttrs { ElementScalarUnaryAttrs() = delete; - ElementScalarUnaryAttrs(::FlexFlow::OperatorType const &op_type, - float const &scalar); + explicit ElementScalarUnaryAttrs(::FlexFlow::OperatorType const &op_type, + float const &scalar); bool operator==(ElementScalarUnaryAttrs const &) const; bool operator!=(ElementScalarUnaryAttrs const &) const; @@ -37,23 +37,23 @@ struct ElementScalarUnaryAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ElementScalarUnaryAttrs const &) const; +struct hash<::FlexFlow::ElementScalarUnaryAttrs> { + size_t operator()(::FlexFlow::ElementScalarUnaryAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ElementScalarUnaryAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ElementScalarUnaryAttrs const &); +struct adl_serializer<::FlexFlow::ElementScalarUnaryAttrs> { + static ::FlexFlow::ElementScalarUnaryAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ElementScalarUnaryAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ElementScalarUnaryAttrs> { + static Gen<::FlexFlow::ElementScalarUnaryAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h index 3220234bd1..87b8940706 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct ElementUnaryAttrs { ElementUnaryAttrs() = delete; - ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type); + explicit ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type); bool operator==(ElementUnaryAttrs const &) const; bool operator!=(ElementUnaryAttrs const &) const; @@ -35,23 +35,23 @@ struct ElementUnaryAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ElementUnaryAttrs const &) const; +struct hash<::FlexFlow::ElementUnaryAttrs> { + size_t operator()(::FlexFlow::ElementUnaryAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ElementUnaryAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ElementUnaryAttrs const &); +struct adl_serializer<::FlexFlow::ElementUnaryAttrs> { + static ::FlexFlow::ElementUnaryAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ElementUnaryAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ElementUnaryAttrs> { + static Gen<::FlexFlow::ElementUnaryAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h index f1cae86460..7b1eb8d2f7 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -23,10 +23,10 @@ namespace FlexFlow { struct EmbeddingAttrs { EmbeddingAttrs() = delete; - EmbeddingAttrs(int const &num_entries, - int const &out_channels, - std::optional<::FlexFlow::AggregateOp> const &aggr, - ::FlexFlow::DataType const &data_type); + explicit EmbeddingAttrs(int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type); bool operator==(EmbeddingAttrs const &) const; bool operator!=(EmbeddingAttrs const &) const; @@ -43,23 +43,23 @@ struct EmbeddingAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::EmbeddingAttrs const &) const; +struct hash<::FlexFlow::EmbeddingAttrs> { + size_t operator()(::FlexFlow::EmbeddingAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::EmbeddingAttrs from_json(json const &); - static void to_json(json &, FlexFlow::EmbeddingAttrs const &); +struct adl_serializer<::FlexFlow::EmbeddingAttrs> { + static ::FlexFlow::EmbeddingAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::EmbeddingAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::EmbeddingAttrs> { + static Gen<::FlexFlow::EmbeddingAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h index a94c0aeff3..a8b74af565 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h @@ -30,23 +30,23 @@ struct FlatAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::FlatAttrs const &) const; +struct hash<::FlexFlow::FlatAttrs> { + size_t operator()(::FlexFlow::FlatAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::FlatAttrs from_json(json const &); - static void to_json(json &, FlexFlow::FlatAttrs const &); +struct adl_serializer<::FlexFlow::FlatAttrs> { + static ::FlexFlow::FlatAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::FlatAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::FlatAttrs> { + static Gen<::FlexFlow::FlatAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h index e7a35e5800..84835bc850 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct GatherAttrs { GatherAttrs() = delete; - GatherAttrs(::FlexFlow::ff_dim_t const &dim); + explicit GatherAttrs(::FlexFlow::ff_dim_t const &dim); bool operator==(GatherAttrs const &) const; bool operator!=(GatherAttrs const &) const; @@ -36,23 +36,23 @@ struct GatherAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::GatherAttrs const &) const; +struct hash<::FlexFlow::GatherAttrs> { + size_t operator()(::FlexFlow::GatherAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::GatherAttrs from_json(json const &); - static void to_json(json &, FlexFlow::GatherAttrs const &); +struct adl_serializer<::FlexFlow::GatherAttrs> { + static ::FlexFlow::GatherAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::GatherAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::GatherAttrs> { + static Gen<::FlexFlow::GatherAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h index aa2ca1e933..729b47dedc 100644 --- a/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h @@ -30,23 +30,23 @@ struct InputAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::InputAttrs const &) const; +struct hash<::FlexFlow::InputAttrs> { + size_t operator()(::FlexFlow::InputAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::InputAttrs from_json(json const &); - static void to_json(json &, FlexFlow::InputAttrs const &); +struct adl_serializer<::FlexFlow::InputAttrs> { + static ::FlexFlow::InputAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::InputAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::InputAttrs> { + static Gen<::FlexFlow::InputAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h index c945206863..e480544815 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h @@ -23,10 +23,10 @@ namespace FlexFlow { struct LayerNormAttrs { LayerNormAttrs() = delete; - LayerNormAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, - MAX_TENSOR_DIM> const &axes, - bool const &elementwise_affine, - float const &eps); + explicit LayerNormAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps); bool operator==(LayerNormAttrs const &) const; bool operator!=(LayerNormAttrs const &) const; @@ -42,23 +42,23 @@ struct LayerNormAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::LayerNormAttrs const &) const; +struct hash<::FlexFlow::LayerNormAttrs> { + size_t operator()(::FlexFlow::LayerNormAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::LayerNormAttrs from_json(json const &); - static void to_json(json &, FlexFlow::LayerNormAttrs const &); +struct adl_serializer<::FlexFlow::LayerNormAttrs> { + static ::FlexFlow::LayerNormAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::LayerNormAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::LayerNormAttrs> { + static Gen<::FlexFlow::LayerNormAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h index 28cd2a8b33..a00dc65ccb 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -24,11 +24,12 @@ namespace FlexFlow { struct LinearAttrs { LinearAttrs() = delete; - LinearAttrs(int const &out_channels, - bool const &use_bias, - ::FlexFlow::DataType const &data_type, - std::optional<::FlexFlow::Activation> const &activation, - std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); + explicit LinearAttrs( + int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + std::optional<::FlexFlow::Activation> const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); bool operator==(LinearAttrs const &) const; bool operator!=(LinearAttrs const &) const; @@ -46,23 +47,23 @@ struct LinearAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::LinearAttrs const &) const; +struct hash<::FlexFlow::LinearAttrs> { + size_t operator()(::FlexFlow::LinearAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::LinearAttrs from_json(json const &); - static void to_json(json &, FlexFlow::LinearAttrs const &); +struct adl_serializer<::FlexFlow::LinearAttrs> { + static ::FlexFlow::LinearAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::LinearAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::LinearAttrs> { + static Gen<::FlexFlow::LinearAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h index ed0d8c9348..528926cc0c 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h @@ -30,23 +30,23 @@ struct NoopAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::NoopAttrs const &) const; +struct hash<::FlexFlow::NoopAttrs> { + size_t operator()(::FlexFlow::NoopAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::NoopAttrs from_json(json const &); - static void to_json(json &, FlexFlow::NoopAttrs const &); +struct adl_serializer<::FlexFlow::NoopAttrs> { + static ::FlexFlow::NoopAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::NoopAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::NoopAttrs> { + static Gen<::FlexFlow::NoopAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h index d3903bd3b2..f6a739473a 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct ParallelMultiHeadAttentionInputs { ParallelMultiHeadAttentionInputs() = delete; - ParallelMultiHeadAttentionInputs( + explicit ParallelMultiHeadAttentionInputs( ::FlexFlow::ParallelTensorShape const &query, ::FlexFlow::ParallelTensorShape const &key, ::FlexFlow::ParallelTensorShape const &value); @@ -36,24 +36,24 @@ struct ParallelMultiHeadAttentionInputs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelMultiHeadAttentionInputs const &) const; +struct hash<::FlexFlow::ParallelMultiHeadAttentionInputs> { + size_t operator()(::FlexFlow::ParallelMultiHeadAttentionInputs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelMultiHeadAttentionInputs from_json(json const &); +struct adl_serializer<::FlexFlow::ParallelMultiHeadAttentionInputs> { + static ::FlexFlow::ParallelMultiHeadAttentionInputs from_json(json const &); static void to_json(json &, - FlexFlow::ParallelMultiHeadAttentionInputs const &); + ::FlexFlow::ParallelMultiHeadAttentionInputs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ParallelMultiHeadAttentionInputs> { + static Gen<::FlexFlow::ParallelMultiHeadAttentionInputs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h index a5c6603302..ef779217cd 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h @@ -22,14 +22,14 @@ namespace FlexFlow { struct Pool2DAttrs { Pool2DAttrs() = delete; - Pool2DAttrs(int const &kernel_h, - int const &kernel_w, - int const &stride_h, - int const &stride_w, - int const &padding_h, - int const &padding_w, - ::FlexFlow::PoolOp const &pool_type, - ::FlexFlow::Activation const &activation); + explicit Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation); bool operator==(Pool2DAttrs const &) const; bool operator!=(Pool2DAttrs const &) const; @@ -50,23 +50,23 @@ struct Pool2DAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::Pool2DAttrs const &) const; +struct hash<::FlexFlow::Pool2DAttrs> { + size_t operator()(::FlexFlow::Pool2DAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::Pool2DAttrs from_json(json const &); - static void to_json(json &, FlexFlow::Pool2DAttrs const &); +struct adl_serializer<::FlexFlow::Pool2DAttrs> { + static ::FlexFlow::Pool2DAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::Pool2DAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::Pool2DAttrs> { + static Gen<::FlexFlow::Pool2DAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h index af27bf35be..1710687b36 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h @@ -24,10 +24,10 @@ namespace FlexFlow { struct ReduceAttrs { ReduceAttrs() = delete; - ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, - MAX_TENSOR_DIM> const &axes, - ::FlexFlow::OperatorType const &op_type, - bool const &keepdims); + explicit ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + ::FlexFlow::OperatorType const &op_type, + bool const &keepdims); bool operator==(ReduceAttrs const &) const; bool operator!=(ReduceAttrs const &) const; @@ -43,23 +43,23 @@ struct ReduceAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReduceAttrs const &) const; +struct hash<::FlexFlow::ReduceAttrs> { + size_t operator()(::FlexFlow::ReduceAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReduceAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ReduceAttrs const &); +struct adl_serializer<::FlexFlow::ReduceAttrs> { + static ::FlexFlow::ReduceAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ReduceAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReduceAttrs> { + static Gen<::FlexFlow::ReduceAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h index 9de5eb2252..f742ce46fb 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct ReductionAttrs { ReductionAttrs() = delete; - ReductionAttrs(int const &reduction_degree); + explicit ReductionAttrs(int const &reduction_degree); bool operator==(ReductionAttrs const &) const; bool operator!=(ReductionAttrs const &) const; @@ -34,23 +34,23 @@ struct ReductionAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReductionAttrs const &) const; +struct hash<::FlexFlow::ReductionAttrs> { + size_t operator()(::FlexFlow::ReductionAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReductionAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ReductionAttrs const &); +struct adl_serializer<::FlexFlow::ReductionAttrs> { + static ::FlexFlow::ReductionAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ReductionAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReductionAttrs> { + static Gen<::FlexFlow::ReductionAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h index 66c21466f4..33f32f709c 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h @@ -22,8 +22,8 @@ namespace FlexFlow { struct RepartitionAttrs { RepartitionAttrs() = delete; - RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, - int const &repartition_degree); + explicit RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree); bool operator==(RepartitionAttrs const &) const; bool operator!=(RepartitionAttrs const &) const; @@ -38,23 +38,23 @@ struct RepartitionAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::RepartitionAttrs const &) const; +struct hash<::FlexFlow::RepartitionAttrs> { + size_t operator()(::FlexFlow::RepartitionAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::RepartitionAttrs from_json(json const &); - static void to_json(json &, FlexFlow::RepartitionAttrs const &); +struct adl_serializer<::FlexFlow::RepartitionAttrs> { + static ::FlexFlow::RepartitionAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::RepartitionAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::RepartitionAttrs> { + static Gen<::FlexFlow::RepartitionAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h index ea3f0d46c7..53a9a05337 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct ReplicateAttrs { ReplicateAttrs() = delete; - ReplicateAttrs(int const &replicate_degree); + explicit ReplicateAttrs(int const &replicate_degree); bool operator==(ReplicateAttrs const &) const; bool operator!=(ReplicateAttrs const &) const; @@ -34,23 +34,23 @@ struct ReplicateAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReplicateAttrs const &) const; +struct hash<::FlexFlow::ReplicateAttrs> { + size_t operator()(::FlexFlow::ReplicateAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReplicateAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ReplicateAttrs const &); +struct adl_serializer<::FlexFlow::ReplicateAttrs> { + static ::FlexFlow::ReplicateAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ReplicateAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReplicateAttrs> { + static Gen<::FlexFlow::ReplicateAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h index 612874790f..1d16e9eccb 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct ReshapeAttrs { ReshapeAttrs() = delete; - ReshapeAttrs(::FlexFlow::TensorShape const &shape); + explicit ReshapeAttrs(::FlexFlow::TensorShape const &shape); bool operator==(ReshapeAttrs const &) const; bool operator!=(ReshapeAttrs const &) const; @@ -35,23 +35,23 @@ struct ReshapeAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReshapeAttrs const &) const; +struct hash<::FlexFlow::ReshapeAttrs> { + size_t operator()(::FlexFlow::ReshapeAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReshapeAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ReshapeAttrs const &); +struct adl_serializer<::FlexFlow::ReshapeAttrs> { + static ::FlexFlow::ReshapeAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ReshapeAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReshapeAttrs> { + static Gen<::FlexFlow::ReshapeAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h index 8c8c8a7a9e..94037c653d 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct ReverseAttrs { ReverseAttrs() = delete; - ReverseAttrs(::FlexFlow::ff_dim_t const &axis); + explicit ReverseAttrs(::FlexFlow::ff_dim_t const &axis); bool operator==(ReverseAttrs const &) const; bool operator!=(ReverseAttrs const &) const; @@ -36,23 +36,23 @@ struct ReverseAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReverseAttrs const &) const; +struct hash<::FlexFlow::ReverseAttrs> { + size_t operator()(::FlexFlow::ReverseAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReverseAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ReverseAttrs const &); +struct adl_serializer<::FlexFlow::ReverseAttrs> { + static ::FlexFlow::ReverseAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ReverseAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReverseAttrs> { + static Gen<::FlexFlow::ReverseAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h index 1c855d90f4..5705c7a882 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct SoftmaxAttrs { SoftmaxAttrs() = delete; - SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim); + explicit SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim); bool operator==(SoftmaxAttrs const &) const; bool operator!=(SoftmaxAttrs const &) const; @@ -36,23 +36,23 @@ struct SoftmaxAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::SoftmaxAttrs const &) const; +struct hash<::FlexFlow::SoftmaxAttrs> { + size_t operator()(::FlexFlow::SoftmaxAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::SoftmaxAttrs from_json(json const &); - static void to_json(json &, FlexFlow::SoftmaxAttrs const &); +struct adl_serializer<::FlexFlow::SoftmaxAttrs> { + static ::FlexFlow::SoftmaxAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::SoftmaxAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::SoftmaxAttrs> { + static Gen<::FlexFlow::SoftmaxAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h index b602015e2e..baf0a8f305 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h @@ -23,8 +23,9 @@ namespace FlexFlow { struct SplitAttrs { SplitAttrs() = delete; - SplitAttrs(::FlexFlow::stack_vector const &splits, - ::FlexFlow::ff_dim_t const &axis); + explicit SplitAttrs( + ::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis); bool operator==(SplitAttrs const &) const; bool operator!=(SplitAttrs const &) const; @@ -39,23 +40,23 @@ struct SplitAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::SplitAttrs const &) const; +struct hash<::FlexFlow::SplitAttrs> { + size_t operator()(::FlexFlow::SplitAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::SplitAttrs from_json(json const &); - static void to_json(json &, FlexFlow::SplitAttrs const &); +struct adl_serializer<::FlexFlow::SplitAttrs> { + static ::FlexFlow::SplitAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::SplitAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::SplitAttrs> { + static Gen<::FlexFlow::SplitAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h index d1f32f67b7..ef09bc3b16 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct TopKAttrs { TopKAttrs() = delete; - TopKAttrs(int const &k, bool const &sorted); + explicit TopKAttrs(int const &k, bool const &sorted); bool operator==(TopKAttrs const &) const; bool operator!=(TopKAttrs const &) const; @@ -35,23 +35,23 @@ struct TopKAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TopKAttrs const &) const; +struct hash<::FlexFlow::TopKAttrs> { + size_t operator()(::FlexFlow::TopKAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TopKAttrs from_json(json const &); - static void to_json(json &, FlexFlow::TopKAttrs const &); +struct adl_serializer<::FlexFlow::TopKAttrs> { + static ::FlexFlow::TopKAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::TopKAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TopKAttrs> { + static Gen<::FlexFlow::TopKAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h index f4d932845f..fac95b406b 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -23,7 +23,8 @@ namespace FlexFlow { struct TransposeAttrs { TransposeAttrs() = delete; - TransposeAttrs(::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm); + explicit TransposeAttrs( + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm); bool operator==(TransposeAttrs const &) const; bool operator!=(TransposeAttrs const &) const; @@ -37,23 +38,23 @@ struct TransposeAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TransposeAttrs const &) const; +struct hash<::FlexFlow::TransposeAttrs> { + size_t operator()(::FlexFlow::TransposeAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TransposeAttrs from_json(json const &); - static void to_json(json &, FlexFlow::TransposeAttrs const &); +struct adl_serializer<::FlexFlow::TransposeAttrs> { + static ::FlexFlow::TransposeAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::TransposeAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TransposeAttrs> { + static Gen<::FlexFlow::TransposeAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h index 4a19909c25..c7672267fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h @@ -30,23 +30,23 @@ struct WeightAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::WeightAttrs const &) const; +struct hash<::FlexFlow::WeightAttrs> { + size_t operator()(::FlexFlow::WeightAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::WeightAttrs from_json(json const &); - static void to_json(json &, FlexFlow::WeightAttrs const &); +struct adl_serializer<::FlexFlow::WeightAttrs> { + static ::FlexFlow::WeightAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::WeightAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::WeightAttrs> { + static Gen<::FlexFlow::WeightAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h index 71ad517095..edb24f78f4 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -26,7 +26,7 @@ namespace FlexFlow { struct ParallelTensorDims { ParallelTensorDims() = delete; - ParallelTensorDims( + explicit ParallelTensorDims( ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, ::FlexFlow::ReplicaParallelDimSet const &replica_dims); @@ -43,23 +43,23 @@ struct ParallelTensorDims { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelTensorDims const &) const; +struct hash<::FlexFlow::ParallelTensorDims> { + size_t operator()(::FlexFlow::ParallelTensorDims const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelTensorDims from_json(json const &); - static void to_json(json &, FlexFlow::ParallelTensorDims const &); +struct adl_serializer<::FlexFlow::ParallelTensorDims> { + static ::FlexFlow::ParallelTensorDims from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelTensorDims const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ParallelTensorDims> { + static Gen<::FlexFlow::ParallelTensorDims> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h index 62d291fa4f..9f56d29fbb 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -22,8 +22,8 @@ namespace FlexFlow { struct ParallelTensorShape { ParallelTensorShape() = delete; - ParallelTensorShape(::FlexFlow::ParallelTensorDims const &dims, - ::FlexFlow::DataType const &data_type); + explicit ParallelTensorShape(::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type); bool operator==(ParallelTensorShape const &) const; bool operator!=(ParallelTensorShape const &) const; @@ -38,23 +38,23 @@ struct ParallelTensorShape { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelTensorShape const &) const; +struct hash<::FlexFlow::ParallelTensorShape> { + size_t operator()(::FlexFlow::ParallelTensorShape const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelTensorShape from_json(json const &); - static void to_json(json &, FlexFlow::ParallelTensorShape const &); +struct adl_serializer<::FlexFlow::ParallelTensorShape> { + static ::FlexFlow::ParallelTensorShape from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelTensorShape const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ParallelTensorShape> { + static Gen<::FlexFlow::ParallelTensorShape> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h index a820bfe81c..c5f8748cbc 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct DiscardCopyDegree { DiscardCopyDegree() = delete; - DiscardCopyDegree(int const &value); + explicit DiscardCopyDegree(int const &value); bool operator==(DiscardCopyDegree const &) const; bool operator!=(DiscardCopyDegree const &) const; @@ -34,23 +34,23 @@ struct DiscardCopyDegree { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::DiscardCopyDegree const &) const; +struct hash<::FlexFlow::DiscardCopyDegree> { + size_t operator()(::FlexFlow::DiscardCopyDegree const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::DiscardCopyDegree from_json(json const &); - static void to_json(json &, FlexFlow::DiscardCopyDegree const &); +struct adl_serializer<::FlexFlow::DiscardCopyDegree> { + static ::FlexFlow::DiscardCopyDegree from_json(json const &); + static void to_json(json &, ::FlexFlow::DiscardCopyDegree const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::DiscardCopyDegree> { + static Gen<::FlexFlow::DiscardCopyDegree> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h index 17388f8d05..9391f7743e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct SumDegree { SumDegree() = delete; - SumDegree(int const &value); + explicit SumDegree(int const &value); bool operator==(SumDegree const &) const; bool operator!=(SumDegree const &) const; @@ -34,23 +34,23 @@ struct SumDegree { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::SumDegree const &) const; +struct hash<::FlexFlow::SumDegree> { + size_t operator()(::FlexFlow::SumDegree const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::SumDegree from_json(json const &); - static void to_json(json &, FlexFlow::SumDegree const &); +struct adl_serializer<::FlexFlow::SumDegree> { + static ::FlexFlow::SumDegree from_json(json const &); + static void to_json(json &, ::FlexFlow::SumDegree const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::SumDegree> { + static Gen<::FlexFlow::SumDegree> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h index 250ba29947..171cad2680 100644 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h @@ -21,8 +21,8 @@ namespace FlexFlow { struct ReplicaParallelDim { ReplicaParallelDim() = delete; - ReplicaParallelDim(int const °ree, - ::FlexFlow::ReplicaType const &replica_type); + explicit ReplicaParallelDim(int const °ree, + ::FlexFlow::ReplicaType const &replica_type); bool operator==(ReplicaParallelDim const &) const; bool operator!=(ReplicaParallelDim const &) const; @@ -37,23 +37,23 @@ struct ReplicaParallelDim { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReplicaParallelDim const &) const; +struct hash<::FlexFlow::ReplicaParallelDim> { + size_t operator()(::FlexFlow::ReplicaParallelDim const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReplicaParallelDim from_json(json const &); - static void to_json(json &, FlexFlow::ReplicaParallelDim const &); +struct adl_serializer<::FlexFlow::ReplicaParallelDim> { + static ::FlexFlow::ReplicaParallelDim from_json(json const &); + static void to_json(json &, ::FlexFlow::ReplicaParallelDim const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReplicaParallelDim> { + static Gen<::FlexFlow::ReplicaParallelDim> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h index 321029347f..1f964c4645 100644 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct ReplicaParallelDimSet { ReplicaParallelDimSet() = delete; - ReplicaParallelDimSet( + explicit ReplicaParallelDimSet( ::FlexFlow::SumDegree const &sum_degree, ::FlexFlow::DiscardCopyDegree const &discard_copy_degree); @@ -39,23 +39,23 @@ struct ReplicaParallelDimSet { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ReplicaParallelDimSet const &) const; +struct hash<::FlexFlow::ReplicaParallelDimSet> { + size_t operator()(::FlexFlow::ReplicaParallelDimSet const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ReplicaParallelDimSet from_json(json const &); - static void to_json(json &, FlexFlow::ReplicaParallelDimSet const &); +struct adl_serializer<::FlexFlow::ReplicaParallelDimSet> { + static ::FlexFlow::ReplicaParallelDimSet from_json(json const &); + static void to_json(json &, ::FlexFlow::ReplicaParallelDimSet const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ReplicaParallelDimSet> { + static Gen<::FlexFlow::ReplicaParallelDimSet> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h index 631852c259..a1cdea1fce 100644 --- a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct ShardParallelDim { ShardParallelDim() = delete; - ShardParallelDim(size_t const &size, int const °ree); + explicit ShardParallelDim(size_t const &size, int const °ree); bool operator==(ShardParallelDim const &) const; bool operator!=(ShardParallelDim const &) const; @@ -35,23 +35,23 @@ struct ShardParallelDim { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ShardParallelDim const &) const; +struct hash<::FlexFlow::ShardParallelDim> { + size_t operator()(::FlexFlow::ShardParallelDim const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ShardParallelDim from_json(json const &); - static void to_json(json &, FlexFlow::ShardParallelDim const &); +struct adl_serializer<::FlexFlow::ShardParallelDim> { + static ::FlexFlow::ShardParallelDim from_json(json const &); + static void to_json(json &, ::FlexFlow::ShardParallelDim const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ShardParallelDim> { + static Gen<::FlexFlow::ShardParallelDim> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h index a8e46a4626..1d50442831 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct TensorDims { TensorDims() = delete; - TensorDims(::FlexFlow::FFOrdered const &ff_ordered); + explicit TensorDims(::FlexFlow::FFOrdered const &ff_ordered); bool operator==(TensorDims const &) const; bool operator!=(TensorDims const &) const; @@ -35,23 +35,23 @@ struct TensorDims { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorDims const &) const; +struct hash<::FlexFlow::TensorDims> { + size_t operator()(::FlexFlow::TensorDims const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorDims from_json(json const &); - static void to_json(json &, FlexFlow::TensorDims const &); +struct adl_serializer<::FlexFlow::TensorDims> { + static ::FlexFlow::TensorDims from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorDims const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TensorDims> { + static Gen<::FlexFlow::TensorDims> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h index f36d5d1306..17a1d88994 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -22,8 +22,8 @@ namespace FlexFlow { struct TensorShape { TensorShape() = delete; - TensorShape(::FlexFlow::TensorDims const &dims, - ::FlexFlow::DataType const &data_type); + explicit TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type); bool operator==(TensorShape const &) const; bool operator!=(TensorShape const &) const; @@ -38,23 +38,23 @@ struct TensorShape { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorShape const &) const; +struct hash<::FlexFlow::TensorShape> { + size_t operator()(::FlexFlow::TensorShape const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorShape from_json(json const &); - static void to_json(json &, FlexFlow::TensorShape const &); +struct adl_serializer<::FlexFlow::TensorShape> { + static ::FlexFlow::TensorShape from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorShape const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TensorShape> { + static Gen<::FlexFlow::TensorShape> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc index e382ea298d..17ce3452cb 100644 --- a/lib/op-attrs/src/op-attrs/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -22,7 +22,7 @@ size_t size_of_datatype(DataType data_type) { } bool can_strictly_promote_datatype_from_to(DataType src, DataType dst) { - return src < dst; + return src < dst; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc index 8b22dfd18d..8cebeaeed0 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc @@ -34,7 +34,8 @@ bool ff_dim_t::operator>=(ff_dim_t const &other) const { } // namespace FlexFlow namespace std { -size_t hash::operator()(FlexFlow::ff_dim_t const &x) const { +size_t + hash::operator()(::FlexFlow::ff_dim_t const &x) const { size_t result = 0; result ^= std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -43,12 +44,12 @@ size_t hash::operator()(FlexFlow::ff_dim_t const &x) const { } // namespace std namespace nlohmann { -FlexFlow::ff_dim_t - adl_serializer::from_json(json const &j) { - return {j.at("value").template get()}; +::FlexFlow::ff_dim_t + adl_serializer<::FlexFlow::ff_dim_t>::from_json(json const &j) { + return ::FlexFlow::ff_dim_t{j.at("value").template get()}; } -void adl_serializer::to_json(json &j, - FlexFlow::ff_dim_t const &v) { +void adl_serializer<::FlexFlow::ff_dim_t>::to_json( + json &j, ::FlexFlow::ff_dim_t const &v) { j["__type"] = "ff_dim_t"; j["value"] = v.value; } diff --git a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc index ed06df2c78..747108c386 100644 --- a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc @@ -35,7 +35,7 @@ bool L1RegularizerAttrs::operator>=(L1RegularizerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::L1RegularizerAttrs const &x) const { + ::FlexFlow::L1RegularizerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,21 +44,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::L1RegularizerAttrs - adl_serializer::from_json(json const &j) { - return {j.at("lambda").template get()}; +::FlexFlow::L1RegularizerAttrs + adl_serializer<::FlexFlow::L1RegularizerAttrs>::from_json(json const &j) { + return ::FlexFlow::L1RegularizerAttrs{j.at("lambda").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::L1RegularizerAttrs const &v) { +void adl_serializer<::FlexFlow::L1RegularizerAttrs>::to_json( + json &j, ::FlexFlow::L1RegularizerAttrs const &v) { j["__type"] = "L1RegularizerAttrs"; j["lambda"] = v.lambda; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::L1RegularizerAttrs> + Arbitrary<::FlexFlow::L1RegularizerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::L1RegularizerAttrs>( + gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc index f0f3f34ee5..877f1703ca 100644 --- a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc @@ -35,7 +35,7 @@ bool L2RegularizerAttrs::operator>=(L2RegularizerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::L2RegularizerAttrs const &x) const { + ::FlexFlow::L2RegularizerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,21 +44,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::L2RegularizerAttrs - adl_serializer::from_json(json const &j) { - return {j.at("lambda").template get()}; +::FlexFlow::L2RegularizerAttrs + adl_serializer<::FlexFlow::L2RegularizerAttrs>::from_json(json const &j) { + return ::FlexFlow::L2RegularizerAttrs{j.at("lambda").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::L2RegularizerAttrs const &v) { +void adl_serializer<::FlexFlow::L2RegularizerAttrs>::to_json( + json &j, ::FlexFlow::L2RegularizerAttrs const &v) { j["__type"] = "L2RegularizerAttrs"; j["lambda"] = v.lambda; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::L2RegularizerAttrs> + Arbitrary<::FlexFlow::L2RegularizerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::L2RegularizerAttrs>( + gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc index 26d3138eb4..a5a66b1a77 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -112,7 +112,7 @@ bool MultiHeadAttentionInputs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::MultiHeadAttentionInputs const &x) const { + ::FlexFlow::MultiHeadAttentionInputs const &x) const { size_t result = 0; result ^= std::hash{}(x.batch_size) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -131,18 +131,19 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::MultiHeadAttentionInputs - adl_serializer::from_json( +::FlexFlow::MultiHeadAttentionInputs + adl_serializer<::FlexFlow::MultiHeadAttentionInputs>::from_json( json const &j) { - return {j.at("batch_size").template get(), - j.at("sequence_length").template get(), - j.at("query_size").template get(), - j.at("key_size").template get(), - j.at("value_size").template get(), - j.at("datatype").template get<::FlexFlow::DataType>()}; + return ::FlexFlow::MultiHeadAttentionInputs{ + j.at("batch_size").template get(), + j.at("sequence_length").template get(), + j.at("query_size").template get(), + j.at("key_size").template get(), + j.at("value_size").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::MultiHeadAttentionInputs const &v) { +void adl_serializer<::FlexFlow::MultiHeadAttentionInputs>::to_json( + json &j, ::FlexFlow::MultiHeadAttentionInputs const &v) { j["__type"] = "MultiHeadAttentionInputs"; j["batch_size"] = v.batch_size; j["sequence_length"] = v.sequence_length; @@ -154,9 +155,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::MultiHeadAttentionInputs> + Arbitrary<::FlexFlow::MultiHeadAttentionInputs>::arbitrary() { + return gen::construct<::FlexFlow::MultiHeadAttentionInputs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc index 2cd5b7ec00..b5ddeaac30 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -122,7 +122,7 @@ tl::expected query_dim, key_dim, value_dim, - discard_copy_q, + DiscardCopyDegree{discard_copy_q}, input_q.data_type, }; diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc index 94784d83cc..be4507677b 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc @@ -128,7 +128,7 @@ bool MultiHeadAttentionParallelInputs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::MultiHeadAttentionParallelInputs const &x) const { + ::FlexFlow::MultiHeadAttentionParallelInputs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.batch_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -149,10 +149,10 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::MultiHeadAttentionParallelInputs - adl_serializer::from_json( +::FlexFlow::MultiHeadAttentionParallelInputs + adl_serializer<::FlexFlow::MultiHeadAttentionParallelInputs>::from_json( json const &j) { - return { + return ::FlexFlow::MultiHeadAttentionParallelInputs{ j.at("batch_dim").template get<::FlexFlow::ShardParallelDim>(), j.at("sequence_dim").template get<::FlexFlow::ShardParallelDim>(), j.at("query_dim").template get<::FlexFlow::ShardParallelDim>(), @@ -161,8 +161,8 @@ FlexFlow::MultiHeadAttentionParallelInputs j.at("discard_copy_degree").template get<::FlexFlow::DiscardCopyDegree>(), j.at("datatype").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::MultiHeadAttentionParallelInputs const &v) { +void adl_serializer<::FlexFlow::MultiHeadAttentionParallelInputs>::to_json( + json &j, ::FlexFlow::MultiHeadAttentionParallelInputs const &v) { j["__type"] = "MultiHeadAttentionParallelInputs"; j["batch_dim"] = v.batch_dim; j["sequence_dim"] = v.sequence_dim; @@ -175,9 +175,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::MultiHeadAttentionParallelInputs> + Arbitrary<::FlexFlow::MultiHeadAttentionParallelInputs>::arbitrary() { + return gen::construct<::FlexFlow::MultiHeadAttentionParallelInputs>( gen::arbitrary<::FlexFlow::ShardParallelDim>(), gen::arbitrary<::FlexFlow::ShardParallelDim>(), gen::arbitrary<::FlexFlow::ShardParallelDim>(), diff --git a/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc index ad0c094969..a5fbcd6cf6 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc @@ -135,7 +135,7 @@ bool MultiHeadAttentionAttrs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::MultiHeadAttentionAttrs const &x) const { + ::FlexFlow::MultiHeadAttentionAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.embed_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -158,20 +158,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::MultiHeadAttentionAttrs - adl_serializer::from_json( +::FlexFlow::MultiHeadAttentionAttrs + adl_serializer<::FlexFlow::MultiHeadAttentionAttrs>::from_json( json const &j) { - return {j.at("embed_dim").template get(), - j.at("num_heads").template get(), - j.at("kdim").template get(), - j.at("vdim").template get(), - j.at("dropout").template get(), - j.at("bias").template get(), - j.at("add_bias_kv").template get(), - j.at("add_zero_attn").template get()}; + return ::FlexFlow::MultiHeadAttentionAttrs{ + j.at("embed_dim").template get(), + j.at("num_heads").template get(), + j.at("kdim").template get(), + j.at("vdim").template get(), + j.at("dropout").template get(), + j.at("bias").template get(), + j.at("add_bias_kv").template get(), + j.at("add_zero_attn").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::MultiHeadAttentionAttrs const &v) { +void adl_serializer<::FlexFlow::MultiHeadAttentionAttrs>::to_json( + json &j, ::FlexFlow::MultiHeadAttentionAttrs const &v) { j["__type"] = "MultiHeadAttentionAttrs"; j["embed_dim"] = v.embed_dim; j["num_heads"] = v.num_heads; @@ -185,9 +186,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::MultiHeadAttentionAttrs> + Arbitrary<::FlexFlow::MultiHeadAttentionAttrs>::arbitrary() { + return gen::construct<::FlexFlow::MultiHeadAttentionAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index cbda4ea533..f9836bd3ed 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -163,8 +163,8 @@ tl::expected output_p, }, ReplicaParallelDimSet{ - output_sum_degree, - output_discard_copy_degree, + SumDegree{output_sum_degree}, + DiscardCopyDegree{output_discard_copy_degree}, }, }, input_lhs.data_type, diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc index f178d40696..2395bf5691 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc @@ -43,7 +43,7 @@ bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::BatchMatmulAttrs const &x) const { + ::FlexFlow::BatchMatmulAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -54,13 +54,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::BatchMatmulAttrs - adl_serializer::from_json(json const &j) { - return {j.at("a_seq_length_dim").template get(), - j.at("b_seq_length_dim").template get()}; +::FlexFlow::BatchMatmulAttrs + adl_serializer<::FlexFlow::BatchMatmulAttrs>::from_json(json const &j) { + return ::FlexFlow::BatchMatmulAttrs{ + j.at("a_seq_length_dim").template get(), + j.at("b_seq_length_dim").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::BatchMatmulAttrs const &v) { +void adl_serializer<::FlexFlow::BatchMatmulAttrs>::to_json( + json &j, ::FlexFlow::BatchMatmulAttrs const &v) { j["__type"] = "BatchMatmulAttrs"; j["a_seq_length_dim"] = v.a_seq_length_dim; j["b_seq_length_dim"] = v.b_seq_length_dim; @@ -68,10 +69,10 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary()); +Gen<::FlexFlow::BatchMatmulAttrs> + Arbitrary<::FlexFlow::BatchMatmulAttrs>::arbitrary() { + return gen::construct<::FlexFlow::BatchMatmulAttrs>(gen::arbitrary(), + gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc index cb8dcadae1..13f20a82a5 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc @@ -35,7 +35,7 @@ bool BatchNormAttrs::operator>=(BatchNormAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::BatchNormAttrs const &x) const { + ::FlexFlow::BatchNormAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.relu) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,20 +44,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::BatchNormAttrs - adl_serializer::from_json(json const &j) { - return {j.at("relu").template get()}; +::FlexFlow::BatchNormAttrs + adl_serializer<::FlexFlow::BatchNormAttrs>::from_json(json const &j) { + return ::FlexFlow::BatchNormAttrs{j.at("relu").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::BatchNormAttrs const &v) { +void adl_serializer<::FlexFlow::BatchNormAttrs>::to_json( + json &j, ::FlexFlow::BatchNormAttrs const &v) { j["__type"] = "BatchNormAttrs"; j["relu"] = v.relu; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::BatchNormAttrs> + Arbitrary<::FlexFlow::BatchNormAttrs>::arbitrary() { + return gen::construct<::FlexFlow::BatchNormAttrs>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc index ec08bd6a1d..85fff2518c 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -38,7 +38,7 @@ bool BroadcastAttrs::operator>=(BroadcastAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::BroadcastAttrs const &x) const { + ::FlexFlow::BroadcastAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::stack_vector>{}( x.target_dims) + @@ -48,21 +48,23 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::BroadcastAttrs - adl_serializer::from_json(json const &j) { - return {j.at("target_dims") - .template get<::FlexFlow::stack_vector>()}; +::FlexFlow::BroadcastAttrs + adl_serializer<::FlexFlow::BroadcastAttrs>::from_json(json const &j) { + return ::FlexFlow::BroadcastAttrs{ + j.at("target_dims") + .template get<::FlexFlow::stack_vector>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::BroadcastAttrs const &v) { +void adl_serializer<::FlexFlow::BroadcastAttrs>::to_json( + json &j, ::FlexFlow::BroadcastAttrs const &v) { j["__type"] = "BroadcastAttrs"; j["target_dims"] = v.target_dims; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::BroadcastAttrs> + Arbitrary<::FlexFlow::BroadcastAttrs>::arbitrary() { + return gen::construct<::FlexFlow::BroadcastAttrs>( gen::arbitrary<::FlexFlow::stack_vector>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc index 28367f3449..423fc2e046 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -35,8 +35,8 @@ bool CastAttrs::operator>=(CastAttrs const &other) const { } // namespace FlexFlow namespace std { -size_t - hash::operator()(FlexFlow::CastAttrs const &x) const { +size_t hash::operator()( + ::FlexFlow::CastAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.dtype) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -45,20 +45,20 @@ size_t } // namespace std namespace nlohmann { -FlexFlow::CastAttrs - adl_serializer::from_json(json const &j) { - return {j.at("dtype").template get()}; +::FlexFlow::CastAttrs + adl_serializer<::FlexFlow::CastAttrs>::from_json(json const &j) { + return ::FlexFlow::CastAttrs{j.at("dtype").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::CastAttrs const &v) { +void adl_serializer<::FlexFlow::CastAttrs>::to_json( + json &j, ::FlexFlow::CastAttrs const &v) { j["__type"] = "CastAttrs"; j["dtype"] = v.dtype; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::CastAttrs> Arbitrary<::FlexFlow::CastAttrs>::arbitrary() { + return gen::construct<::FlexFlow::CastAttrs>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc index 516d3b0318..198da728bf 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -45,7 +45,7 @@ bool CombineAttrs::operator>=(CombineAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::CombineAttrs const &x) const { + ::FlexFlow::CombineAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.combine_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -56,13 +56,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::CombineAttrs - adl_serializer::from_json(json const &j) { - return {j.at("combine_dim").template get<::FlexFlow::ff_dim_t>(), - j.at("combine_degree").template get()}; +::FlexFlow::CombineAttrs + adl_serializer<::FlexFlow::CombineAttrs>::from_json(json const &j) { + return ::FlexFlow::CombineAttrs{ + j.at("combine_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("combine_degree").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::CombineAttrs const &v) { +void adl_serializer<::FlexFlow::CombineAttrs>::to_json( + json &j, ::FlexFlow::CombineAttrs const &v) { j["__type"] = "CombineAttrs"; j["combine_dim"] = v.combine_dim; j["combine_degree"] = v.combine_degree; @@ -70,8 +71,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::CombineAttrs> Arbitrary<::FlexFlow::CombineAttrs>::arbitrary() { + return gen::construct<::FlexFlow::CombineAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc index 20db25d485..2bbd9ba50e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -45,7 +45,7 @@ bool ConcatAttrs::operator>=(ConcatAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ConcatAttrs const &x) const { + ::FlexFlow::ConcatAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -56,13 +56,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ConcatAttrs - adl_serializer::from_json(json const &j) { - return {j.at("axis").template get<::FlexFlow::ff_dim_t>(), - j.at("num_inputs").template get()}; +::FlexFlow::ConcatAttrs + adl_serializer<::FlexFlow::ConcatAttrs>::from_json(json const &j) { + return ::FlexFlow::ConcatAttrs{ + j.at("axis").template get<::FlexFlow::ff_dim_t>(), + j.at("num_inputs").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ConcatAttrs const &v) { +void adl_serializer<::FlexFlow::ConcatAttrs>::to_json( + json &j, ::FlexFlow::ConcatAttrs const &v) { j["__type"] = "ConcatAttrs"; j["axis"] = v.axis; j["num_inputs"] = v.num_inputs; @@ -70,8 +71,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ConcatAttrs> Arbitrary<::FlexFlow::ConcatAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ConcatAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index d20690d705..03ae18a1d9 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -57,111 +57,75 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, input.datatype}; } -ParallelTensorShape - get_kernel_shape(Conv2DAttrs const &attrs, - ParallelTensorShape const &raw_input_shape) { +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported - Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - - ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), - input.discard_copy_reduction_degree}; - ShardParallelDim input_channels_dim = { - size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; - ShardParallelDim kernel_height_dim = {size_t_from_int(attrs.kernel_h), 1}; - ShardParallelDim kernel_width_dim = {size_t_from_int(attrs.kernel_w), 1}; - - int sum_degree = 1; - int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * - input.sum_reduction_degree; - - ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_channels_dim, - input_channels_dim, - kernel_height_dim, - kernel_width_dim, - }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }, - }, - input.datatype, + + Conv2DParallelInputShape parsed = parse_parallel_input_shape(input); + + TensorShape unpar = get_kernel_shape(attrs, get_reduced_shape(input)); + + assert(parsed.height_dim.degree == 1); + assert(parsed.width_dim.degree == 1); + + SumDegree sum_degree = SumDegree{1}; + DiscardCopyDegree discard_copy_degree = + DiscardCopyDegree{parsed.sample_dim.degree * parsed.sum_reduction_degree}; + FFOrdered shard_degrees = { + parsed.discard_copy_reduction_degree, + parsed.channel_dim.degree, + 1, + 1, }; - return result; + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, - ParallelTensorShape const &raw_input_shape) { + ParallelTensorShape const &input) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported - Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - - ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), - input.discard_copy_reduction_degree}; - - int sum_degree = 1; - int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * - input.sum_reduction_degree * - input.channel_dim.degree; - - ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_channels_dim, - }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }, - }, - input.datatype, + + Conv2DParallelInputShape parsed = parse_parallel_input_shape(input); + + TensorShape unpar = get_bias_shape(attrs, get_reduced_shape(input)); + + SumDegree sum_degree = + SumDegree{parsed.sum_reduction_degree * parsed.channel_dim.degree}; + DiscardCopyDegree discard_copy_degree = + DiscardCopyDegree{parsed.height_dim.degree * parsed.width_dim.degree * + parsed.sample_dim.degree}; + FFOrdered shard_degrees = { + parsed.discard_copy_reduction_degree, }; - return result; + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } -ParallelTensorShape - get_output_shape(Conv2DAttrs const &attrs, - ParallelTensorShape const &raw_input_shape) { +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported - Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - - TensorShape unpar_output_shape = - get_output_shape(attrs, get_reduced_shape(raw_input_shape)); - - size_t num_samples = dim_at_idx(unpar_output_shape, ff_dim_t{0}); - size_t num_channels = dim_at_idx(unpar_output_shape, ff_dim_t{1}); - size_t height = dim_at_idx(unpar_output_shape, ff_dim_t{2}); - size_t width = dim_at_idx(unpar_output_shape, ff_dim_t{3}); - - ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; - ShardParallelDim channel_dim = {num_channels, - input.discard_copy_reduction_degree}; - ShardParallelDim height_dim = {height, input.height_dim.degree}; - ShardParallelDim width_dim = {width, input.width_dim.degree}; - - int sum_degree = input.channel_dim.degree * input.sum_reduction_degree; - int discard_copy_degree = 1; - - ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - sample_dim, - channel_dim, - height_dim, - width_dim, - }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }, - }, - input.datatype, + + Conv2DParallelInputShape parsed = parse_parallel_input_shape(input); + + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input)); + + assert(parsed.height_dim.degree == 1); + assert(parsed.width_dim.degree == 1); + + SumDegree sum_degree = + SumDegree{parsed.sum_reduction_degree * parsed.channel_dim.degree}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1}; + FFOrdered shard_degrees = { + parsed.sample_dim.degree, + parsed.discard_copy_reduction_degree, + 1, + 1, }; - return result; + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc index 74df30e2d7..90df5ae1a3 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc @@ -91,7 +91,7 @@ bool Conv2DInputShape::operator>=(Conv2DInputShape const &other) const { namespace std { size_t hash::operator()( - FlexFlow::Conv2DInputShape const &x) const { + ::FlexFlow::Conv2DInputShape const &x) const { size_t result = 0; result ^= std::hash{}(x.num_samples) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -108,16 +108,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::Conv2DInputShape - adl_serializer::from_json(json const &j) { - return {j.at("num_samples").template get(), - j.at("num_channels").template get(), - j.at("height").template get(), - j.at("width").template get(), - j.at("datatype").template get<::FlexFlow::DataType>()}; +::FlexFlow::Conv2DInputShape + adl_serializer<::FlexFlow::Conv2DInputShape>::from_json(json const &j) { + return ::FlexFlow::Conv2DInputShape{ + j.at("num_samples").template get(), + j.at("num_channels").template get(), + j.at("height").template get(), + j.at("width").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::Conv2DInputShape const &v) { +void adl_serializer<::FlexFlow::Conv2DInputShape>::to_json( + json &j, ::FlexFlow::Conv2DInputShape const &v) { j["__type"] = "Conv2DInputShape"; j["num_samples"] = v.num_samples; j["num_channels"] = v.num_channels; @@ -128,9 +129,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::Conv2DInputShape> + Arbitrary<::FlexFlow::Conv2DInputShape>::arbitrary() { + return gen::construct<::FlexFlow::Conv2DInputShape>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc index 32ac4547f1..98f69d14c9 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -12,7 +12,7 @@ Conv2DParallelInputShape ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); - return Conv2DParallelInputShape{ + Conv2DParallelInputShape parsed = Conv2DParallelInputShape{ sample_dim, channel_dim, height_dim, @@ -21,6 +21,11 @@ Conv2DParallelInputShape get_discard_copy_degree(input), input.data_type, }; + + assert(parsed.height_dim.degree == 1); + assert(parsed.width_dim.degree == 1); + + return parsed; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc index df854c2b8f..efb73dba1b 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc @@ -132,7 +132,7 @@ bool Conv2DParallelInputShape::operator>=( namespace std { size_t hash::operator()( - FlexFlow::Conv2DParallelInputShape const &x) const { + ::FlexFlow::Conv2DParallelInputShape const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sample_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -153,19 +153,20 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::Conv2DParallelInputShape - adl_serializer::from_json( +::FlexFlow::Conv2DParallelInputShape + adl_serializer<::FlexFlow::Conv2DParallelInputShape>::from_json( json const &j) { - return {j.at("sample_dim").template get<::FlexFlow::ShardParallelDim>(), - j.at("channel_dim").template get<::FlexFlow::ShardParallelDim>(), - j.at("height_dim").template get<::FlexFlow::ShardParallelDim>(), - j.at("width_dim").template get<::FlexFlow::ShardParallelDim>(), - j.at("sum_reduction_degree").template get(), - j.at("discard_copy_reduction_degree").template get(), - j.at("datatype").template get<::FlexFlow::DataType>()}; + return ::FlexFlow::Conv2DParallelInputShape{ + j.at("sample_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("channel_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("height_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("width_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sum_reduction_degree").template get(), + j.at("discard_copy_reduction_degree").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::Conv2DParallelInputShape const &v) { +void adl_serializer<::FlexFlow::Conv2DParallelInputShape>::to_json( + json &j, ::FlexFlow::Conv2DParallelInputShape const &v) { j["__type"] = "Conv2DParallelInputShape"; j["sample_dim"] = v.sample_dim; j["channel_dim"] = v.channel_dim; @@ -178,9 +179,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::Conv2DParallelInputShape> + Arbitrary<::FlexFlow::Conv2DParallelInputShape>::arbitrary() { + return gen::construct<::FlexFlow::Conv2DParallelInputShape>( gen::arbitrary<::FlexFlow::ShardParallelDim>(), gen::arbitrary<::FlexFlow::ShardParallelDim>(), gen::arbitrary<::FlexFlow::ShardParallelDim>(), diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc index 238b349cbe..696fe08a6f 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -160,7 +160,7 @@ bool Conv2DAttrs::operator>=(Conv2DAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::Conv2DAttrs const &x) const { + ::FlexFlow::Conv2DAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -187,9 +187,9 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::Conv2DAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::Conv2DAttrs + adl_serializer<::FlexFlow::Conv2DAttrs>::from_json(json const &j) { + return ::FlexFlow::Conv2DAttrs{ j.at("out_channels").template get(), j.at("kernel_h").template get(), j.at("kernel_w").template get(), @@ -201,8 +201,8 @@ FlexFlow::Conv2DAttrs j.at("activation").template get>(), j.at("use_bias").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::Conv2DAttrs const &v) { +void adl_serializer<::FlexFlow::Conv2DAttrs>::to_json( + json &j, ::FlexFlow::Conv2DAttrs const &v) { j["__type"] = "Conv2DAttrs"; j["out_channels"] = v.out_channels; j["kernel_h"] = v.kernel_h; @@ -218,8 +218,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::Conv2DAttrs> Arbitrary<::FlexFlow::Conv2DAttrs>::arbitrary() { + return gen::construct<::FlexFlow::Conv2DAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc index 284443a0e4..15f6ad8bb1 100644 --- a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc @@ -36,7 +36,7 @@ bool DropoutAttrs::operator>=(DropoutAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::DropoutAttrs const &x) const { + ::FlexFlow::DropoutAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.rate) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -47,13 +47,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::DropoutAttrs - adl_serializer::from_json(json const &j) { - return {j.at("rate").template get(), - j.at("seed").template get()}; +::FlexFlow::DropoutAttrs + adl_serializer<::FlexFlow::DropoutAttrs>::from_json(json const &j) { + return ::FlexFlow::DropoutAttrs{ + j.at("rate").template get(), + j.at("seed").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::DropoutAttrs const &v) { +void adl_serializer<::FlexFlow::DropoutAttrs>::to_json( + json &j, ::FlexFlow::DropoutAttrs const &v) { j["__type"] = "DropoutAttrs"; j["rate"] = v.rate; j["seed"] = v.seed; @@ -61,8 +62,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::DropoutAttrs> Arbitrary<::FlexFlow::DropoutAttrs>::arbitrary() { + return gen::construct<::FlexFlow::DropoutAttrs>( gen::arbitrary(), gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc index a0e555cb12..568371c4fe 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -85,7 +85,7 @@ bool ElementBinaryAttrs::operator>=(ElementBinaryAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ElementBinaryAttrs const &x) const { + ::FlexFlow::ElementBinaryAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorType>{}(x.type) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -100,15 +100,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ElementBinaryAttrs - adl_serializer::from_json(json const &j) { - return {j.at("type").template get<::FlexFlow::OperatorType>(), - j.at("compute_type").template get<::FlexFlow::DataType>(), - j.at("should_broadcast_lhs").template get(), - j.at("should_broadcast_rhs").template get()}; +::FlexFlow::ElementBinaryAttrs + adl_serializer<::FlexFlow::ElementBinaryAttrs>::from_json(json const &j) { + return ::FlexFlow::ElementBinaryAttrs{ + j.at("type").template get<::FlexFlow::OperatorType>(), + j.at("compute_type").template get<::FlexFlow::DataType>(), + j.at("should_broadcast_lhs").template get(), + j.at("should_broadcast_rhs").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ElementBinaryAttrs const &v) { +void adl_serializer<::FlexFlow::ElementBinaryAttrs>::to_json( + json &j, ::FlexFlow::ElementBinaryAttrs const &v) { j["__type"] = "ElementBinaryAttrs"; j["type"] = v.type; j["compute_type"] = v.compute_type; @@ -118,9 +119,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ElementBinaryAttrs> + Arbitrary<::FlexFlow::ElementBinaryAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ElementBinaryAttrs>( gen::arbitrary<::FlexFlow::OperatorType>(), gen::arbitrary<::FlexFlow::DataType>(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc index ee85474caf..55c7b4f38f 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc @@ -50,7 +50,7 @@ bool ElementScalarUnaryAttrs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::ElementScalarUnaryAttrs const &x) const { + ::FlexFlow::ElementScalarUnaryAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -61,14 +61,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ElementScalarUnaryAttrs - adl_serializer::from_json( +::FlexFlow::ElementScalarUnaryAttrs + adl_serializer<::FlexFlow::ElementScalarUnaryAttrs>::from_json( json const &j) { - return {j.at("op_type").template get<::FlexFlow::OperatorType>(), - j.at("scalar").template get()}; + return ::FlexFlow::ElementScalarUnaryAttrs{ + j.at("op_type").template get<::FlexFlow::OperatorType>(), + j.at("scalar").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ElementScalarUnaryAttrs const &v) { +void adl_serializer<::FlexFlow::ElementScalarUnaryAttrs>::to_json( + json &j, ::FlexFlow::ElementScalarUnaryAttrs const &v) { j["__type"] = "ElementScalarUnaryAttrs"; j["op_type"] = v.op_type; j["scalar"] = v.scalar; @@ -76,9 +77,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ElementScalarUnaryAttrs> + Arbitrary<::FlexFlow::ElementScalarUnaryAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ElementScalarUnaryAttrs>( gen::arbitrary<::FlexFlow::OperatorType>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc index bf90a3db7d..6f158f9de4 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -37,7 +37,7 @@ bool ElementUnaryAttrs::operator>=(ElementUnaryAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ElementUnaryAttrs const &x) const { + ::FlexFlow::ElementUnaryAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,21 +46,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ElementUnaryAttrs - adl_serializer::from_json(json const &j) { - return {j.at("op_type").template get<::FlexFlow::OperatorType>()}; +::FlexFlow::ElementUnaryAttrs + adl_serializer<::FlexFlow::ElementUnaryAttrs>::from_json(json const &j) { + return ::FlexFlow::ElementUnaryAttrs{ + j.at("op_type").template get<::FlexFlow::OperatorType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ElementUnaryAttrs const &v) { +void adl_serializer<::FlexFlow::ElementUnaryAttrs>::to_json( + json &j, ::FlexFlow::ElementUnaryAttrs const &v) { j["__type"] = "ElementUnaryAttrs"; j["op_type"] = v.op_type; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ElementUnaryAttrs> + Arbitrary<::FlexFlow::ElementUnaryAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ElementUnaryAttrs>( gen::arbitrary<::FlexFlow::OperatorType>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index 9e9ad3a194..be7b91c24f 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -71,8 +71,9 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = shard_dim_at_idx(input, ff_dim_t{-1}).degree; - DiscardCopyDegree discard_copy_degree = 1; + SumDegree sum_degree = + SumDegree{shard_dim_at_idx(input, ff_dim_t{-1}).degree}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1}; FFOrdered shard_degrees = transform(input.dims.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); @@ -94,7 +95,7 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = 1; + SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( transform(ff_ordered_shard_dims(input.dims), [](ShardParallelDim const &d) -> int { return d.degree; }))}; diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc index b4d4657e08..8f5778d794 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -80,7 +80,7 @@ bool EmbeddingAttrs::operator>=(EmbeddingAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::EmbeddingAttrs const &x) const { + ::FlexFlow::EmbeddingAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.num_entries) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -95,15 +95,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::EmbeddingAttrs - adl_serializer::from_json(json const &j) { - return {j.at("num_entries").template get(), - j.at("out_channels").template get(), - j.at("aggr").template get>(), - j.at("data_type").template get<::FlexFlow::DataType>()}; +::FlexFlow::EmbeddingAttrs + adl_serializer<::FlexFlow::EmbeddingAttrs>::from_json(json const &j) { + return ::FlexFlow::EmbeddingAttrs{ + j.at("num_entries").template get(), + j.at("out_channels").template get(), + j.at("aggr").template get>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::EmbeddingAttrs const &v) { +void adl_serializer<::FlexFlow::EmbeddingAttrs>::to_json( + json &j, ::FlexFlow::EmbeddingAttrs const &v) { j["__type"] = "EmbeddingAttrs"; j["num_entries"] = v.num_entries; j["out_channels"] = v.out_channels; @@ -113,8 +114,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::EmbeddingAttrs> + Arbitrary<::FlexFlow::EmbeddingAttrs>::arbitrary() { + return gen::construct<::FlexFlow::EmbeddingAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary>(), diff --git a/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc index ef34d97a89..ff2cdcace5 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc @@ -33,27 +33,27 @@ bool FlatAttrs::operator>=(FlatAttrs const &other) const { } // namespace FlexFlow namespace std { -size_t - hash::operator()(FlexFlow::FlatAttrs const &x) const { +size_t hash::operator()( + ::FlexFlow::FlatAttrs const &x) const { size_t result = 0; return result; } } // namespace std namespace nlohmann { -FlexFlow::FlatAttrs - adl_serializer::from_json(json const &j) { - return {}; +::FlexFlow::FlatAttrs + adl_serializer<::FlexFlow::FlatAttrs>::from_json(json const &j) { + return ::FlexFlow::FlatAttrs{}; } -void adl_serializer::to_json( - json &j, FlexFlow::FlatAttrs const &v) { +void adl_serializer<::FlexFlow::FlatAttrs>::to_json( + json &j, ::FlexFlow::FlatAttrs const &v) { j["__type"] = "FlatAttrs"; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(); +Gen<::FlexFlow::FlatAttrs> Arbitrary<::FlexFlow::FlatAttrs>::arbitrary() { + return gen::construct<::FlexFlow::FlatAttrs>(); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc index 713c0f391e..a056d812ca 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -37,7 +37,7 @@ bool GatherAttrs::operator>=(GatherAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::GatherAttrs const &x) const { + ::FlexFlow::GatherAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,20 +46,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::GatherAttrs - adl_serializer::from_json(json const &j) { - return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +::FlexFlow::GatherAttrs + adl_serializer<::FlexFlow::GatherAttrs>::from_json(json const &j) { + return ::FlexFlow::GatherAttrs{ + j.at("dim").template get<::FlexFlow::ff_dim_t>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::GatherAttrs const &v) { +void adl_serializer<::FlexFlow::GatherAttrs>::to_json( + json &j, ::FlexFlow::GatherAttrs const &v) { j["__type"] = "GatherAttrs"; j["dim"] = v.dim; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::GatherAttrs> Arbitrary<::FlexFlow::GatherAttrs>::arbitrary() { + return gen::construct<::FlexFlow::GatherAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc index 35544402f7..b3b092bcfd 100644 --- a/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc @@ -34,26 +34,26 @@ bool InputAttrs::operator>=(InputAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::InputAttrs const &x) const { + ::FlexFlow::InputAttrs const &x) const { size_t result = 0; return result; } } // namespace std namespace nlohmann { -FlexFlow::InputAttrs - adl_serializer::from_json(json const &j) { - return {}; +::FlexFlow::InputAttrs + adl_serializer<::FlexFlow::InputAttrs>::from_json(json const &j) { + return ::FlexFlow::InputAttrs{}; } -void adl_serializer::to_json( - json &j, FlexFlow::InputAttrs const &v) { +void adl_serializer<::FlexFlow::InputAttrs>::to_json( + json &j, ::FlexFlow::InputAttrs const &v) { j["__type"] = "InputAttrs"; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(); +Gen<::FlexFlow::InputAttrs> Arbitrary<::FlexFlow::InputAttrs>::arbitrary() { + return gen::construct<::FlexFlow::InputAttrs>(); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc index 163f2e2f91..66db8e278a 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -48,7 +48,7 @@ bool LayerNormAttrs::operator>=(LayerNormAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::LayerNormAttrs const &x) const { + ::FlexFlow::LayerNormAttrs const &x) const { size_t result = 0; result ^= std::hash< @@ -64,17 +64,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::LayerNormAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::LayerNormAttrs + adl_serializer<::FlexFlow::LayerNormAttrs>::from_json(json const &j) { + return ::FlexFlow::LayerNormAttrs{ j.at("axes") .template get< ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), j.at("elementwise_affine").template get(), j.at("eps").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::LayerNormAttrs const &v) { +void adl_serializer<::FlexFlow::LayerNormAttrs>::to_json( + json &j, ::FlexFlow::LayerNormAttrs const &v) { j["__type"] = "LayerNormAttrs"; j["axes"] = v.axes; j["elementwise_affine"] = v.elementwise_affine; @@ -83,8 +83,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::LayerNormAttrs> + Arbitrary<::FlexFlow::LayerNormAttrs>::arbitrary() { + return gen::construct<::FlexFlow::LayerNormAttrs>( gen::arbitrary< ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 8283673378..2bd0cea950 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -50,7 +50,7 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = 1; + SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ get_sum_degree(input) * product( @@ -75,10 +75,10 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = - get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; - DiscardCopyDegree discard_copy_degree = product( - slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); + SumDegree sum_degree = SumDegree{ + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; return lift_to_parallel_with_degrees( @@ -97,9 +97,9 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = - get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; - DiscardCopyDegree discard_copy_degree = 1; + SumDegree sum_degree = SumDegree{ + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1}; FFOrdered shard_degrees = ff_ordered_shard_degrees(input); shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc index f3359da219..3099a6c7e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -94,7 +94,7 @@ bool LinearAttrs::operator>=(LinearAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::LinearAttrs const &x) const { + ::FlexFlow::LinearAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -112,9 +112,9 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::LinearAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::LinearAttrs + adl_serializer<::FlexFlow::LinearAttrs>::from_json(json const &j) { + return ::FlexFlow::LinearAttrs{ j.at("out_channels").template get(), j.at("use_bias").template get(), j.at("data_type").template get<::FlexFlow::DataType>(), @@ -122,8 +122,8 @@ FlexFlow::LinearAttrs j.at("regularizer") .template get>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::LinearAttrs const &v) { +void adl_serializer<::FlexFlow::LinearAttrs>::to_json( + json &j, ::FlexFlow::LinearAttrs const &v) { j["__type"] = "LinearAttrs"; j["out_channels"] = v.out_channels; j["use_bias"] = v.use_bias; @@ -134,8 +134,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::LinearAttrs> Arbitrary<::FlexFlow::LinearAttrs>::arbitrary() { + return gen::construct<::FlexFlow::LinearAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary<::FlexFlow::DataType>(), diff --git a/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc index 3ef3a0119b..9600011c06 100644 --- a/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc @@ -33,27 +33,27 @@ bool NoopAttrs::operator>=(NoopAttrs const &other) const { } // namespace FlexFlow namespace std { -size_t - hash::operator()(FlexFlow::NoopAttrs const &x) const { +size_t hash::operator()( + ::FlexFlow::NoopAttrs const &x) const { size_t result = 0; return result; } } // namespace std namespace nlohmann { -FlexFlow::NoopAttrs - adl_serializer::from_json(json const &j) { - return {}; +::FlexFlow::NoopAttrs + adl_serializer<::FlexFlow::NoopAttrs>::from_json(json const &j) { + return ::FlexFlow::NoopAttrs{}; } -void adl_serializer::to_json( - json &j, FlexFlow::NoopAttrs const &v) { +void adl_serializer<::FlexFlow::NoopAttrs>::to_json( + json &j, ::FlexFlow::NoopAttrs const &v) { j["__type"] = "NoopAttrs"; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(); +Gen<::FlexFlow::NoopAttrs> Arbitrary<::FlexFlow::NoopAttrs>::arbitrary() { + return gen::construct<::FlexFlow::NoopAttrs>(); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc index ac8da6d2d7..67a46ef5fb 100644 --- a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -32,7 +32,7 @@ bool ParallelMultiHeadAttentionInputs::operator!=( namespace std { size_t hash::operator()( - FlexFlow::ParallelMultiHeadAttentionInputs const &x) const { + ::FlexFlow::ParallelMultiHeadAttentionInputs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.query) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -45,15 +45,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ParallelMultiHeadAttentionInputs - adl_serializer::from_json( +::FlexFlow::ParallelMultiHeadAttentionInputs + adl_serializer<::FlexFlow::ParallelMultiHeadAttentionInputs>::from_json( json const &j) { - return {j.at("query").template get<::FlexFlow::ParallelTensorShape>(), - j.at("key").template get<::FlexFlow::ParallelTensorShape>(), - j.at("value").template get<::FlexFlow::ParallelTensorShape>()}; + return ::FlexFlow::ParallelMultiHeadAttentionInputs{ + j.at("query").template get<::FlexFlow::ParallelTensorShape>(), + j.at("key").template get<::FlexFlow::ParallelTensorShape>(), + j.at("value").template get<::FlexFlow::ParallelTensorShape>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelMultiHeadAttentionInputs const &v) { +void adl_serializer<::FlexFlow::ParallelMultiHeadAttentionInputs>::to_json( + json &j, ::FlexFlow::ParallelMultiHeadAttentionInputs const &v) { j["__type"] = "ParallelMultiHeadAttentionInputs"; j["query"] = v.query; j["key"] = v.key; @@ -62,9 +63,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ParallelMultiHeadAttentionInputs> + Arbitrary<::FlexFlow::ParallelMultiHeadAttentionInputs>::arbitrary() { + return gen::construct<::FlexFlow::ParallelMultiHeadAttentionInputs>( gen::arbitrary<::FlexFlow::ParallelTensorShape>(), gen::arbitrary<::FlexFlow::ParallelTensorShape>(), gen::arbitrary<::FlexFlow::ParallelTensorShape>()); diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc index 8c445d8b84..057b030a96 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -131,7 +131,7 @@ bool Pool2DAttrs::operator>=(Pool2DAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::Pool2DAttrs const &x) const { + ::FlexFlow::Pool2DAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -154,19 +154,20 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::Pool2DAttrs - adl_serializer::from_json(json const &j) { - return {j.at("kernel_h").template get(), - j.at("kernel_w").template get(), - j.at("stride_h").template get(), - j.at("stride_w").template get(), - j.at("padding_h").template get(), - j.at("padding_w").template get(), - j.at("pool_type").template get<::FlexFlow::PoolOp>(), - j.at("activation").template get<::FlexFlow::Activation>()}; +::FlexFlow::Pool2DAttrs + adl_serializer<::FlexFlow::Pool2DAttrs>::from_json(json const &j) { + return ::FlexFlow::Pool2DAttrs{ + j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("pool_type").template get<::FlexFlow::PoolOp>(), + j.at("activation").template get<::FlexFlow::Activation>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::Pool2DAttrs const &v) { +void adl_serializer<::FlexFlow::Pool2DAttrs>::to_json( + json &j, ::FlexFlow::Pool2DAttrs const &v) { j["__type"] = "Pool2DAttrs"; j["kernel_h"] = v.kernel_h; j["kernel_w"] = v.kernel_w; @@ -180,8 +181,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::Pool2DAttrs> Arbitrary<::FlexFlow::Pool2DAttrs>::arbitrary() { + return gen::construct<::FlexFlow::Pool2DAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary(), diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc index 2aa9546956..c365819440 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -49,7 +49,7 @@ bool ReduceAttrs::operator>=(ReduceAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReduceAttrs const &x) const { + ::FlexFlow::ReduceAttrs const &x) const { size_t result = 0; result ^= std::hash< @@ -65,17 +65,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReduceAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::ReduceAttrs + adl_serializer<::FlexFlow::ReduceAttrs>::from_json(json const &j) { + return ::FlexFlow::ReduceAttrs{ j.at("axes") .template get< ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), j.at("op_type").template get<::FlexFlow::OperatorType>(), j.at("keepdims").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReduceAttrs const &v) { +void adl_serializer<::FlexFlow::ReduceAttrs>::to_json( + json &j, ::FlexFlow::ReduceAttrs const &v) { j["__type"] = "ReduceAttrs"; j["axes"] = v.axes; j["op_type"] = v.op_type; @@ -84,8 +84,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ReduceAttrs> Arbitrary<::FlexFlow::ReduceAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ReduceAttrs>( gen::arbitrary< ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), gen::arbitrary<::FlexFlow::OperatorType>(), diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc index 2f1550bb66..b861676f2b 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -36,7 +36,7 @@ bool ReductionAttrs::operator>=(ReductionAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReductionAttrs const &x) const { + ::FlexFlow::ReductionAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.reduction_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -45,20 +45,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReductionAttrs - adl_serializer::from_json(json const &j) { - return {j.at("reduction_degree").template get()}; +::FlexFlow::ReductionAttrs + adl_serializer<::FlexFlow::ReductionAttrs>::from_json(json const &j) { + return ::FlexFlow::ReductionAttrs{ + j.at("reduction_degree").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReductionAttrs const &v) { +void adl_serializer<::FlexFlow::ReductionAttrs>::to_json( + json &j, ::FlexFlow::ReductionAttrs const &v) { j["__type"] = "ReductionAttrs"; j["reduction_degree"] = v.reduction_degree; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::ReductionAttrs> + Arbitrary<::FlexFlow::ReductionAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ReductionAttrs>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc index 6270298c87..110e16c36a 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -46,7 +46,7 @@ bool RepartitionAttrs::operator>=(RepartitionAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::RepartitionAttrs const &x) const { + ::FlexFlow::RepartitionAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.repartition_dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -57,13 +57,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::RepartitionAttrs - adl_serializer::from_json(json const &j) { - return {j.at("repartition_dim").template get<::FlexFlow::ff_dim_t>(), - j.at("repartition_degree").template get()}; +::FlexFlow::RepartitionAttrs + adl_serializer<::FlexFlow::RepartitionAttrs>::from_json(json const &j) { + return ::FlexFlow::RepartitionAttrs{ + j.at("repartition_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("repartition_degree").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::RepartitionAttrs const &v) { +void adl_serializer<::FlexFlow::RepartitionAttrs>::to_json( + json &j, ::FlexFlow::RepartitionAttrs const &v) { j["__type"] = "RepartitionAttrs"; j["repartition_dim"] = v.repartition_dim; j["repartition_degree"] = v.repartition_degree; @@ -71,9 +72,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::RepartitionAttrs> + Arbitrary<::FlexFlow::RepartitionAttrs>::arbitrary() { + return gen::construct<::FlexFlow::RepartitionAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc index 930c5beaf4..bdac2d8c81 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -36,7 +36,7 @@ bool ReplicateAttrs::operator>=(ReplicateAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReplicateAttrs const &x) const { + ::FlexFlow::ReplicateAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.replicate_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -45,20 +45,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReplicateAttrs - adl_serializer::from_json(json const &j) { - return {j.at("replicate_degree").template get()}; +::FlexFlow::ReplicateAttrs + adl_serializer<::FlexFlow::ReplicateAttrs>::from_json(json const &j) { + return ::FlexFlow::ReplicateAttrs{ + j.at("replicate_degree").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReplicateAttrs const &v) { +void adl_serializer<::FlexFlow::ReplicateAttrs>::to_json( + json &j, ::FlexFlow::ReplicateAttrs const &v) { j["__type"] = "ReplicateAttrs"; j["replicate_degree"] = v.replicate_degree; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::ReplicateAttrs> + Arbitrary<::FlexFlow::ReplicateAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ReplicateAttrs>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc index b1fb350b88..de18a192ff 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -37,7 +37,7 @@ bool ReshapeAttrs::operator>=(ReshapeAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReshapeAttrs const &x) const { + ::FlexFlow::ReshapeAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,20 +46,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReshapeAttrs - adl_serializer::from_json(json const &j) { - return {j.at("shape").template get<::FlexFlow::TensorShape>()}; +::FlexFlow::ReshapeAttrs + adl_serializer<::FlexFlow::ReshapeAttrs>::from_json(json const &j) { + return ::FlexFlow::ReshapeAttrs{ + j.at("shape").template get<::FlexFlow::TensorShape>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReshapeAttrs const &v) { +void adl_serializer<::FlexFlow::ReshapeAttrs>::to_json( + json &j, ::FlexFlow::ReshapeAttrs const &v) { j["__type"] = "ReshapeAttrs"; j["shape"] = v.shape; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ReshapeAttrs> Arbitrary<::FlexFlow::ReshapeAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ReshapeAttrs>( gen::arbitrary<::FlexFlow::TensorShape>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc index 9ac9abeb82..9e8079d666 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -37,7 +37,7 @@ bool ReverseAttrs::operator>=(ReverseAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReverseAttrs const &x) const { + ::FlexFlow::ReverseAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,20 +46,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReverseAttrs - adl_serializer::from_json(json const &j) { - return {j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +::FlexFlow::ReverseAttrs + adl_serializer<::FlexFlow::ReverseAttrs>::from_json(json const &j) { + return ::FlexFlow::ReverseAttrs{ + j.at("axis").template get<::FlexFlow::ff_dim_t>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReverseAttrs const &v) { +void adl_serializer<::FlexFlow::ReverseAttrs>::to_json( + json &j, ::FlexFlow::ReverseAttrs const &v) { j["__type"] = "ReverseAttrs"; j["axis"] = v.axis; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ReverseAttrs> Arbitrary<::FlexFlow::ReverseAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ReverseAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc index 4941b7438a..1d4d396ef3 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -37,7 +37,7 @@ bool SoftmaxAttrs::operator>=(SoftmaxAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::SoftmaxAttrs const &x) const { + ::FlexFlow::SoftmaxAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,20 +46,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::SoftmaxAttrs - adl_serializer::from_json(json const &j) { - return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +::FlexFlow::SoftmaxAttrs + adl_serializer<::FlexFlow::SoftmaxAttrs>::from_json(json const &j) { + return ::FlexFlow::SoftmaxAttrs{ + j.at("dim").template get<::FlexFlow::ff_dim_t>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::SoftmaxAttrs const &v) { +void adl_serializer<::FlexFlow::SoftmaxAttrs>::to_json( + json &j, ::FlexFlow::SoftmaxAttrs const &v) { j["__type"] = "SoftmaxAttrs"; j["dim"] = v.dim; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::SoftmaxAttrs> Arbitrary<::FlexFlow::SoftmaxAttrs>::arbitrary() { + return gen::construct<::FlexFlow::SoftmaxAttrs>( gen::arbitrary<::FlexFlow::ff_dim_t>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc index c6f7e75dbf..bdae47681e 100644 --- a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -47,7 +47,7 @@ bool SplitAttrs::operator>=(SplitAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::SplitAttrs const &x) const { + ::FlexFlow::SplitAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::stack_vector>{}(x.splits) + @@ -59,14 +59,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::SplitAttrs - adl_serializer::from_json(json const &j) { - return {j.at("splits") - .template get<::FlexFlow::stack_vector>(), - j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +::FlexFlow::SplitAttrs + adl_serializer<::FlexFlow::SplitAttrs>::from_json(json const &j) { + return ::FlexFlow::SplitAttrs{ + j.at("splits") + .template get<::FlexFlow::stack_vector>(), + j.at("axis").template get<::FlexFlow::ff_dim_t>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::SplitAttrs const &v) { +void adl_serializer<::FlexFlow::SplitAttrs>::to_json( + json &j, ::FlexFlow::SplitAttrs const &v) { j["__type"] = "SplitAttrs"; j["splits"] = v.splits; j["axis"] = v.axis; @@ -74,8 +75,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::SplitAttrs> Arbitrary<::FlexFlow::SplitAttrs>::arbitrary() { + return gen::construct<::FlexFlow::SplitAttrs>( gen::arbitrary<::FlexFlow::stack_vector>(), gen::arbitrary<::FlexFlow::ff_dim_t>()); } diff --git a/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc index 55ead7d858..9723c063a5 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc @@ -34,8 +34,8 @@ bool TopKAttrs::operator>=(TopKAttrs const &other) const { } // namespace FlexFlow namespace std { -size_t - hash::operator()(FlexFlow::TopKAttrs const &x) const { +size_t hash::operator()( + ::FlexFlow::TopKAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.k) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= @@ -45,12 +45,13 @@ size_t } // namespace std namespace nlohmann { -FlexFlow::TopKAttrs - adl_serializer::from_json(json const &j) { - return {j.at("k").template get(), j.at("sorted").template get()}; +::FlexFlow::TopKAttrs + adl_serializer<::FlexFlow::TopKAttrs>::from_json(json const &j) { + return ::FlexFlow::TopKAttrs{j.at("k").template get(), + j.at("sorted").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TopKAttrs const &v) { +void adl_serializer<::FlexFlow::TopKAttrs>::to_json( + json &j, ::FlexFlow::TopKAttrs const &v) { j["__type"] = "TopKAttrs"; j["k"] = v.k; j["sorted"] = v.sorted; @@ -58,9 +59,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary()); +Gen<::FlexFlow::TopKAttrs> Arbitrary<::FlexFlow::TopKAttrs>::arbitrary() { + return gen::construct<::FlexFlow::TopKAttrs>(gen::arbitrary(), + gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc index 0a774b992e..23e78beb7a 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -40,7 +40,7 @@ bool TransposeAttrs::operator>=(TransposeAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::TransposeAttrs const &x) const { + ::FlexFlow::TransposeAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>{}(x.perm) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -49,21 +49,22 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TransposeAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::TransposeAttrs + adl_serializer<::FlexFlow::TransposeAttrs>::from_json(json const &j) { + return ::FlexFlow::TransposeAttrs{ j.at("perm").template get<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TransposeAttrs const &v) { +void adl_serializer<::FlexFlow::TransposeAttrs>::to_json( + json &j, ::FlexFlow::TransposeAttrs const &v) { j["__type"] = "TransposeAttrs"; j["perm"] = v.perm; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::TransposeAttrs> + Arbitrary<::FlexFlow::TransposeAttrs>::arbitrary() { + return gen::construct<::FlexFlow::TransposeAttrs>( gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc index a288161da2..03ad9f469c 100644 --- a/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc @@ -34,26 +34,26 @@ bool WeightAttrs::operator>=(WeightAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::WeightAttrs const &x) const { + ::FlexFlow::WeightAttrs const &x) const { size_t result = 0; return result; } } // namespace std namespace nlohmann { -FlexFlow::WeightAttrs - adl_serializer::from_json(json const &j) { - return {}; +::FlexFlow::WeightAttrs + adl_serializer<::FlexFlow::WeightAttrs>::from_json(json const &j) { + return ::FlexFlow::WeightAttrs{}; } -void adl_serializer::to_json( - json &j, FlexFlow::WeightAttrs const &v) { +void adl_serializer<::FlexFlow::WeightAttrs>::to_json( + json &j, ::FlexFlow::WeightAttrs const &v) { j["__type"] = "WeightAttrs"; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(); +Gen<::FlexFlow::WeightAttrs> Arbitrary<::FlexFlow::WeightAttrs>::arbitrary() { + return gen::construct<::FlexFlow::WeightAttrs>(); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index 40be73cb9f..3cad12b4fa 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -50,7 +50,7 @@ bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ParallelTensorDims const &x) const { + ::FlexFlow::ParallelTensorDims const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>{}( x.shard_dims) + @@ -62,15 +62,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ParallelTensorDims - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::ParallelTensorDims + adl_serializer<::FlexFlow::ParallelTensorDims>::from_json(json const &j) { + return ::FlexFlow::ParallelTensorDims{ j.at("shard_dims") .template get<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), j.at("replica_dims").template get<::FlexFlow::ReplicaParallelDimSet>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelTensorDims const &v) { +void adl_serializer<::FlexFlow::ParallelTensorDims>::to_json( + json &j, ::FlexFlow::ParallelTensorDims const &v) { j["__type"] = "ParallelTensorDims"; j["shard_dims"] = v.shard_dims; j["replica_dims"] = v.replica_dims; @@ -78,9 +78,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ParallelTensorDims> + Arbitrary<::FlexFlow::ParallelTensorDims>::arbitrary() { + return gen::construct<::FlexFlow::ParallelTensorDims>( gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), gen::arbitrary<::FlexFlow::ReplicaParallelDimSet>()); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 516cbe191f..e2ba10c7bb 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -58,7 +58,7 @@ std::optional } ParallelTensorShape lift_to_parallel(TensorShape const &s) { - return {lift_to_parallel(s.dims), s.data_type}; + return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } ParallelTensorShape diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index 1fe82ce108..3a509de7f0 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -46,7 +46,7 @@ bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ParallelTensorShape const &x) const { + ::FlexFlow::ParallelTensorShape const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ParallelTensorDims>{}(x.dims) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -57,13 +57,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ParallelTensorShape - adl_serializer::from_json(json const &j) { - return {j.at("dims").template get<::FlexFlow::ParallelTensorDims>(), - j.at("data_type").template get<::FlexFlow::DataType>()}; +::FlexFlow::ParallelTensorShape + adl_serializer<::FlexFlow::ParallelTensorShape>::from_json(json const &j) { + return ::FlexFlow::ParallelTensorShape{ + j.at("dims").template get<::FlexFlow::ParallelTensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelTensorShape const &v) { +void adl_serializer<::FlexFlow::ParallelTensorShape>::to_json( + json &j, ::FlexFlow::ParallelTensorShape const &v) { j["__type"] = "ParallelTensorShape"; j["dims"] = v.dims; j["data_type"] = v.data_type; @@ -71,9 +72,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ParallelTensorShape> + Arbitrary<::FlexFlow::ParallelTensorShape>::arbitrary() { + return gen::construct<::FlexFlow::ParallelTensorShape>( gen::arbitrary<::FlexFlow::ParallelTensorDims>(), gen::arbitrary<::FlexFlow::DataType>()); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc index 4547a5df9b..cdea7bb484 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc @@ -35,7 +35,7 @@ bool DiscardCopyDegree::operator>=(DiscardCopyDegree const &other) const { namespace std { size_t hash::operator()( - FlexFlow::DiscardCopyDegree const &x) const { + ::FlexFlow::DiscardCopyDegree const &x) const { size_t result = 0; result ^= std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,21 +44,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::DiscardCopyDegree - adl_serializer::from_json(json const &j) { - return {j.at("value").template get()}; +::FlexFlow::DiscardCopyDegree + adl_serializer<::FlexFlow::DiscardCopyDegree>::from_json(json const &j) { + return ::FlexFlow::DiscardCopyDegree{j.at("value").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::DiscardCopyDegree const &v) { +void adl_serializer<::FlexFlow::DiscardCopyDegree>::to_json( + json &j, ::FlexFlow::DiscardCopyDegree const &v) { j["__type"] = "DiscardCopyDegree"; j["value"] = v.value; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::DiscardCopyDegree> + Arbitrary<::FlexFlow::DiscardCopyDegree>::arbitrary() { + return gen::construct<::FlexFlow::DiscardCopyDegree>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc index cf159a1ea7..9dbc095f84 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc @@ -34,8 +34,8 @@ bool SumDegree::operator>=(SumDegree const &other) const { } // namespace FlexFlow namespace std { -size_t - hash::operator()(FlexFlow::SumDegree const &x) const { +size_t hash::operator()( + ::FlexFlow::SumDegree const &x) const { size_t result = 0; result ^= std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,20 +44,20 @@ size_t } // namespace std namespace nlohmann { -FlexFlow::SumDegree - adl_serializer::from_json(json const &j) { - return {j.at("value").template get()}; +::FlexFlow::SumDegree + adl_serializer<::FlexFlow::SumDegree>::from_json(json const &j) { + return ::FlexFlow::SumDegree{j.at("value").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::SumDegree const &v) { +void adl_serializer<::FlexFlow::SumDegree>::to_json( + json &j, ::FlexFlow::SumDegree const &v) { j["__type"] = "SumDegree"; j["value"] = v.value; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::SumDegree> Arbitrary<::FlexFlow::SumDegree>::arbitrary() { + return gen::construct<::FlexFlow::SumDegree>(gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc index a1256ad79a..ed45115c77 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc @@ -44,7 +44,7 @@ bool ReplicaParallelDim::operator>=(ReplicaParallelDim const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ReplicaParallelDim const &x) const { + ::FlexFlow::ReplicaParallelDim const &x) const { size_t result = 0; result ^= std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -55,13 +55,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReplicaParallelDim - adl_serializer::from_json(json const &j) { - return {j.at("degree").template get(), - j.at("replica_type").template get<::FlexFlow::ReplicaType>()}; +::FlexFlow::ReplicaParallelDim + adl_serializer<::FlexFlow::ReplicaParallelDim>::from_json(json const &j) { + return ::FlexFlow::ReplicaParallelDim{ + j.at("degree").template get(), + j.at("replica_type").template get<::FlexFlow::ReplicaType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReplicaParallelDim const &v) { +void adl_serializer<::FlexFlow::ReplicaParallelDim>::to_json( + json &j, ::FlexFlow::ReplicaParallelDim const &v) { j["__type"] = "ReplicaParallelDim"; j["degree"] = v.degree; j["replica_type"] = v.replica_type; @@ -69,9 +70,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ReplicaParallelDim> + Arbitrary<::FlexFlow::ReplicaParallelDim>::arbitrary() { + return gen::construct<::FlexFlow::ReplicaParallelDim>( gen::arbitrary(), gen::arbitrary<::FlexFlow::ReplicaType>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc index 7ef228e97e..20c88c77dc 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -4,7 +4,7 @@ namespace FlexFlow { ReplicaParallelDimSet empty_replica_parallel_dim_set() { - return ReplicaParallelDimSet{1, 1}; + return ReplicaParallelDimSet{SumDegree{1}, DiscardCopyDegree{1}}; } int get_order_of_replica_type(ReplicaParallelDimSet const &s, diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc index f8782be01b..1d11006523 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -52,7 +52,7 @@ bool ReplicaParallelDimSet::operator>=( namespace std { size_t hash::operator()( - FlexFlow::ReplicaParallelDimSet const &x) const { + ::FlexFlow::ReplicaParallelDimSet const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::SumDegree>{}(x.sum_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -63,14 +63,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ReplicaParallelDimSet - adl_serializer::from_json(json const &j) { - return {j.at("sum_degree").template get<::FlexFlow::SumDegree>(), - j.at("discard_copy_degree") - .template get<::FlexFlow::DiscardCopyDegree>()}; +::FlexFlow::ReplicaParallelDimSet + adl_serializer<::FlexFlow::ReplicaParallelDimSet>::from_json( + json const &j) { + return ::FlexFlow::ReplicaParallelDimSet{ + j.at("sum_degree").template get<::FlexFlow::SumDegree>(), + j.at("discard_copy_degree") + .template get<::FlexFlow::DiscardCopyDegree>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ReplicaParallelDimSet const &v) { +void adl_serializer<::FlexFlow::ReplicaParallelDimSet>::to_json( + json &j, ::FlexFlow::ReplicaParallelDimSet const &v) { j["__type"] = "ReplicaParallelDimSet"; j["sum_degree"] = v.sum_degree; j["discard_copy_degree"] = v.discard_copy_degree; @@ -78,9 +80,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::ReplicaParallelDimSet> + Arbitrary<::FlexFlow::ReplicaParallelDimSet>::arbitrary() { + return gen::construct<::FlexFlow::ReplicaParallelDimSet>( gen::arbitrary<::FlexFlow::SumDegree>(), gen::arbitrary<::FlexFlow::DiscardCopyDegree>()); } diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc index 9566eb486b..fba9e1b8f7 100644 --- a/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc @@ -42,7 +42,7 @@ bool ShardParallelDim::operator>=(ShardParallelDim const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ShardParallelDim const &x) const { + ::FlexFlow::ShardParallelDim const &x) const { size_t result = 0; result ^= std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -53,13 +53,13 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ShardParallelDim - adl_serializer::from_json(json const &j) { - return {j.at("size").template get(), - j.at("degree").template get()}; +::FlexFlow::ShardParallelDim + adl_serializer<::FlexFlow::ShardParallelDim>::from_json(json const &j) { + return ::FlexFlow::ShardParallelDim{j.at("size").template get(), + j.at("degree").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ShardParallelDim const &v) { +void adl_serializer<::FlexFlow::ShardParallelDim>::to_json( + json &j, ::FlexFlow::ShardParallelDim const &v) { j["__type"] = "ShardParallelDim"; j["size"] = v.size; j["degree"] = v.degree; @@ -67,10 +67,10 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary()); +Gen<::FlexFlow::ShardParallelDim> + Arbitrary<::FlexFlow::ShardParallelDim>::arbitrary() { + return gen::construct<::FlexFlow::ShardParallelDim>(gen::arbitrary(), + gen::arbitrary()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index ed40f509d9..de9c3d4adb 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -26,7 +26,8 @@ size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { ParallelTensorDims lift_to_parallel(TensorDims const &dims) { std::vector shard_degrees(num_dims(dims), 1); // 1 repeated num_dims(dims) times - return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); + return lift_to_parallel_with_degrees( + dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); } ParallelTensorDims diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc index 909be323ac..ab78d44805 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -37,7 +37,7 @@ bool TensorDims::operator>=(TensorDims const &other) const { namespace std { size_t hash::operator()( - FlexFlow::TensorDims const &x) const { + ::FlexFlow::TensorDims const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::FFOrdered>{}(x.ff_ordered) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -46,20 +46,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorDims - adl_serializer::from_json(json const &j) { - return {j.at("ff_ordered").template get<::FlexFlow::FFOrdered>()}; +::FlexFlow::TensorDims + adl_serializer<::FlexFlow::TensorDims>::from_json(json const &j) { + return ::FlexFlow::TensorDims{ + j.at("ff_ordered").template get<::FlexFlow::FFOrdered>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorDims const &v) { +void adl_serializer<::FlexFlow::TensorDims>::to_json( + json &j, ::FlexFlow::TensorDims const &v) { j["__type"] = "TensorDims"; j["ff_ordered"] = v.ff_ordered; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::TensorDims> Arbitrary<::FlexFlow::TensorDims>::arbitrary() { + return gen::construct<::FlexFlow::TensorDims>( gen::arbitrary<::FlexFlow::FFOrdered>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc index 92b31930fa..0c725dc443 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -45,7 +45,7 @@ bool TensorShape::operator>=(TensorShape const &other) const { namespace std { size_t hash::operator()( - FlexFlow::TensorShape const &x) const { + ::FlexFlow::TensorShape const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::TensorDims>{}(x.dims) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -56,13 +56,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorShape - adl_serializer::from_json(json const &j) { - return {j.at("dims").template get<::FlexFlow::TensorDims>(), - j.at("data_type").template get<::FlexFlow::DataType>()}; +::FlexFlow::TensorShape + adl_serializer<::FlexFlow::TensorShape>::from_json(json const &j) { + return ::FlexFlow::TensorShape{ + j.at("dims").template get<::FlexFlow::TensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorShape const &v) { +void adl_serializer<::FlexFlow::TensorShape>::to_json( + json &j, ::FlexFlow::TensorShape const &v) { j["__type"] = "TensorShape"; j["dims"] = v.dims; j["data_type"] = v.data_type; @@ -70,8 +71,8 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::TensorShape> Arbitrary<::FlexFlow::TensorShape>::arbitrary() { + return gen::construct<::FlexFlow::TensorShape>( gen::arbitrary<::FlexFlow::TensorDims>(), gen::arbitrary<::FlexFlow::DataType>()); } diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/datatype.cc index 90aa0c20f6..90dd64a352 100644 --- a/lib/op-attrs/test/src/datatype.cc +++ b/lib/op-attrs/test/src/datatype.cc @@ -1,11 +1,14 @@ -#include "test/utils/doctest.h" #include "op-attrs/datatype.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { - CHECK(can_strictly_promote_datatype_from_to(DataType::BOOL, DataType::INT32)); - CHECK(can_strictly_promote_datatype_from_to(DataType::INT32, DataType::INT64)); - CHECK(can_strictly_promote_datatype_from_to(DataType::FLOAT, DataType::DOUBLE)); + CHECK( + can_strictly_promote_datatype_from_to(DataType::BOOL, DataType::INT32)); + CHECK(can_strictly_promote_datatype_from_to(DataType::INT32, + DataType::INT64)); + CHECK(can_strictly_promote_datatype_from_to(DataType::FLOAT, + DataType::DOUBLE)); SUBCASE("is strict") { rc::check([](DataType d) { diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/ops/attention.cc similarity index 78% rename from lib/op-attrs/test/src/test_attention.cc rename to lib/op-attrs/test/src/ops/attention.cc index 74ae4565ca..2c7121e4a8 100644 --- a/lib/op-attrs/test/src/test_attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { /* Parameter meanings match those at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ - MultiHeadAttentionAttrs attrs = { + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ /*embed_dim=*/embed_dim, /*num_heads=*/10, /*kdim=*/embed_dim, @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t batch_size = 40; size_t seq_len = 48; - TensorShape input_q = { + TensorShape input_q = TensorShape{ TensorDims{FFOrdered{ batch_size, seq_len, @@ -34,7 +34,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape input_k = { + TensorShape input_k = TensorShape{ TensorDims{ FFOrdered{ batch_size, @@ -45,7 +45,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape input_v = { + TensorShape input_v = TensorShape{ TensorDims{ FFOrdered{ batch_size, @@ -104,7 +104,7 @@ TEST_SUITE(FF_TEST_SUITE) { /* Parameter meanings can be found at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ - MultiHeadAttentionAttrs attrs = { + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ /*embed_dim=*/embed_dim, /*num_heads=*/10, /*kdim=*/embed_dim, @@ -173,8 +173,8 @@ TEST_SUITE(FF_TEST_SUITE) { unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); }; - auto make_k = [&](int o_sum, - int o_eq, + auto make_k = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_k) { @@ -182,8 +182,8 @@ TEST_SUITE(FF_TEST_SUITE) { unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); }; - auto make_v = [&](int o_sum, - int o_eq, + auto make_v = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_v) { @@ -191,8 +191,8 @@ TEST_SUITE(FF_TEST_SUITE) { unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); }; - auto make_o = [&](int o_sum, - int o_eq, + auto make_o = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_o) { @@ -200,49 +200,56 @@ TEST_SUITE(FF_TEST_SUITE) { unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); }; - auto make_w = [&](int o_sum, int o_eq, int o_e, int o_h) { - return lift_to_parallel_with_degrees( - unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); - }; + auto make_w = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_e, int o_h) { + return lift_to_parallel_with_degrees( + unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); + }; SUBCASE("data parallelism") { int o_b = 4; - ParallelTensorShape q = make_q(1, 1, o_b, 1, 1); - ParallelTensorShape k = make_k(1, 1, o_b, 1, 1); - ParallelTensorShape v = make_v(1, 1, o_b, 1, 1); + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); tl::expected result_o = get_output_shape(attrs, q, k, v); tl::expected correct_o = - make_o(1, 1, o_b, 1, 1); + make_o(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); CHECK(result_o == correct_o); tl::expected result_w = get_weights_shape(attrs, q, k, v); tl::expected correct_w = - make_w(1, o_b, 1, 1); + make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, 1); CHECK(result_w == correct_w); } SUBCASE("attention head parallelism") { int o_h = 2; - ParallelTensorShape q = make_q(1, o_h, 1, 1, 1); - ParallelTensorShape k = make_k(1, o_h, 1, 1, 1); - ParallelTensorShape v = make_v(1, o_h, 1, 1, 1); + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); tl::expected result_o = get_output_shape(attrs, q, k, v); tl::expected correct_o = - make_o(o_h, 1, 1, 1, 1); + make_o(SumDegree{o_h}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result_o == correct_o); tl::expected result_w = get_weights_shape(attrs, q, k, v); tl::expected correct_w = - make_w(1, 1, 1, o_h); + make_w(SumDegree{1}, DiscardCopyDegree{1}, 1, o_h); CHECK(result_w == correct_w); } @@ -250,21 +257,24 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("combined data & attention head parallelism") { int o_b = 4; int o_h = 2; - ParallelTensorShape q = make_q(1, o_h, o_b, 1, 1); - ParallelTensorShape k = make_k(1, o_h, o_b, 1, 1); - ParallelTensorShape v = make_v(1, o_h, o_b, 1, 1); + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); tl::expected result_o = get_output_shape(attrs, q, k, v); tl::expected correct_o = - make_o(o_h, 1, o_b, 1, 1); + make_o(SumDegree{o_h}, DiscardCopyDegree{1}, o_b, 1, 1); CHECK(result_o == correct_o); tl::expected result_w = get_weights_shape(attrs, q, k, v); tl::expected correct_w = - make_w(1, o_b, 1, o_h); + make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, o_h); CHECK(result_w == correct_w); } diff --git a/lib/op-attrs/test/src/test_batch_matmul.cc b/lib/op-attrs/test/src/ops/batch_matmul.cc similarity index 59% rename from lib/op-attrs/test/src/test_batch_matmul.cc rename to lib/op-attrs/test/src/ops/batch_matmul.cc index f48478be10..3ff02ccece 100644 --- a/lib/op-attrs/test/src/test_batch_matmul.cc +++ b/lib/op-attrs/test/src/ops/batch_matmul.cc @@ -8,13 +8,13 @@ TEST_SUITE(FF_TEST_SUITE) { size_t n = 8; size_t p = 10; - BatchMatmulAttrs attrs = { + BatchMatmulAttrs attrs = BatchMatmulAttrs{ /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still // relevant /*b_seq_length_dim=*/0, }; - TensorShape input_lhs_shape = { + TensorShape input_lhs_shape = TensorShape{ TensorDims{ FFOrdered{ b, @@ -26,7 +26,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; SUBCASE("valid") { - TensorShape input_rhs_shape = { + TensorShape input_rhs_shape = TensorShape{ TensorDims{ FFOrdered{ b, @@ -55,7 +55,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("mismatched b") { - TensorShape input_rhs_shape = { + TensorShape input_rhs_shape = TensorShape{ TensorDims{ FFOrdered{ b + 1, @@ -73,7 +73,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("mismatched m") { - TensorShape input_rhs_shape = { + TensorShape input_rhs_shape = TensorShape{ TensorDims{ FFOrdered{ b, @@ -102,13 +102,17 @@ TEST_SUITE(FF_TEST_SUITE) { int o_p = 7; int o_sum = 11; - BatchMatmulAttrs attrs = { + BatchMatmulAttrs attrs = BatchMatmulAttrs{ /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still // relevant /*b_seq_length_dim=*/0, }; - auto make_lhs = [&](int o_sum, int o_eq, int o_b, int o_n, int o_m) { + auto make_lhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_b, + int o_n, + int o_m) { return ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ @@ -125,7 +129,11 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - auto make_rhs = [&](int o_sum, int o_eq, int o_b, int o_m, int o_p) { + auto make_rhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_b, + int o_m, + int o_p) { return ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ @@ -142,7 +150,11 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - auto make_output = [&](int o_sum, int o_eq, int o_b, int o_n, int o_p) { + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_b, + int o_n, + int o_p) { return ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ @@ -161,106 +173,121 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("data parallel") { tl::expected result = get_output_shape( - attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); + attrs, + make_lhs(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1), + make_rhs(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1)); tl::expected correct = - make_output(1, 1, o_b, 1, 1); + make_output(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); CHECK(result == correct); } SUBCASE("n parallel") { tl::expected result = get_output_shape( - attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); + attrs, + make_lhs(SumDegree{1}, DiscardCopyDegree{1}, 1, o_n, 1), + make_rhs(SumDegree{1}, DiscardCopyDegree{o_n}, 1, 1, 1)); tl::expected correct = - make_output(1, 1, 1, o_n, 1); + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, o_n, 1); CHECK(result == correct); } SUBCASE("p parallel") { tl::expected result = get_output_shape( - attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); + attrs, + make_lhs(SumDegree{1}, DiscardCopyDegree{o_p}, 1, 1, 1), + make_rhs(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, o_p)); tl::expected correct = - make_output(1, 1, 1, 1, o_p); + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, o_p); CHECK(result == correct); } SUBCASE("reduction parallel") { tl::expected result = get_output_shape( - attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); + attrs, + make_lhs(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, o_m), + make_rhs(SumDegree{1}, DiscardCopyDegree{1}, 1, o_m, 1)); tl::expected correct = - make_output(o_m, 1, 1, 1, 1); + make_output(SumDegree{o_m}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result == correct); } SUBCASE("propagate reduction lhs") { tl::expected result = get_output_shape( - attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1), + make_rhs(SumDegree{1}, DiscardCopyDegree{o_sum}, 1, 1, 1)); tl::expected correct = - make_output(o_sum, 1, 1, 1, 1); + make_output(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result == correct); } SUBCASE("propagate reduction rhs") { tl::expected result = get_output_shape( - attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + attrs, + make_lhs(SumDegree{1}, DiscardCopyDegree{o_sum}, 1, 1, 1), + make_rhs(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1)); tl::expected correct = - make_output(o_sum, 1, 1, 1, 1); + make_output(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs") { - tl::expected result = - get_output_shape(attrs, - make_lhs(o_sum, o_sum, 1, 1, 1), - make_rhs(o_sum, o_sum, 1, 1, 1)); + tl::expected result = get_output_shape( + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1, 1, 1), + make_rhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1, 1, 1)); tl::expected correct = - make_output(o_sum * o_sum, 1, 1, 1, 1); + make_output(SumDegree{o_sum * o_sum}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result == correct); } SUBCASE("reduction lhs & rhs (invalid)") { tl::expected result = get_output_shape( - attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1), + make_rhs(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, 1, 1)); CHECK_MESSAGE( !result.has_value(), "Unexpected successful value: ", result); } SUBCASE("reduction lhs & n") { - tl::expected result = - get_output_shape(attrs, - make_lhs(o_sum, 1, 1, o_n, 1), - make_rhs(1, o_sum * o_n, 1, 1, 1)); + tl::expected result = get_output_shape( + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, o_n, 1), + make_rhs(SumDegree{1}, DiscardCopyDegree{o_sum * o_n}, 1, 1, 1)); tl::expected correct = - make_output(o_sum, 1, 1, o_n, 1); + make_output(SumDegree{o_sum}, DiscardCopyDegree{1}, 1, o_n, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs & n") { - tl::expected result = - get_output_shape(attrs, - make_lhs(o_sum, o_sum, 1, o_n, 1), - make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); - tl::expected correct = - make_output(o_sum * o_sum, 1, 1, o_n, 1); + tl::expected result = get_output_shape( + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1, o_n, 1), + make_rhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum * o_n}, 1, 1, 1)); + tl::expected correct = make_output( + SumDegree{o_sum * o_sum}, DiscardCopyDegree{1}, 1, o_n, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs & n & m") { - tl::expected result = - get_output_shape(attrs, - make_lhs(o_sum, o_sum, 1, o_n, o_m), - make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); - tl::expected correct = - make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); + tl::expected result = get_output_shape( + attrs, + make_lhs(SumDegree{o_sum}, DiscardCopyDegree{o_sum}, 1, o_n, o_m), + make_rhs( + SumDegree{o_sum}, DiscardCopyDegree{o_sum * o_n}, 1, o_m, 1)); + tl::expected correct = make_output( + SumDegree{o_sum * o_sum * o_m}, DiscardCopyDegree{1}, 1, o_n, 1); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc index a50b3b01de..ac18bbc798 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -4,7 +4,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { - ParallelTensorShape input = { + ParallelTensorShape input = ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ ShardParallelDim{12, 2}, @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; SUBCASE("valid") { - ff_dim_t dim = 2; + ff_dim_t dim = ff_dim_t{2}; int degree = 3; CombineAttrs attrs = CombineAttrs{ /*repartition_dim=*/dim, @@ -41,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("invalid") { - ff_dim_t dim = 2; + ff_dim_t dim = ff_dim_t{2}; int degree = 4; CombineAttrs attrs = CombineAttrs{ /*repartition_dim=*/dim, diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/ops/conv_2d.cc similarity index 67% rename from lib/op-attrs/test/src/test_conv_2d.cc rename to lib/op-attrs/test/src/ops/conv_2d.cc index 85d95b42cb..6f5028cfeb 100644 --- a/lib/op-attrs/test/src/test_conv_2d.cc +++ b/lib/op-attrs/test/src/ops/conv_2d.cc @@ -1,5 +1,5 @@ -#include "doctest/doctest.h" #include "op-attrs/ops/conv_2d.h" +#include "doctest/doctest.h" #include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional activation = std::nullopt; bool use_bias = true; - Conv2DAttrs attrs = { + Conv2DAttrs attrs = Conv2DAttrs{ /*out_channels=*/out_channels, /*kernel_h=*/kernel_h, /*kernel_w=*/kernel_w, @@ -33,7 +33,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t input_height = 10; size_t input_width = 15; - TensorShape input = { + TensorShape input = TensorShape{ TensorDims{FFOrdered{ num_samples, input_channels, @@ -46,7 +46,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t output_height = 3; size_t output_width = 6; - TensorShape output = { + TensorShape output = TensorShape{ TensorDims{FFOrdered{ num_samples, size_t_from_int(out_channels), @@ -56,19 +56,19 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape kernel = { + TensorShape kernel = TensorShape{ TensorDims{FFOrdered{ - size_t_from_int(out_channels), - input_channels, - size_t_from_int(kernel_h), - size_t_from_int(kernel_w), + size_t_from_int(out_channels), + input_channels, + size_t_from_int(kernel_h), + size_t_from_int(kernel_w), }}, DataType::FLOAT, }; - TensorShape bias = { + TensorShape bias = TensorShape{ TensorDims{FFOrdered{ - size_t_from_int(out_channels), + size_t_from_int(out_channels), }}, DataType::FLOAT, }; @@ -118,104 +118,122 @@ TEST_SUITE(FF_TEST_SUITE) { int o_kernel_h, int o_kernel_w) { return lift_to_parallel_with_degrees( - kernel, o_sum, o_eq, FFOrdered{o_outchannels, o_inchannels, o_kernel_h, o_kernel_w}); + kernel, + o_sum, + o_eq, + FFOrdered{o_outchannels, o_inchannels, o_kernel_h, o_kernel_w}); }; - auto make_bias = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_outchannels) { - return lift_to_parallel_with_degrees( - bias, o_sum, o_eq, FFOrdered{o_outchannels}); - }; + auto make_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannels) { + return lift_to_parallel_with_degrees( + bias, o_sum, o_eq, FFOrdered{o_outchannels}); + }; SUBCASE("data parallelism") { int degree = 2; - ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); SUBCASE("get_output_shape") { ParallelTensorShape result = get_output_shape(attrs, par_input); - ParallelTensorShape correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + ParallelTensorShape correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_kernel_shape") { ParallelTensorShape result = get_kernel_shape(attrs, par_input); - ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape correct = + make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_bias_shape") { ParallelTensorShape result = get_bias_shape(attrs, par_input); - ParallelTensorShape correct = make_bias(SumDegree{1}, DiscardCopyDegree{degree}, 1); + ParallelTensorShape correct = + make_bias(SumDegree{1}, DiscardCopyDegree{degree}, 1); CHECK(result == correct); } } SUBCASE("input channel parallelism") { int degree = 2; - ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); - + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + SUBCASE("get_output_shape") { ParallelTensorShape result = get_output_shape(attrs, par_input); - ParallelTensorShape correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + ParallelTensorShape correct = + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_kernel_shape") { ParallelTensorShape result = get_kernel_shape(attrs, par_input); - ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + ParallelTensorShape correct = + make_kernel(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); CHECK(result == correct); } SUBCASE("get_bias_shape") { ParallelTensorShape result = get_bias_shape(attrs, par_input); - ParallelTensorShape correct = make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); + ParallelTensorShape correct = + make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); CHECK(result == correct); } } SUBCASE("output channel parallelism") { int degree = 2; - ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); - + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + SUBCASE("get_output_shape") { ParallelTensorShape result = get_output_shape(attrs, par_input); - ParallelTensorShape correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); + ParallelTensorShape correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1, 1); CHECK(result == correct); } SUBCASE("get_kernel_shape") { ParallelTensorShape result = get_kernel_shape(attrs, par_input); - ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); + ParallelTensorShape correct = + make_kernel(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_bias_shape") { ParallelTensorShape result = get_bias_shape(attrs, par_input); - ParallelTensorShape correct = make_bias(SumDegree{1}, DiscardCopyDegree{1}, degree); + ParallelTensorShape correct = + make_bias(SumDegree{1}, DiscardCopyDegree{1}, degree); CHECK(result == correct); } } SUBCASE("propagating sum degree") { int degree = 2; - ParallelTensorShape par_input = make_input(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); - + ParallelTensorShape par_input = + make_input(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + SUBCASE("get_output_shape") { ParallelTensorShape result = get_output_shape(attrs, par_input); - ParallelTensorShape correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); + ParallelTensorShape correct = + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_kernel_shape") { ParallelTensorShape result = get_kernel_shape(attrs, par_input); - ParallelTensorShape correct = make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape correct = + make_kernel(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("get_bias_shape") { ParallelTensorShape result = get_bias_shape(attrs, par_input); - ParallelTensorShape correct = make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); + ParallelTensorShape correct = + make_bias(SumDegree{degree}, DiscardCopyDegree{1}, 1); CHECK(result == correct); } } diff --git a/lib/op-attrs/test/src/test_element_binary.cc b/lib/op-attrs/test/src/ops/element_binary.cc similarity index 82% rename from lib/op-attrs/test/src/test_element_binary.cc rename to lib/op-attrs/test/src/ops/element_binary.cc index b1aedbf6b5..0ed695eb89 100644 --- a/lib/op-attrs/test/src/test_element_binary.cc +++ b/lib/op-attrs/test/src/ops/element_binary.cc @@ -108,12 +108,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("data parallelism") { int degree = 4; - ParallelTensorShape input_lhs = make_lhs(1, 1, degree, 1, 1); - ParallelTensorShape input_rhs = make_rhs(1, 1, degree, 1, 1); + ParallelTensorShape input_lhs = + make_lhs(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1); + ParallelTensorShape input_rhs = + make_rhs(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1); tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); tl::expected correct = - make_output(1, 1, degree, 1, 1); + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1, 1); CHECK(result == correct); } @@ -121,12 +123,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("reduction parallelism") { int degree = 4; - ParallelTensorShape input_lhs = make_lhs(SumDegree{degree}, 1, 1, 1, 1); - ParallelTensorShape input_rhs = make_rhs(SumDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape input_lhs = + make_lhs(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1); + ParallelTensorShape input_rhs = + make_rhs(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1); tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); tl::expected correct = - make_output(SumDegree{degree}, 1, 1, 1, 1); + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1); CHECK(result == correct); } @@ -135,9 +139,9 @@ TEST_SUITE(FF_TEST_SUITE) { int degree = 4; ParallelTensorShape input_lhs = - make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + make_lhs(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1); ParallelTensorShape input_rhs = - make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + make_rhs(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1); tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); @@ -149,8 +153,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("invalid mismatched parallelism degrees") { int degree = 4; - ParallelTensorShape input_lhs = make_lhs(1, 1, 1, degree, 1); - ParallelTensorShape input_rhs = make_rhs(1, 1, 1, 1, degree); + ParallelTensorShape input_lhs = + make_lhs(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1); + ParallelTensorShape input_rhs = + make_rhs(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, degree); tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); diff --git a/lib/op-attrs/test/src/test_element_unary.cc b/lib/op-attrs/test/src/ops/element_unary.cc similarity index 83% rename from lib/op-attrs/test/src/test_element_unary.cc rename to lib/op-attrs/test/src/ops/element_unary.cc index 384dbc1a53..bf94e55235 100644 --- a/lib/op-attrs/test/src/test_element_unary.cc +++ b/lib/op-attrs/test/src/ops/element_unary.cc @@ -39,7 +39,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("partition i.e., sharding parallelism") { int degree1 = 4; int degree2 = 8; - ParallelTensorShape par_input = make_i(1, 1, degree1, 1, degree2); + ParallelTensorShape par_input = + make_i(SumDegree{1}, DiscardCopyDegree{1}, degree1, 1, degree2); tl::expected result = get_output_shape(attrs, par_input); @@ -51,8 +52,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum degree > 1") { int degree = 2; - tl::expected result = - get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); + tl::expected result = get_output_shape( + attrs, make_i(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1, 1)); CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", @@ -63,7 +64,7 @@ TEST_SUITE(FF_TEST_SUITE) { int degree = 2; tl::expected result = get_output_shape( - attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); + attrs, make_i(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1, 1)); CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", diff --git a/lib/op-attrs/test/src/test_embedding.cc b/lib/op-attrs/test/src/ops/embedding.cc similarity index 99% rename from lib/op-attrs/test/src/test_embedding.cc rename to lib/op-attrs/test/src/ops/embedding.cc index 7bce6bd4d9..9180f7055d 100644 --- a/lib/op-attrs/test/src/test_embedding.cc +++ b/lib/op-attrs/test/src/ops/embedding.cc @@ -17,7 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t batch_size = 48; size_t features_dim = 56; - TensorShape input = { + TensorShape input = TensorShape{ TensorDims{FFOrdered{ batch_size, features_dim, diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc index 6f73951e00..59ed5bb5ee 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -4,7 +4,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { - ParallelTensorShape input = { + ParallelTensorShape input = ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ ShardParallelDim{12, 2}, diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc index 3b3ae92b4c..af28a6d471 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -3,14 +3,14 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { - ff_dim_t dim = 2; + ff_dim_t dim = ff_dim_t{2}; int degree = 4; RepartitionAttrs attrs = RepartitionAttrs{ /*repartition_dim=*/dim, /*repartition_degree=*/degree, }; - ParallelTensorShape input = { + ParallelTensorShape input = ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ ShardParallelDim{12, 2}, diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc index b326038388..a0ec40cc14 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -7,7 +7,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*replicate_degree=*/4, }; - ParallelTensorShape input = { + ParallelTensorShape input = ParallelTensorShape{ ParallelTensorDims{ FFOrdered{ ShardParallelDim{10, 2}, @@ -26,7 +26,7 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape result = get_output_shape(attrs, input); ParallelTensorShape correct_output = input; - correct_output.dims.replica_dims.discard_copy_degree = 8; + correct_output.dims.replica_dims.discard_copy_degree = DiscardCopyDegree{8}; CHECK(result == correct_output); } diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h index 217b940ce6..028d9ecfab 100644 --- a/lib/pcg/include/pcg/computation_graph.dtg.h +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -3,21 +3,21 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "8f1f0e13d75065944f7fe307e12fe280" + "generated_from": "bf8996bea2e022265a372d692c2db8ed" } */ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H -#include "pcg/dataflow_graph.h" +#include "pcg/dataflow_graph/dataflow_graph.h" #include "pcg/layer_attrs.dtg.h" #include "pcg/tensor_attrs.dtg.h" namespace FlexFlow { struct ComputationGraph { ComputationGraph() = delete; - ComputationGraph( + explicit ComputationGraph( ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs> const &raw_graph); diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml index a270cb8fbe..39c68b8e4f 100644 --- a/lib/pcg/include/pcg/computation_graph.struct.toml +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -5,7 +5,7 @@ features = [ ] includes = [ "pcg/layer_attrs.dtg.h", "pcg/tensor_attrs.dtg.h", - "pcg/dataflow_graph.h", + "pcg/dataflow_graph/dataflow_graph.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h index 4fd78f2d44..5013871cc8 100644 --- a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml /* proj-data { - "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" + "generated_from": "234b5c222ae4ce1da36194b4eb519145" } */ @@ -13,14 +13,16 @@ #include "fmt/format.h" #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_guid_t.dtg.h" +#include "utils/fmt/vector.h" #include #include namespace FlexFlow { struct LayerAddedResult { LayerAddedResult() = delete; - LayerAddedResult(::FlexFlow::layer_guid_t const &layer, - std::vector<::FlexFlow::tensor_guid_t> const &outputs); + explicit LayerAddedResult( + ::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs); bool operator==(LayerAddedResult const &) const; bool operator!=(LayerAddedResult const &) const; diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml index b02e992ba1..d7b669fb3a 100644 --- a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml @@ -8,6 +8,7 @@ features = [ includes = [ "pcg/layer_guid_t.dtg.h", "pcg/tensor_guid_t.dtg.h", + "utils/fmt/vector.h" ] [[fields]] diff --git a/lib/pcg/include/pcg/cpu_id_t.dtg.h b/lib/pcg/include/pcg/cpu_id_t.dtg.h index a6c81e80b0..b5c5bdd22f 100644 --- a/lib/pcg/include/pcg/cpu_id_t.dtg.h +++ b/lib/pcg/include/pcg/cpu_id_t.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct cpu_id_t { cpu_id_t() = delete; - cpu_id_t(int const &cpu_index); + explicit cpu_id_t(int const &cpu_index); bool operator==(cpu_id_t const &) const; bool operator!=(cpu_id_t const &) const; @@ -34,23 +34,23 @@ struct cpu_id_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::cpu_id_t const &) const; +struct hash<::FlexFlow::cpu_id_t> { + size_t operator()(::FlexFlow::cpu_id_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::cpu_id_t from_json(json const &); - static void to_json(json &, FlexFlow::cpu_id_t const &); +struct adl_serializer<::FlexFlow::cpu_id_t> { + static ::FlexFlow::cpu_id_t from_json(json const &); + static void to_json(json &, ::FlexFlow::cpu_id_t const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::cpu_id_t> { + static Gen<::FlexFlow::cpu_id_t> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph.h deleted file mode 100644 index f649c0444c..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H - -#include "utils/containers/enumerate_vector.h" -#include "utils/graph.h" - -namespace FlexFlow { - -template -struct DataflowGraph { -public: - DataflowGraph() - : g(OutputLabelledMultiDiGraph::template create< - UnorderedOutputLabelledMultiDiGraph>()) {} - - std::vector - add_operator(NodeLabel const &func, - std::vector const &inputs, - std::vector const &outputs) { - Node n = this->g.add_node(func); - for (auto const &[idx, input] : enumerate_vector(inputs)) { - this->g.add_edge(MultiDiEdge{ - input.src, input.src_idx, n, this->make_port_for_idx(idx)}); - } - - std::vector result; - for (auto const &[idx, label] : enumerate_vector(outputs)) { - MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; - this->g.add_output(output, label); - result.push_back(output); - } - - return result; - } - - NodePort make_port_for_idx(int idx) { - if (!this->port_mapping.contains_l(idx)) { - this->port_mapping.equate(idx, this->g.add_node_port()); - } - return this->port_mapping.at_l(idx); - } - - NodePort port_for_idx(int idx) const { - return this->port_mapping.at_l(idx); - } - - int idx_for_port(NodePort const &p) const { - return this->port_mapping.at_r(p); - } - - OutputLabelledMultiDiGraphView const & - get_raw_graph() const { - return this->g; - } - - NodeLabel const &at(Node const &n) const { - return this->g.at(n); - } - - OutputLabel const &at(MultiDiOutput const &o) const { - return this->g.at(o); - } - -private: - OutputLabelledMultiDiGraph g; - bidict port_mapping; -}; - -template -std::unordered_set - get_nodes(DataflowGraph const &g) { - return get_nodes(g.get_raw_graph()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h new file mode 100644 index 0000000000..e90acf533d --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h @@ -0,0 +1,123 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H + +#include "pcg/dataflow_graph/operator_added_result.dtg.h" +#include "utils/containers/enumerate_vector.h" +#include "utils/graph.h" + +namespace FlexFlow { + +template +struct DataflowGraph { +public: + DataflowGraph() + : g(OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>()) {} + + OperatorAddedResult + add_operator(NodeLabel const &func, + std::vector const &inputs, + std::vector const &output_labels) { + Node node = this->g.add_node(func); + for (auto const &[idx, input] : enumerate_vector(inputs)) { + this->g.add_edge(MultiDiEdge{ + node, this->make_port_for_idx(idx), input.src, input.src_idx}); + } + + std::vector outputs; + for (auto const &[idx, label] : enumerate_vector(output_labels)) { + MultiDiOutput output = MultiDiOutput{node, this->make_port_for_idx(idx)}; + this->g.add_output(output, label); + outputs.push_back(output); + } + this->output_map[node] = outputs; + + return OperatorAddedResult{ + node, + outputs, + }; + } + + NodePort make_port_for_idx(int idx) { + if (!this->port_mapping.contains_l(idx)) { + this->port_mapping.equate(idx, this->g.add_node_port()); + } + return this->port_mapping.at_l(idx); + } + + NodePort port_for_idx(int idx) const { + return this->port_mapping.at_l(idx); + } + + int idx_for_port(NodePort const &p) const { + return this->port_mapping.at_r(p); + } + + OutputLabelledMultiDiGraphView const & + get_raw_graph() const { + return this->g; + } + + NodeLabel const &at(Node const &n) const { + return this->g.at(n); + } + + OutputLabel const &at(MultiDiOutput const &o) const { + return this->g.at(o); + } + + std::unordered_map> const &get_output_map() const { + return this->output_map; + } +private: + OutputLabelledMultiDiGraph g; + bidict port_mapping; + std::unordered_map> output_map; // NOTE(@lockshaw): temporary workaround until not tracking outputs + // independent of edges in multidigraph is resolved +}; + +template +std::unordered_set + get_nodes(DataflowGraph const &g) { + return get_nodes(g.get_raw_graph()); +} + +template +std::vector + vector_from_indexed_set(std::vector> const &s) { + std::vector> result{s.size(), std::nullopt}; + for (auto const &[idx, value] : s) { + assert(idx < s.size() && idx >= 0); + assert(!result.at(idx).has_value()); + result.at(idx) = value; + } + return transform(result, [](std::optional const &v) { + assert(v.has_value()); + return v.value(); + }); +} + +template +std::vector + get_inputs(DataflowGraph const &g, Node const &n) { + std::vector> input_edges = + transform(as_vector(get_incoming_edges(g.get_raw_graph(), + std::unordered_set{n})), + [&](MultiDiEdge const &e) { + int idx = g.idx_for_port(e.dst_idx); + MultiDiOutput val = static_cast(e); + return std::make_pair(idx, val); + }); + + return vector_from_indexed_set(input_edges); +} + +template +std::vector + get_outputs(DataflowGraph const &g, Node const &n) { + return g.get_output_map().at(n); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h new file mode 100644 index 0000000000..9e9803b8a0 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h @@ -0,0 +1,43 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml +/* proj-data +{ + "generated_from": "62224733c501773b41f1fc63a8677949" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "utils/fmt/vector.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAddedResult { + OperatorAddedResult() = delete; + explicit OperatorAddedResult( + ::FlexFlow::Node const &node, + std::vector<::FlexFlow::MultiDiOutput> const &outputs); + + bool operator==(OperatorAddedResult const &) const; + bool operator!=(OperatorAddedResult const &) const; + bool operator<(OperatorAddedResult const &) const; + bool operator>(OperatorAddedResult const &) const; + bool operator<=(OperatorAddedResult const &) const; + bool operator>=(OperatorAddedResult const &) const; + ::FlexFlow::Node node; + std::vector<::FlexFlow::MultiDiOutput> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(OperatorAddedResult const &); +std::ostream &operator<<(std::ostream &, OperatorAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml new file mode 100644 index 0000000000..3c9cb87e85 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OperatorAddedResult" + +features = [ + "eq", + "ord", + "fmt", +] + +includes = [ + "", + "utils/graph.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::MultiDiOutput>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index dad73ce142..6090d60e1a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -4,9 +4,9 @@ #include "pcg/computation_graph.dtg.h" #include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" #include "pcg/layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph.dtg.h" -#include "pcg/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "pcg/tensor_attrs.dtg.h" #include "utils/json.h" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h index e9238301d0..3243cca010 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h @@ -19,10 +19,10 @@ namespace FlexFlow { struct V1GraphEdge { V1GraphEdge() = delete; - V1GraphEdge(size_t const &srcNode, - size_t const &srcIdx, - size_t const &dstNode, - size_t const &dstIdx); + explicit V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx); bool operator==(V1GraphEdge const &) const; bool operator!=(V1GraphEdge const &) const; @@ -39,16 +39,16 @@ struct V1GraphEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::V1GraphEdge const &) const; +struct hash<::FlexFlow::V1GraphEdge> { + size_t operator()(::FlexFlow::V1GraphEdge const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::V1GraphEdge from_json(json const &); - static void to_json(json &, FlexFlow::V1GraphEdge const &); +struct adl_serializer<::FlexFlow::V1GraphEdge> { + static ::FlexFlow::V1GraphEdge from_json(json const &); + static void to_json(json &, ::FlexFlow::V1GraphEdge const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h index 730282bdb9..eb9c013b36 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct V1GraphOutput { V1GraphOutput() = delete; - V1GraphOutput(size_t const &srcNode, size_t const &srcIdx); + explicit V1GraphOutput(size_t const &srcNode, size_t const &srcIdx); bool operator==(V1GraphOutput const &) const; bool operator!=(V1GraphOutput const &) const; @@ -34,16 +34,16 @@ struct V1GraphOutput { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::V1GraphOutput const &) const; +struct hash<::FlexFlow::V1GraphOutput> { + size_t operator()(::FlexFlow::V1GraphOutput const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::V1GraphOutput from_json(json const &); - static void to_json(json &, FlexFlow::V1GraphOutput const &); +struct adl_serializer<::FlexFlow::V1GraphOutput> { + static ::FlexFlow::V1GraphOutput from_json(json const &); + static void to_json(json &, ::FlexFlow::V1GraphOutput const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h index f183a14a9e..c6ffb55e3b 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { template struct V1JsonableGraph { V1JsonableGraph() = delete; - V1JsonableGraph( + explicit V1JsonableGraph( std::unordered_map const &node_labels, std::unordered_map const &outputs, std::unordered_map const &output_labels, @@ -37,10 +37,10 @@ struct V1JsonableGraph { namespace nlohmann { template -struct adl_serializer> { - static FlexFlow::V1JsonableGraph from_json(json const &); +struct adl_serializer<::FlexFlow::V1JsonableGraph> { + static ::FlexFlow::V1JsonableGraph from_json(json const &); static void to_json(json &, - FlexFlow::V1JsonableGraph const &); + ::FlexFlow::V1JsonableGraph const &); }; } // namespace nlohmann @@ -65,10 +65,10 @@ V1JsonableGraph::V1JsonableGraph( namespace nlohmann { template -FlexFlow::V1JsonableGraph - adl_serializer>::from_json( +::FlexFlow::V1JsonableGraph + adl_serializer<::FlexFlow::V1JsonableGraph>::from_json( json const &j) { - return { + return ::FlexFlow::V1JsonableGraph{ j.at("node_labels").template get>(), j.at("outputs") .template get< @@ -77,8 +77,8 @@ FlexFlow::V1JsonableGraph j.at("graph").template get<::FlexFlow::V1MultiDiGraph>()}; } template -void adl_serializer>::to_json( - json &j, FlexFlow::V1JsonableGraph const &v) { +void adl_serializer<::FlexFlow::V1JsonableGraph>::to_json( + json &j, ::FlexFlow::V1JsonableGraph const &v) { j["__type"] = "V1JsonableGraph"; j["node_labels"] = v.node_labels; j["outputs"] = v.outputs; diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h index 5d7edcf1d8..5b214d2b58 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml /* proj-data { - "generated_from": "fb1033385645e54a19c9b44cef0be04b" + "generated_from": "582054edb983c3cc31d9273ce29421eb" } */ @@ -13,7 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include #include #include @@ -21,9 +22,10 @@ namespace FlexFlow { struct V1MultiDiGraph { V1MultiDiGraph() = delete; - V1MultiDiGraph(std::vector const &nodes, - std::vector const &ports, - std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + explicit V1MultiDiGraph( + std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); std::vector nodes; std::vector ports; @@ -33,9 +35,9 @@ struct V1MultiDiGraph { namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::V1MultiDiGraph from_json(json const &); - static void to_json(json &, FlexFlow::V1MultiDiGraph const &); +struct adl_serializer<::FlexFlow::V1MultiDiGraph> { + static ::FlexFlow::V1MultiDiGraph from_json(json const &); + static void to_json(json &, ::FlexFlow::V1MultiDiGraph const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml index 9650f3bd43..20ca69eed4 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml @@ -13,7 +13,8 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", - "utils/fmt.h", + "utils/fmt/vector.h", + "utils/fmt/unordered_set.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h index 7e5554d44a..f1e9cb5a5c 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml /* proj-data { - "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" + "generated_from": "fed215ca219af1bd375801eb2e33b473" } */ @@ -13,7 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include #include #include @@ -21,8 +22,9 @@ namespace FlexFlow { struct V1OperatorGraph { V1OperatorGraph() = delete; - V1OperatorGraph(std::vector const &nodes, - std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + explicit V1OperatorGraph( + std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); std::vector nodes; std::unordered_set<::FlexFlow::V1GraphEdge> edges; @@ -31,9 +33,9 @@ struct V1OperatorGraph { namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::V1OperatorGraph from_json(json const &); - static void to_json(json &, FlexFlow::V1OperatorGraph const &); +struct adl_serializer<::FlexFlow::V1OperatorGraph> { + static ::FlexFlow::V1OperatorGraph from_json(json const &); + static void to_json(json &, ::FlexFlow::V1OperatorGraph const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml index 61dc45ae2e..2715ae176b 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml @@ -13,7 +13,8 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", - "utils/fmt.h", + "utils/fmt/unordered_set.h", + "utils/fmt/vector.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/gpu_id_t.dtg.h b/lib/pcg/include/pcg/gpu_id_t.dtg.h index f0847848ca..e056b8e0e3 100644 --- a/lib/pcg/include/pcg/gpu_id_t.dtg.h +++ b/lib/pcg/include/pcg/gpu_id_t.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct gpu_id_t { gpu_id_t() = delete; - gpu_id_t(int const &gpu_index); + explicit gpu_id_t(int const &gpu_index); bool operator==(gpu_id_t const &) const; bool operator!=(gpu_id_t const &) const; @@ -34,23 +34,23 @@ struct gpu_id_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::gpu_id_t const &) const; +struct hash<::FlexFlow::gpu_id_t> { + size_t operator()(::FlexFlow::gpu_id_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::gpu_id_t from_json(json const &); - static void to_json(json &, FlexFlow::gpu_id_t const &); +struct adl_serializer<::FlexFlow::gpu_id_t> { + static ::FlexFlow::gpu_id_t from_json(json const &); + static void to_json(json &, ::FlexFlow::gpu_id_t const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::gpu_id_t> { + static Gen<::FlexFlow::gpu_id_t> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h index 1eb9eb8834..1512cb8e18 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct ConstantInitializerAttrs { ConstantInitializerAttrs() = delete; - ConstantInitializerAttrs(::FlexFlow::DataTypeValue const &value); + explicit ConstantInitializerAttrs(::FlexFlow::DataTypeValue const &value); bool operator==(ConstantInitializerAttrs const &) const; bool operator!=(ConstantInitializerAttrs const &) const; @@ -35,16 +35,16 @@ struct ConstantInitializerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ConstantInitializerAttrs const &) const; +struct hash<::FlexFlow::ConstantInitializerAttrs> { + size_t operator()(::FlexFlow::ConstantInitializerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ConstantInitializerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ConstantInitializerAttrs const &); +struct adl_serializer<::FlexFlow::ConstantInitializerAttrs> { + static ::FlexFlow::ConstantInitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ConstantInitializerAttrs const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h index 04851fb333..e6fe29a048 100644 --- a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct GlorotUniformAttrs { GlorotUniformAttrs() = delete; - GlorotUniformAttrs(int const &seed); + explicit GlorotUniformAttrs(int const &seed); bool operator==(GlorotUniformAttrs const &) const; bool operator!=(GlorotUniformAttrs const &) const; @@ -34,23 +34,23 @@ struct GlorotUniformAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::GlorotUniformAttrs const &) const; +struct hash<::FlexFlow::GlorotUniformAttrs> { + size_t operator()(::FlexFlow::GlorotUniformAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::GlorotUniformAttrs from_json(json const &); - static void to_json(json &, FlexFlow::GlorotUniformAttrs const &); +struct adl_serializer<::FlexFlow::GlorotUniformAttrs> { + static ::FlexFlow::GlorotUniformAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::GlorotUniformAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::GlorotUniformAttrs> { + static Gen<::FlexFlow::GlorotUniformAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h index e1d3e59ed7..602a877c30 100644 --- a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h @@ -20,7 +20,9 @@ namespace FlexFlow { struct NormInitializerAttrs { NormInitializerAttrs() = delete; - NormInitializerAttrs(int const &seed, float const &mean, float const &stddev); + explicit NormInitializerAttrs(int const &seed, + float const &mean, + float const &stddev); bool operator==(NormInitializerAttrs const &) const; bool operator!=(NormInitializerAttrs const &) const; @@ -36,23 +38,23 @@ struct NormInitializerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::NormInitializerAttrs const &) const; +struct hash<::FlexFlow::NormInitializerAttrs> { + size_t operator()(::FlexFlow::NormInitializerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::NormInitializerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::NormInitializerAttrs const &); +struct adl_serializer<::FlexFlow::NormInitializerAttrs> { + static ::FlexFlow::NormInitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::NormInitializerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::NormInitializerAttrs> { + static Gen<::FlexFlow::NormInitializerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h index 1f4deada06..9493d2ffff 100644 --- a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h @@ -19,9 +19,9 @@ namespace FlexFlow { struct UniformInitializerAttrs { UniformInitializerAttrs() = delete; - UniformInitializerAttrs(int const &seed, - float const &min_val, - float const &max_val); + explicit UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val); bool operator==(UniformInitializerAttrs const &) const; bool operator!=(UniformInitializerAttrs const &) const; @@ -37,16 +37,16 @@ struct UniformInitializerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::UniformInitializerAttrs const &) const; +struct hash<::FlexFlow::UniformInitializerAttrs> { + size_t operator()(::FlexFlow::UniformInitializerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::UniformInitializerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::UniformInitializerAttrs const &); +struct adl_serializer<::FlexFlow::UniformInitializerAttrs> { + static ::FlexFlow::UniformInitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::UniformInitializerAttrs const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h index f3086ea087..7a4a8ccd1f 100644 --- a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h @@ -30,23 +30,23 @@ struct ZeroInitializerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ZeroInitializerAttrs const &) const; +struct hash<::FlexFlow::ZeroInitializerAttrs> { + size_t operator()(::FlexFlow::ZeroInitializerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ZeroInitializerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ZeroInitializerAttrs const &); +struct adl_serializer<::FlexFlow::ZeroInitializerAttrs> { + static ::FlexFlow::ZeroInitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ZeroInitializerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ZeroInitializerAttrs> { + static Gen<::FlexFlow::ZeroInitializerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.h b/lib/pcg/include/pcg/layer_attrs.dtg.h index 6afa1757dc..d856bc1401 100644 --- a/lib/pcg/include/pcg/layer_attrs.dtg.h +++ b/lib/pcg/include/pcg/layer_attrs.dtg.h @@ -23,8 +23,9 @@ namespace FlexFlow { struct LayerAttrs { LayerAttrs() = delete; - LayerAttrs(::FlexFlow::ComputationGraphOpAttrs const &attrs, - std::optional<::FlexFlow::stack_string> const &name); + explicit LayerAttrs( + ::FlexFlow::ComputationGraphOpAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); bool operator==(LayerAttrs const &) const; bool operator!=(LayerAttrs const &) const; @@ -39,16 +40,16 @@ struct LayerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::LayerAttrs const &) const; +struct hash<::FlexFlow::LayerAttrs> { + size_t operator()(::FlexFlow::LayerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::LayerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::LayerAttrs const &); +struct adl_serializer<::FlexFlow::LayerAttrs> { + static ::FlexFlow::LayerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::LayerAttrs const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.h b/lib/pcg/include/pcg/layer_guid_t.dtg.h index 4bbdd36fed..9b0e3338d9 100644 --- a/lib/pcg/include/pcg/layer_guid_t.dtg.h +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct layer_guid_t { layer_guid_t() = delete; - layer_guid_t(::FlexFlow::Node const &raw_node); + explicit layer_guid_t(::FlexFlow::Node const &raw_node); bool operator==(layer_guid_t const &) const; bool operator!=(layer_guid_t const &) const; @@ -33,8 +33,8 @@ struct layer_guid_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::layer_guid_t const &) const; +struct hash<::FlexFlow::layer_guid_t> { + size_t operator()(::FlexFlow::layer_guid_t const &) const; }; } // namespace std diff --git a/lib/pcg/include/pcg/machine_specification.dtg.h b/lib/pcg/include/pcg/machine_specification.dtg.h index cd6ffe6c0f..8b75a6dcb4 100644 --- a/lib/pcg/include/pcg/machine_specification.dtg.h +++ b/lib/pcg/include/pcg/machine_specification.dtg.h @@ -19,11 +19,11 @@ namespace FlexFlow { struct MachineSpecification { MachineSpecification() = delete; - MachineSpecification(int const &num_nodes, - int const &num_cpus_per_node, - int const &num_gpus_per_node, - float const &inter_node_bandwidth, - float const &intra_node_bandwidth); + explicit MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth); bool operator==(MachineSpecification const &) const; bool operator!=(MachineSpecification const &) const; @@ -41,16 +41,16 @@ struct MachineSpecification { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::MachineSpecification const &) const; +struct hash<::FlexFlow::MachineSpecification> { + size_t operator()(::FlexFlow::MachineSpecification const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::MachineSpecification from_json(json const &); - static void to_json(json &, FlexFlow::MachineSpecification const &); +struct adl_serializer<::FlexFlow::MachineSpecification> { + static ::FlexFlow::MachineSpecification from_json(json const &); + static void to_json(json &, ::FlexFlow::MachineSpecification const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/machine_view.dtg.h b/lib/pcg/include/pcg/machine_view.dtg.h index 2eae6e2c8b..2f058bacc5 100644 --- a/lib/pcg/include/pcg/machine_view.dtg.h +++ b/lib/pcg/include/pcg/machine_view.dtg.h @@ -21,8 +21,8 @@ namespace FlexFlow { struct MachineView { MachineView() = delete; - MachineView(::FlexFlow::device_id_t const &start, - ::FlexFlow::StridedRectangle const &rect); + explicit MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect); bool operator==(MachineView const &) const; bool operator!=(MachineView const &) const; @@ -37,16 +37,16 @@ struct MachineView { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::MachineView const &) const; +struct hash<::FlexFlow::MachineView> { + size_t operator()(::FlexFlow::MachineView const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::MachineView from_json(json const &); - static void to_json(json &, FlexFlow::MachineView const &); +struct adl_serializer<::FlexFlow::MachineView> { + static ::FlexFlow::MachineView from_json(json const &); + static void to_json(json &, ::FlexFlow::MachineView const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/num_points_t.dtg.h b/lib/pcg/include/pcg/num_points_t.dtg.h index 3b8e0e0c6c..52c2af8e7f 100644 --- a/lib/pcg/include/pcg/num_points_t.dtg.h +++ b/lib/pcg/include/pcg/num_points_t.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct num_points_t { num_points_t() = delete; - num_points_t(int const &unwrapped); + explicit num_points_t(int const &unwrapped); bool operator==(num_points_t const &) const; bool operator!=(num_points_t const &) const; @@ -34,23 +34,23 @@ struct num_points_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::num_points_t const &) const; +struct hash<::FlexFlow::num_points_t> { + size_t operator()(::FlexFlow::num_points_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::num_points_t from_json(json const &); - static void to_json(json &, FlexFlow::num_points_t const &); +struct adl_serializer<::FlexFlow::num_points_t> { + static ::FlexFlow::num_points_t from_json(json const &); + static void to_json(json &, ::FlexFlow::num_points_t const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::num_points_t> { + static Gen<::FlexFlow::num_points_t> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h index 13904f220d..f0bedc0f3d 100644 --- a/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct OperatorGraphInput { OperatorGraphInput() = delete; - OperatorGraphInput(::FlexFlow::Node const &node, int const &idx); + explicit OperatorGraphInput(::FlexFlow::Node const &node, int const &idx); bool operator==(OperatorGraphInput const &) const; bool operator!=(OperatorGraphInput const &) const; @@ -34,8 +34,8 @@ struct OperatorGraphInput { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorGraphInput const &) const; +struct hash<::FlexFlow::OperatorGraphInput> { + size_t operator()(::FlexFlow::OperatorGraphInput const &) const; }; } // namespace std diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h index 40bdc245b8..4a99eba273 100644 --- a/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct OperatorGraphOutput { OperatorGraphOutput() = delete; - OperatorGraphOutput(::FlexFlow::Node const &node, int const &idx); + explicit OperatorGraphOutput(::FlexFlow::Node const &node, int const &idx); bool operator==(OperatorGraphOutput const &) const; bool operator!=(OperatorGraphOutput const &) const; @@ -34,8 +34,8 @@ struct OperatorGraphOutput { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorGraphOutput const &) const; +struct hash<::FlexFlow::OperatorGraphOutput> { + size_t operator()(::FlexFlow::OperatorGraphOutput const &) const; }; } // namespace std diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h index a5a6a5ed0a..1dfbb4a4e1 100644 --- a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h @@ -20,13 +20,13 @@ namespace FlexFlow { struct AdamOptimizerAttrs { AdamOptimizerAttrs() = delete; - AdamOptimizerAttrs(double const &alpha, - double const &beta1, - double const &beta2, - double const &weight_decay, - double const &alpha_t, - double const &beta_t, - double const &beta2_t); + explicit AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t); bool operator==(AdamOptimizerAttrs const &) const; bool operator!=(AdamOptimizerAttrs const &) const; @@ -46,23 +46,23 @@ struct AdamOptimizerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::AdamOptimizerAttrs const &) const; +struct hash<::FlexFlow::AdamOptimizerAttrs> { + size_t operator()(::FlexFlow::AdamOptimizerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::AdamOptimizerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::AdamOptimizerAttrs const &); +struct adl_serializer<::FlexFlow::AdamOptimizerAttrs> { + static ::FlexFlow::AdamOptimizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::AdamOptimizerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::AdamOptimizerAttrs> { + static Gen<::FlexFlow::AdamOptimizerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h index f6a17f2354..5fa33bfbe7 100644 --- a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h @@ -20,10 +20,10 @@ namespace FlexFlow { struct SGDOptimizerAttrs { SGDOptimizerAttrs() = delete; - SGDOptimizerAttrs(double const &lr, - double const &momentum, - bool const &nesterov, - double const &weight_decay); + explicit SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay); bool operator==(SGDOptimizerAttrs const &) const; bool operator!=(SGDOptimizerAttrs const &) const; @@ -40,23 +40,23 @@ struct SGDOptimizerAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::SGDOptimizerAttrs const &) const; +struct hash<::FlexFlow::SGDOptimizerAttrs> { + size_t operator()(::FlexFlow::SGDOptimizerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::SGDOptimizerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::SGDOptimizerAttrs const &); +struct adl_serializer<::FlexFlow::SGDOptimizerAttrs> { + static ::FlexFlow::SGDOptimizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::SGDOptimizerAttrs const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::SGDOptimizerAttrs> { + static Gen<::FlexFlow::SGDOptimizerAttrs> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h deleted file mode 100644 index aae5122671..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H - -#include "pcg/parallel_computation_graph.dtg.h" -#include "pcg/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -ParallelComputationGraph empty_parallel_computation_graph(); - -std::unordered_set get_parallel_layers(ParallelComputationGraph const &); - -ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); - -} - -#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h similarity index 61% rename from lib/pcg/include/pcg/parallel_computation_graph.dtg.h rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h index 01fbb7d30c..a6f9f9455e 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h @@ -1,23 +1,23 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" + "generated_from": "1339be6e86e9818c36d6ecf5475e2d4b" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_DTG_H -#include "pcg/dataflow_graph.h" -#include "pcg/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/dataflow_graph/dataflow_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" namespace FlexFlow { struct ParallelComputationGraph { ParallelComputationGraph() = delete; - ParallelComputationGraph( + explicit ParallelComputationGraph( ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs> const &raw_graph); @@ -28,4 +28,4 @@ struct ParallelComputationGraph { }; } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h new file mode 100644 index 0000000000..6dda689d35 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +ParallelComputationGraph empty_parallel_computation_graph(); + +std::unordered_set + get_parallel_layers(ParallelComputationGraph const &); + +std::vector + get_layer_inputs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_layer_outputs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + +parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); +ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml similarity index 56% rename from lib/pcg/include/pcg/parallel_computation_graph.struct.toml rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml index d4e305abe5..759a8424d5 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml @@ -3,9 +3,9 @@ name = "ParallelComputationGraph" features = [ ] includes = [ - "pcg/dataflow_graph.h", - "pcg/parallel_tensor_attrs.dtg.h", - "pcg/parallel_layer_attrs.dtg.h", + "pcg/dataflow_graph/dataflow_graph.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h new file mode 100644 index 0000000000..cdeb846af3 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -0,0 +1,144 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include + +namespace FlexFlow { + +struct ParallelComputationGraphBuilder { +public: + ParallelComputationGraphBuilder(); + + parallel_tensor_guid_t create_input_tensor( + ParallelTensorShape const &shape, + bool create_grad = true, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + add(parallel_tensor_guid_t const &lhs, + parallel_tensor_guid_t const &rhs, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + batch_matmul(parallel_tensor_guid_t const &a, + parallel_tensor_guid_t const &b, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + cast(parallel_tensor_guid_t const &input, + DataType result_type, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t conv2d( + parallel_tensor_guid_t const &input, + int outChannels, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + std::optional const &activation = std::nullopt, + int groups = 1, + bool use_bias = true, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_regularizer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t dense( + parallel_tensor_guid_t const &input, + int outDim, + std::optional activation = std::nullopt, + bool use_bias = true, + DataType data_type = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t embedding( + parallel_tensor_guid_t const &input, + int num_entries, + int outDim, + AggregateOp aggr, + DataType dtype = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t multihead_attention( + parallel_tensor_guid_t const &query, + parallel_tensor_guid_t const &key, + parallel_tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = true, + bool add_bias_kv = false, + bool add_zero_attn = false, + std::optional initializer = std::nullopt, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + relu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + parallel_partition(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t + parallel_combine(parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t + parallel_replicate(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t + parallel_reduce(parallel_tensor_guid_t const &x, + int degree, + std::optional const &name = std::nullopt); + +private: + parallel_tensor_guid_t as_type(parallel_tensor_guid_t const &, + DataType, + std::string const &name); + +private: + ParallelTensorShape get_shape(parallel_tensor_guid_t const &) const; + +private: + std::vector + add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + std::vector + add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &output); + parallel_tensor_guid_t + add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorAttrs const &output); + parallel_tensor_guid_t + add_layer(ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorShape const &output); + +public: + ParallelComputationGraph pcg; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h similarity index 58% rename from lib/pcg/include/pcg/parallel_layer_attrs.dtg.h rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h index 4c7fce4038..cf0011d4ba 100644 --- a/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml /* proj-data { - "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" + "generated_from": "9bb6e3cb7b0e523fae8f33bd8ad80d6d" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -22,8 +22,8 @@ namespace FlexFlow { struct ParallelLayerAttrs { ParallelLayerAttrs() = delete; - ParallelLayerAttrs( - ::FlexFlow::PCGOperatorAttrs const &attrs, + explicit ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &op_attrs, std::optional<::FlexFlow::stack_string> const &name); bool operator==(ParallelLayerAttrs const &) const; @@ -32,23 +32,23 @@ struct ParallelLayerAttrs { bool operator>(ParallelLayerAttrs const &) const; bool operator<=(ParallelLayerAttrs const &) const; bool operator>=(ParallelLayerAttrs const &) const; - ::FlexFlow::PCGOperatorAttrs attrs; + ::FlexFlow::PCGOperatorAttrs op_attrs; std::optional<::FlexFlow::stack_string> name; }; } // namespace FlexFlow namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelLayerAttrs const &) const; +struct hash<::FlexFlow::ParallelLayerAttrs> { + size_t operator()(::FlexFlow::ParallelLayerAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelLayerAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ParallelLayerAttrs const &); +struct adl_serializer<::FlexFlow::ParallelLayerAttrs> { + static ::FlexFlow::ParallelLayerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelLayerAttrs const &); }; } // namespace nlohmann @@ -57,4 +57,4 @@ std::string format_as(ParallelLayerAttrs const &); std::ostream &operator<<(std::ostream &, ParallelLayerAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.h new file mode 100644 index 0000000000..2b1a082a85 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_H + +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" + +namespace FlexFlow { + +OperatorType get_op_type(ParallelLayerAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml similarity index 95% rename from lib/pcg/include/pcg/parallel_layer_attrs.struct.toml rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 9b1f8f47aa..f3f3c6a8bb 100644 --- a/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -16,7 +16,7 @@ includes = [ ] [[fields]] -name = "attrs" +name = "op_attrs" type = "::FlexFlow::PCGOperatorAttrs" [[fields]] diff --git a/lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h similarity index 64% rename from lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h index 8fc81cee05..c204a5f95c 100644 --- a/lib/pcg/include/pcg/parallel_layer_guid_t.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml /* proj-data { "generated_from": "c31301efeb92e151b04943786aa7bec1" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_GUID_T_DTG_H #include "fmt/format.h" #include "utils/graph.h" @@ -19,7 +19,7 @@ namespace FlexFlow { struct parallel_layer_guid_t { parallel_layer_guid_t() = delete; - parallel_layer_guid_t(::FlexFlow::Node const &raw_graph_node); + explicit parallel_layer_guid_t(::FlexFlow::Node const &raw_graph_node); bool operator==(parallel_layer_guid_t const &) const; bool operator!=(parallel_layer_guid_t const &) const; @@ -33,8 +33,8 @@ struct parallel_layer_guid_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::parallel_layer_guid_t const &) const; +struct hash<::FlexFlow::parallel_layer_guid_t> { + size_t operator()(::FlexFlow::parallel_layer_guid_t const &) const; }; } // namespace std @@ -43,4 +43,4 @@ std::string format_as(parallel_layer_guid_t const &); std::ostream &operator<<(std::ostream &, parallel_layer_guid_t const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_GUID_T_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml similarity index 100% rename from lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h similarity index 69% rename from lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h index fa6b153b0a..a9dfb1d163 100644 --- a/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml /* proj-data { "generated_from": "b3e086b380bbc41d99332e1463a34b28" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -24,7 +24,7 @@ namespace FlexFlow { struct ParallelTensorAttrs { ParallelTensorAttrs() = delete; - ParallelTensorAttrs( + explicit ParallelTensorAttrs( ::FlexFlow::ParallelTensorShape const &shape, std::optional<::FlexFlow::ParamSync> const &sync_type, std::optional<::FlexFlow::InitializerAttrs> const &initializer, @@ -45,16 +45,16 @@ struct ParallelTensorAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelTensorAttrs const &) const; +struct hash<::FlexFlow::ParallelTensorAttrs> { + size_t operator()(::FlexFlow::ParallelTensorAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelTensorAttrs from_json(json const &); - static void to_json(json &, FlexFlow::ParallelTensorAttrs const &); +struct adl_serializer<::FlexFlow::ParallelTensorAttrs> { + static ::FlexFlow::ParallelTensorAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelTensorAttrs const &); }; } // namespace nlohmann @@ -63,4 +63,4 @@ std::string format_as(ParallelTensorAttrs const &); std::ostream &operator<<(std::ostream &, ParallelTensorAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml similarity index 100% rename from lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml diff --git a/lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h similarity index 64% rename from lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h index 4041544903..55a1ebcc75 100644 --- a/lib/pcg/include/pcg/parallel_tensor_guid_t.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml /* proj-data { "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_DTG_H #include "fmt/format.h" #include "utils/graph/multidiedge.h" @@ -19,7 +19,8 @@ namespace FlexFlow { struct parallel_tensor_guid_t { parallel_tensor_guid_t() = delete; - parallel_tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + explicit parallel_tensor_guid_t( + ::FlexFlow::MultiDiOutput const &raw_graph_output); bool operator==(parallel_tensor_guid_t const &) const; bool operator!=(parallel_tensor_guid_t const &) const; @@ -33,8 +34,8 @@ struct parallel_tensor_guid_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::parallel_tensor_guid_t const &) const; +struct hash<::FlexFlow::parallel_tensor_guid_t> { + size_t operator()(::FlexFlow::parallel_tensor_guid_t const &) const; }; } // namespace std @@ -43,4 +44,4 @@ std::string format_as(parallel_tensor_guid_t const &); std::ostream &operator<<(std::ostream &, parallel_tensor_guid_t const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_GUID_T_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml similarity index 100% rename from lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml rename to lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml diff --git a/lib/pcg/include/pcg/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph_builder.h deleted file mode 100644 index 6e21110e0e..0000000000 --- a/lib/pcg/include/pcg/parallel_computation_graph_builder.h +++ /dev/null @@ -1,127 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_BUILDER_H - -#include "pcg/parallel_computation_graph.dtg.h" -#include "pcg/parallel_tensor_guid_t.dtg.h" -#include - -namespace FlexFlow { - -struct ParallelComputationGraphBuilder { -public: - ParallelComputationGraphBuilder(); - - parallel_tensor_guid_t create_input_tensor(ParallelTensorShape const &shape, - bool create_grad = true, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t add(parallel_tensor_guid_t const &lhs, - parallel_tensor_guid_t const &rhs, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t - batch_matmul(parallel_tensor_guid_t const &a, - parallel_tensor_guid_t const &b, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t cast(parallel_tensor_guid_t const &input, - DataType result_type, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t conv2d( - parallel_tensor_guid_t const &input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - std::optional const &activation = std::nullopt, - int groups = 1, - bool use_bias = true, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, - std::optional const &kernel_regularizer = std::nullopt, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t dense( - parallel_tensor_guid_t const &input, - int outDim, - std::optional activation = std::nullopt, - bool use_bias = true, - DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t embedding( - parallel_tensor_guid_t const &input, - int num_entries, - int outDim, - AggregateOp aggr, - DataType dtype = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t multihead_attention( - parallel_tensor_guid_t const &query, - parallel_tensor_guid_t const &key, - parallel_tensor_guid_t const &value, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = true, - bool add_bias_kv = false, - bool add_zero_attn = false, - std::optional initializer = std::nullopt, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t relu(parallel_tensor_guid_t const &x, - std::optional const &name = std::nullopt); - - parallel_tensor_guid_t parallel_partition(parallel_tensor_guid_t const &x, - ff_dim_t dim, - int degree, - std::optional const &name = std::nullopt); - parallel_tensor_guid_t parallel_combine(parallel_tensor_guid_t const &x, - ff_dim_t dim, - int degree, - std::optional const &name = std::nullopt); - parallel_tensor_guid_t parallel_replicate(parallel_tensor_guid_t const &x, - int degree, - std::optional const &name = std::nullopt); - parallel_tensor_guid_t parallel_reduce(parallel_tensor_guid_t const &x, - int degree, - std::optional const &name = std::nullopt); - -private: - parallel_tensor_guid_t as_type(parallel_tensor_guid_t const &, DataType, std::string const &name); -private: - ParallelTensorShape get_shape(parallel_tensor_guid_t const &) const; -private: - std::vector add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - std::vector add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &output); - parallel_tensor_guid_t add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - ParallelTensorAttrs const &output); - parallel_tensor_guid_t add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - ParallelTensorShape const &output); -public: - ParallelComputationGraph pcg; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h deleted file mode 100644 index de41e0fb21..0000000000 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * @file parallel_tensor.h - * @brief Parallel Tensor Representation - * - * @copyright Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford - * (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H - -#include "pcg/parallel_tensor_attrs.h" - -namespace FlexFlow {} // namespace FlexFlow - -namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -} - -#endif diff --git a/lib/pcg/include/pcg/side_size_t.dtg.h b/lib/pcg/include/pcg/side_size_t.dtg.h index fce31b1c9d..a0d65a0e6b 100644 --- a/lib/pcg/include/pcg/side_size_t.dtg.h +++ b/lib/pcg/include/pcg/side_size_t.dtg.h @@ -20,7 +20,7 @@ namespace FlexFlow { struct side_size_t { side_size_t() = delete; - side_size_t(int const &unwrapped); + explicit side_size_t(int const &unwrapped); bool operator==(side_size_t const &) const; bool operator!=(side_size_t const &) const; @@ -34,23 +34,23 @@ struct side_size_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::side_size_t const &) const; +struct hash<::FlexFlow::side_size_t> { + size_t operator()(::FlexFlow::side_size_t const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::side_size_t from_json(json const &); - static void to_json(json &, FlexFlow::side_size_t const &); +struct adl_serializer<::FlexFlow::side_size_t> { + static ::FlexFlow::side_size_t from_json(json const &); + static void to_json(json &, ::FlexFlow::side_size_t const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::side_size_t> { + static Gen<::FlexFlow::side_size_t> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/strided_rectangle.dtg.h b/lib/pcg/include/pcg/strided_rectangle.dtg.h index df6a16a0ad..932c139f91 100644 --- a/lib/pcg/include/pcg/strided_rectangle.dtg.h +++ b/lib/pcg/include/pcg/strided_rectangle.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct StridedRectangle { StridedRectangle() = delete; - StridedRectangle( + explicit StridedRectangle( ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides); bool operator==(StridedRectangle const &) const; @@ -37,23 +37,23 @@ struct StridedRectangle { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::StridedRectangle const &) const; +struct hash<::FlexFlow::StridedRectangle> { + size_t operator()(::FlexFlow::StridedRectangle const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::StridedRectangle from_json(json const &); - static void to_json(json &, FlexFlow::StridedRectangle const &); +struct adl_serializer<::FlexFlow::StridedRectangle> { + static ::FlexFlow::StridedRectangle from_json(json const &); + static void to_json(json &, ::FlexFlow::StridedRectangle const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::StridedRectangle> { + static Gen<::FlexFlow::StridedRectangle> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/strided_rectangle_side.dtg.h b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h index 3e4365c24d..9b9347a7aa 100644 --- a/lib/pcg/include/pcg/strided_rectangle_side.dtg.h +++ b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h @@ -21,8 +21,8 @@ namespace FlexFlow { struct StridedRectangleSide { StridedRectangleSide() = delete; - StridedRectangleSide(::FlexFlow::num_points_t const &num_points, - int const &stride); + explicit StridedRectangleSide(::FlexFlow::num_points_t const &num_points, + int const &stride); bool operator==(StridedRectangleSide const &) const; bool operator!=(StridedRectangleSide const &) const; @@ -37,23 +37,23 @@ struct StridedRectangleSide { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::StridedRectangleSide const &) const; +struct hash<::FlexFlow::StridedRectangleSide> { + size_t operator()(::FlexFlow::StridedRectangleSide const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::StridedRectangleSide from_json(json const &); - static void to_json(json &, FlexFlow::StridedRectangleSide const &); +struct adl_serializer<::FlexFlow::StridedRectangleSide> { + static ::FlexFlow::StridedRectangleSide from_json(json const &); + static void to_json(json &, ::FlexFlow::StridedRectangleSide const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::StridedRectangleSide> { + static Gen<::FlexFlow::StridedRectangleSide> arbitrary(); }; } // namespace rc diff --git a/lib/pcg/include/pcg/tensor_attrs.dtg.h b/lib/pcg/include/pcg/tensor_attrs.dtg.h index 8bc9d3ce9d..38b18c9885 100644 --- a/lib/pcg/include/pcg/tensor_attrs.dtg.h +++ b/lib/pcg/include/pcg/tensor_attrs.dtg.h @@ -23,10 +23,11 @@ namespace FlexFlow { struct TensorAttrs { TensorAttrs() = delete; - TensorAttrs(::FlexFlow::TensorShape const &shape, - std::optional<::FlexFlow::InitializerAttrs> const &initializer, - bool const &create_gradients, - std::optional<::FlexFlow::ParamSync> const &sync_type); + explicit TensorAttrs( + ::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type); bool operator==(TensorAttrs const &) const; bool operator!=(TensorAttrs const &) const; @@ -43,16 +44,16 @@ struct TensorAttrs { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorAttrs const &) const; +struct hash<::FlexFlow::TensorAttrs> { + size_t operator()(::FlexFlow::TensorAttrs const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorAttrs from_json(json const &); - static void to_json(json &, FlexFlow::TensorAttrs const &); +struct adl_serializer<::FlexFlow::TensorAttrs> { + static ::FlexFlow::TensorAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttrs const &); }; } // namespace nlohmann diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h index f9841a4d06..3026c2169e 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.dtg.h +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct tensor_guid_t { tensor_guid_t() = delete; - tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + explicit tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); bool operator==(tensor_guid_t const &) const; bool operator!=(tensor_guid_t const &) const; @@ -33,8 +33,8 @@ struct tensor_guid_t { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::tensor_guid_t const &) const; +struct hash<::FlexFlow::tensor_guid_t> { + size_t operator()(::FlexFlow::tensor_guid_t const &) const; }; } // namespace std diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index eabd266e25..8317c9ec6e 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,5 +1,5 @@ #include "pcg/file_format/v1/graphs.h" -#include "pcg/dataflow_graph.h" +#include "pcg/dataflow_graph/dataflow_graph.h" #include "pcg/file_format/v1/graphs/v1_multidigraph.h" #include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" #include "utils/graph/algorithms.h" @@ -30,10 +30,10 @@ static V1MultiDiGraph to_v1(MultiDiGraphView const &g, bidict const &node_ports) { std::unordered_set edges; for (MultiDiEdge const &e : get_edges(g)) { - edges.insert({nodes.at_l(e.src), - node_ports.at_l(e.src_idx), - nodes.at_l(e.dst), - node_ports.at_l(e.dst_idx)}); + edges.insert(V1GraphEdge{nodes.at_l(e.src), + node_ports.at_l(e.src_idx), + nodes.at_l(e.dst), + node_ports.at_l(e.dst_idx)}); } return V1MultiDiGraph{ @@ -107,7 +107,8 @@ static V1JsonableGraph std::unordered_map output_labels = map_values( outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); - return {node_labels, outputs, output_labels, unlabelled}; + return V1JsonableGraph{ + node_labels, outputs, output_labels, unlabelled}; } template @@ -129,7 +130,8 @@ static V1JsonableGraph std::unordered_map output_labels = map_values( outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); - return {node_labels, outputs, output_labels, unlabelled}; + return V1JsonableGraph{ + node_labels, outputs, output_labels, unlabelled}; } V1ComputationGraph to_v1(ComputationGraph const &g) { diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc index bb6233a910..799cf55908 100644 --- a/lib/pcg/src/pcg/computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -3,13 +3,13 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "8f1f0e13d75065944f7fe307e12fe280" + "generated_from": "bf8996bea2e022265a372d692c2db8ed" } */ #include "pcg/computation_graph.dtg.h" -#include "pcg/dataflow_graph.h" +#include "pcg/dataflow_graph/dataflow_graph.h" #include "pcg/layer_attrs.dtg.h" #include "pcg/tensor_attrs.dtg.h" diff --git a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc index 18b394f6d0..1d00b4f32e 100644 --- a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml /* proj-data { - "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" + "generated_from": "234b5c222ae4ce1da36194b4eb519145" } */ @@ -11,6 +11,7 @@ #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_guid_t.dtg.h" +#include "utils/fmt/vector.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 8c69b3a724..780e4b7cc6 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -23,7 +23,8 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, bool create_grad) { - TensorAttrs tensor_attrs = {shape, std::nullopt, create_grad, std::nullopt}; + TensorAttrs tensor_attrs = + TensorAttrs{shape, std::nullopt, create_grad, std::nullopt}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, std::nullopt, @@ -52,16 +53,20 @@ std::vector ComputationGraphBuilder::add_layer( }; std::vector weight_layer_inputs = {}; std::vector weight_output_attrs = {weight_tensor_attrs}; - raw_weight_tensors.push_back( - get_only(this->computation_graph.raw_graph.add_operator( - weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); + raw_weight_tensors.push_back(get_only(this->computation_graph.raw_graph + .add_operator(weight_layer_attrs, + weight_layer_inputs, + weight_output_attrs) + .outputs)); } std::vector raw_inputs = transform( inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); std::vector raw_outputs = - this->computation_graph.raw_graph.add_operator( - layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + this->computation_graph.raw_graph + .add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) + .outputs; return transform(raw_outputs, [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); } @@ -160,7 +165,7 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); @@ -172,7 +177,7 @@ tensor_guid_t ComputationGraphBuilder::element_unary( OperatorType op_type, tensor_guid_t const &input, std::optional const &name) { - ElementUnaryAttrs attrs = {op_type}; + ElementUnaryAttrs attrs = ElementUnaryAttrs{op_type}; return this->element_unary(attrs, input, name); } @@ -181,7 +186,7 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( tensor_guid_t const &input, float scalar, std::optional const &name) { - ElementScalarUnaryAttrs attrs = {op_type, scalar}; + ElementScalarUnaryAttrs attrs = ElementScalarUnaryAttrs{op_type, scalar}; return this->element_scalar_unary(attrs, input, name); } @@ -203,9 +208,10 @@ tensor_guid_t ComputationGraphBuilder::element_binary( compute_type, name + "_inputr_pre_cast"); - ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; + ElementBinaryAttrs attrs = + ElementBinaryAttrs{op_type, compute_type, false, false}; - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); @@ -375,16 +381,16 @@ tensor_guid_t ComputationGraphBuilder::conv2d( std::optional const &bias_initializer, std::optional const &kernel_regularizer, std::optional const &maybe_name) { - Conv2DAttrs attrs = {outChannels, - kernelH, - kernelW, - strideH, - strideW, - paddingH, - paddingW, - groups, - activation, - use_bias}; + Conv2DAttrs attrs = Conv2DAttrs{outChannels, + kernelH, + kernelW, + strideH, + strideW, + paddingH, + paddingW, + groups, + activation, + use_bias}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); @@ -392,7 +398,7 @@ tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape input_shape = this->get_shape(input); TensorShape output_shape = get_output_shape(attrs, input_shape); @@ -415,11 +421,11 @@ tensor_guid_t ComputationGraphBuilder::dropout( float rate, unsigned long long seed, std::optional const &maybe_name) { - DropoutAttrs attrs = {rate, seed}; + DropoutAttrs attrs = DropoutAttrs{rate, seed}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -436,11 +442,11 @@ tensor_guid_t ComputationGraphBuilder::embedding( DataType dtype, std::optional const &kernel_initializer, std::optional const &maybe_name) { - EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; + EmbeddingAttrs attrs = EmbeddingAttrs{num_entries, outDim, aggr, dtype}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -461,11 +467,11 @@ std::vector ComputationGraphBuilder::gather( tensor_guid_t const &index, ff_dim_t dim, std::optional const &maybe_name) { - GatherAttrs attrs = {dim}; + GatherAttrs attrs = GatherAttrs{dim}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " @@ -521,7 +527,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); diff --git a/lib/pcg/src/pcg/cpu_id_t.dtg.cc b/lib/pcg/src/pcg/cpu_id_t.dtg.cc index f865442eb0..ba8f8cc164 100644 --- a/lib/pcg/src/pcg/cpu_id_t.dtg.cc +++ b/lib/pcg/src/pcg/cpu_id_t.dtg.cc @@ -34,7 +34,8 @@ bool cpu_id_t::operator>=(cpu_id_t const &other) const { } // namespace FlexFlow namespace std { -size_t hash::operator()(FlexFlow::cpu_id_t const &x) const { +size_t + hash::operator()(::FlexFlow::cpu_id_t const &x) const { size_t result = 0; result ^= std::hash{}(x.cpu_index) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -43,20 +44,20 @@ size_t hash::operator()(FlexFlow::cpu_id_t const &x) const { } // namespace std namespace nlohmann { -FlexFlow::cpu_id_t - adl_serializer::from_json(json const &j) { - return {j.at("cpu_index").template get()}; +::FlexFlow::cpu_id_t + adl_serializer<::FlexFlow::cpu_id_t>::from_json(json const &j) { + return ::FlexFlow::cpu_id_t{j.at("cpu_index").template get()}; } -void adl_serializer::to_json(json &j, - FlexFlow::cpu_id_t const &v) { +void adl_serializer<::FlexFlow::cpu_id_t>::to_json( + json &j, ::FlexFlow::cpu_id_t const &v) { j["__type"] = "cpu_id_t"; j["cpu_index"] = v.cpu_index; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::cpu_id_t> Arbitrary<::FlexFlow::cpu_id_t>::arbitrary() { + return gen::construct<::FlexFlow::cpu_id_t>(gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc b/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc new file mode 100644 index 0000000000..d4b926c0a6 --- /dev/null +++ b/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml +/* proj-data +{ + "generated_from": "62224733c501773b41f1fc63a8677949" +} +*/ + +#include "pcg/dataflow_graph/operator_added_result.dtg.h" + +#include "utils/fmt/vector.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +OperatorAddedResult::OperatorAddedResult( + ::FlexFlow::Node const &node, + std::vector<::FlexFlow::MultiDiOutput> const &outputs) + : node(node), outputs(outputs) {} +bool OperatorAddedResult::operator==(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) == + std::tie(other.node, other.outputs); +} +bool OperatorAddedResult::operator!=(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) != + std::tie(other.node, other.outputs); +} +bool OperatorAddedResult::operator<(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) < + std::tie(other.node, other.outputs); +} +bool OperatorAddedResult::operator>(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) > + std::tie(other.node, other.outputs); +} +bool OperatorAddedResult::operator<=(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) <= + std::tie(other.node, other.outputs); +} +bool OperatorAddedResult::operator>=(OperatorAddedResult const &other) const { + return std::tie(this->node, this->outputs) >= + std::tie(other.node, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(OperatorAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc index 713aa941d2..28a0a2e861 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc @@ -45,7 +45,7 @@ bool V1GraphEdge::operator>=(V1GraphEdge const &other) const { namespace std { size_t hash::operator()( - FlexFlow::V1GraphEdge const &x) const { + ::FlexFlow::V1GraphEdge const &x) const { size_t result = 0; result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -60,15 +60,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::V1GraphEdge - adl_serializer::from_json(json const &j) { - return {j.at("srcNode").template get(), - j.at("srcIdx").template get(), - j.at("dstNode").template get(), - j.at("dstIdx").template get()}; +::FlexFlow::V1GraphEdge + adl_serializer<::FlexFlow::V1GraphEdge>::from_json(json const &j) { + return ::FlexFlow::V1GraphEdge{j.at("srcNode").template get(), + j.at("srcIdx").template get(), + j.at("dstNode").template get(), + j.at("dstIdx").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::V1GraphEdge const &v) { +void adl_serializer<::FlexFlow::V1GraphEdge>::to_json( + json &j, ::FlexFlow::V1GraphEdge const &v) { j["__type"] = "V1GraphEdge"; j["srcNode"] = v.srcNode; j["srcIdx"] = v.srcIdx; diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc index fa0b792a37..f4e2ecf0e1 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc @@ -42,7 +42,7 @@ bool V1GraphOutput::operator>=(V1GraphOutput const &other) const { namespace std { size_t hash::operator()( - FlexFlow::V1GraphOutput const &x) const { + ::FlexFlow::V1GraphOutput const &x) const { size_t result = 0; result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -53,13 +53,13 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::V1GraphOutput - adl_serializer::from_json(json const &j) { - return {j.at("srcNode").template get(), - j.at("srcIdx").template get()}; +::FlexFlow::V1GraphOutput + adl_serializer<::FlexFlow::V1GraphOutput>::from_json(json const &j) { + return ::FlexFlow::V1GraphOutput{j.at("srcNode").template get(), + j.at("srcIdx").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::V1GraphOutput const &v) { +void adl_serializer<::FlexFlow::V1GraphOutput>::to_json( + json &j, ::FlexFlow::V1GraphOutput const &v) { j["__type"] = "V1GraphOutput"; j["srcNode"] = v.srcNode; j["srcIdx"] = v.srcIdx; diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc index 0f5a83b02f..41ad9e4e63 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc @@ -3,14 +3,15 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml /* proj-data { - "generated_from": "fb1033385645e54a19c9b44cef0be04b" + "generated_from": "582054edb983c3cc31d9273ce29421eb" } */ #include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" #include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include #include #include @@ -24,15 +25,16 @@ V1MultiDiGraph::V1MultiDiGraph( } // namespace FlexFlow namespace nlohmann { -FlexFlow::V1MultiDiGraph - adl_serializer::from_json(json const &j) { - return {j.at("nodes").template get>(), - j.at("ports").template get>(), - j.at("edges") - .template get>()}; +::FlexFlow::V1MultiDiGraph + adl_serializer<::FlexFlow::V1MultiDiGraph>::from_json(json const &j) { + return ::FlexFlow::V1MultiDiGraph{ + j.at("nodes").template get>(), + j.at("ports").template get>(), + j.at("edges") + .template get>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::V1MultiDiGraph const &v) { +void adl_serializer<::FlexFlow::V1MultiDiGraph>::to_json( + json &j, ::FlexFlow::V1MultiDiGraph const &v) { j["__type"] = "V1MultiDiGraph"; j["nodes"] = v.nodes; j["ports"] = v.ports; diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc index 19f1e09d07..4c908ae2f1 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc @@ -3,14 +3,15 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml /* proj-data { - "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" + "generated_from": "fed215ca219af1bd375801eb2e33b473" } */ #include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" #include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include #include #include @@ -23,14 +24,15 @@ V1OperatorGraph::V1OperatorGraph( } // namespace FlexFlow namespace nlohmann { -FlexFlow::V1OperatorGraph - adl_serializer::from_json(json const &j) { - return {j.at("nodes").template get>(), - j.at("edges") - .template get>()}; +::FlexFlow::V1OperatorGraph + adl_serializer<::FlexFlow::V1OperatorGraph>::from_json(json const &j) { + return ::FlexFlow::V1OperatorGraph{ + j.at("nodes").template get>(), + j.at("edges") + .template get>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::V1OperatorGraph const &v) { +void adl_serializer<::FlexFlow::V1OperatorGraph>::to_json( + json &j, ::FlexFlow::V1OperatorGraph const &v) { j["__type"] = "V1OperatorGraph"; j["nodes"] = v.nodes; j["edges"] = v.edges; diff --git a/lib/pcg/src/pcg/gpu_id_t.dtg.cc b/lib/pcg/src/pcg/gpu_id_t.dtg.cc index e2385a83ce..f82e5c747e 100644 --- a/lib/pcg/src/pcg/gpu_id_t.dtg.cc +++ b/lib/pcg/src/pcg/gpu_id_t.dtg.cc @@ -34,7 +34,8 @@ bool gpu_id_t::operator>=(gpu_id_t const &other) const { } // namespace FlexFlow namespace std { -size_t hash::operator()(FlexFlow::gpu_id_t const &x) const { +size_t + hash::operator()(::FlexFlow::gpu_id_t const &x) const { size_t result = 0; result ^= std::hash{}(x.gpu_index) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -43,20 +44,20 @@ size_t hash::operator()(FlexFlow::gpu_id_t const &x) const { } // namespace std namespace nlohmann { -FlexFlow::gpu_id_t - adl_serializer::from_json(json const &j) { - return {j.at("gpu_index").template get()}; +::FlexFlow::gpu_id_t + adl_serializer<::FlexFlow::gpu_id_t>::from_json(json const &j) { + return ::FlexFlow::gpu_id_t{j.at("gpu_index").template get()}; } -void adl_serializer::to_json(json &j, - FlexFlow::gpu_id_t const &v) { +void adl_serializer<::FlexFlow::gpu_id_t>::to_json( + json &j, ::FlexFlow::gpu_id_t const &v) { j["__type"] = "gpu_id_t"; j["gpu_index"] = v.gpu_index; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::gpu_id_t> Arbitrary<::FlexFlow::gpu_id_t>::arbitrary() { + return gen::construct<::FlexFlow::gpu_id_t>(gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc index 9770c35248..2848d420b7 100644 --- a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -45,7 +45,7 @@ bool ConstantInitializerAttrs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::ConstantInitializerAttrs const &x) const { + ::FlexFlow::ConstantInitializerAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::DataTypeValue>{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -54,13 +54,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ConstantInitializerAttrs - adl_serializer::from_json( +::FlexFlow::ConstantInitializerAttrs + adl_serializer<::FlexFlow::ConstantInitializerAttrs>::from_json( json const &j) { - return {j.at("value").template get<::FlexFlow::DataTypeValue>()}; + return ::FlexFlow::ConstantInitializerAttrs{ + j.at("value").template get<::FlexFlow::DataTypeValue>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ConstantInitializerAttrs const &v) { +void adl_serializer<::FlexFlow::ConstantInitializerAttrs>::to_json( + json &j, ::FlexFlow::ConstantInitializerAttrs const &v) { j["__type"] = "ConstantInitializerAttrs"; j["value"] = v.value; } diff --git a/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc index 0c8ae6e60c..cf2164ed97 100644 --- a/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc @@ -35,7 +35,7 @@ bool GlorotUniformAttrs::operator>=(GlorotUniformAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::GlorotUniformAttrs const &x) const { + ::FlexFlow::GlorotUniformAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,21 +44,21 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::GlorotUniformAttrs - adl_serializer::from_json(json const &j) { - return {j.at("seed").template get()}; +::FlexFlow::GlorotUniformAttrs + adl_serializer<::FlexFlow::GlorotUniformAttrs>::from_json(json const &j) { + return ::FlexFlow::GlorotUniformAttrs{j.at("seed").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::GlorotUniformAttrs const &v) { +void adl_serializer<::FlexFlow::GlorotUniformAttrs>::to_json( + json &j, ::FlexFlow::GlorotUniformAttrs const &v) { j["__type"] = "GlorotUniformAttrs"; j["seed"] = v.seed; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::GlorotUniformAttrs> + Arbitrary<::FlexFlow::GlorotUniformAttrs>::arbitrary() { + return gen::construct<::FlexFlow::GlorotUniformAttrs>(gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc index aceac12212..5d8c2fa02b 100644 --- a/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc @@ -44,7 +44,7 @@ bool NormInitializerAttrs::operator>=(NormInitializerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::NormInitializerAttrs const &x) const { + ::FlexFlow::NormInitializerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -57,14 +57,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::NormInitializerAttrs - adl_serializer::from_json(json const &j) { - return {j.at("seed").template get(), - j.at("mean").template get(), - j.at("stddev").template get()}; +::FlexFlow::NormInitializerAttrs + adl_serializer<::FlexFlow::NormInitializerAttrs>::from_json(json const &j) { + return ::FlexFlow::NormInitializerAttrs{j.at("seed").template get(), + j.at("mean").template get(), + j.at("stddev").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::NormInitializerAttrs const &v) { +void adl_serializer<::FlexFlow::NormInitializerAttrs>::to_json( + json &j, ::FlexFlow::NormInitializerAttrs const &v) { j["__type"] = "NormInitializerAttrs"; j["seed"] = v.seed; j["mean"] = v.mean; @@ -73,9 +73,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::NormInitializerAttrs> + Arbitrary<::FlexFlow::NormInitializerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::NormInitializerAttrs>( gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc index a9c62675d0..4eb3bdc015 100644 --- a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc @@ -50,7 +50,7 @@ bool UniformInitializerAttrs::operator>=( namespace std { size_t hash::operator()( - FlexFlow::UniformInitializerAttrs const &x) const { + ::FlexFlow::UniformInitializerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -63,15 +63,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::UniformInitializerAttrs - adl_serializer::from_json( +::FlexFlow::UniformInitializerAttrs + adl_serializer<::FlexFlow::UniformInitializerAttrs>::from_json( json const &j) { - return {j.at("seed").template get(), - j.at("min_val").template get(), - j.at("max_val").template get()}; + return ::FlexFlow::UniformInitializerAttrs{ + j.at("seed").template get(), + j.at("min_val").template get(), + j.at("max_val").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::UniformInitializerAttrs const &v) { +void adl_serializer<::FlexFlow::UniformInitializerAttrs>::to_json( + json &j, ::FlexFlow::UniformInitializerAttrs const &v) { j["__type"] = "UniformInitializerAttrs"; j["seed"] = v.seed; j["min_val"] = v.min_val; diff --git a/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc index 933501a734..eb88f4c8ff 100644 --- a/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc @@ -34,27 +34,27 @@ bool ZeroInitializerAttrs::operator>=(ZeroInitializerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ZeroInitializerAttrs const &x) const { + ::FlexFlow::ZeroInitializerAttrs const &x) const { size_t result = 0; return result; } } // namespace std namespace nlohmann { -FlexFlow::ZeroInitializerAttrs - adl_serializer::from_json(json const &j) { - return {}; +::FlexFlow::ZeroInitializerAttrs + adl_serializer<::FlexFlow::ZeroInitializerAttrs>::from_json(json const &j) { + return ::FlexFlow::ZeroInitializerAttrs{}; } -void adl_serializer::to_json( - json &j, FlexFlow::ZeroInitializerAttrs const &v) { +void adl_serializer<::FlexFlow::ZeroInitializerAttrs>::to_json( + json &j, ::FlexFlow::ZeroInitializerAttrs const &v) { j["__type"] = "ZeroInitializerAttrs"; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(); +Gen<::FlexFlow::ZeroInitializerAttrs> + Arbitrary<::FlexFlow::ZeroInitializerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ZeroInitializerAttrs>(); } } // namespace rc diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc index 21c53ad4e8..4497d849e6 100644 --- a/lib/pcg/src/pcg/layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -42,7 +42,7 @@ bool LayerAttrs::operator>=(LayerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::LayerAttrs const &x) const { + ::FlexFlow::LayerAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ComputationGraphOpAttrs>{}(x.attrs) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -54,15 +54,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::LayerAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::LayerAttrs + adl_serializer<::FlexFlow::LayerAttrs>::from_json(json const &j) { + return ::FlexFlow::LayerAttrs{ j.at("attrs").template get<::FlexFlow::ComputationGraphOpAttrs>(), j.at("name") .template get>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::LayerAttrs const &v) { +void adl_serializer<::FlexFlow::LayerAttrs>::to_json( + json &j, ::FlexFlow::LayerAttrs const &v) { j["__type"] = "LayerAttrs"; j["attrs"] = v.attrs; j["name"] = v.name; diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc index 9d92608569..706de4e376 100644 --- a/lib/pcg/src/pcg/layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -37,7 +37,7 @@ bool layer_guid_t::operator>=(layer_guid_t const &other) const { namespace std { size_t hash::operator()( - FlexFlow::layer_guid_t const &x) const { + ::FlexFlow::layer_guid_t const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/src/pcg/machine_specification.dtg.cc b/lib/pcg/src/pcg/machine_specification.dtg.cc index 238c61a014..f893b135bb 100644 --- a/lib/pcg/src/pcg/machine_specification.dtg.cc +++ b/lib/pcg/src/pcg/machine_specification.dtg.cc @@ -97,7 +97,7 @@ bool MachineSpecification::operator>=(MachineSpecification const &other) const { namespace std { size_t hash::operator()( - FlexFlow::MachineSpecification const &x) const { + ::FlexFlow::MachineSpecification const &x) const { size_t result = 0; result ^= std::hash{}(x.num_nodes) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -114,16 +114,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::MachineSpecification - adl_serializer::from_json(json const &j) { - return {j.at("num_nodes").template get(), - j.at("num_cpus_per_node").template get(), - j.at("num_gpus_per_node").template get(), - j.at("inter_node_bandwidth").template get(), - j.at("intra_node_bandwidth").template get()}; +::FlexFlow::MachineSpecification + adl_serializer<::FlexFlow::MachineSpecification>::from_json(json const &j) { + return ::FlexFlow::MachineSpecification{ + j.at("num_nodes").template get(), + j.at("num_cpus_per_node").template get(), + j.at("num_gpus_per_node").template get(), + j.at("inter_node_bandwidth").template get(), + j.at("intra_node_bandwidth").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::MachineSpecification const &v) { +void adl_serializer<::FlexFlow::MachineSpecification>::to_json( + json &j, ::FlexFlow::MachineSpecification const &v) { j["__type"] = "MachineSpecification"; j["num_nodes"] = v.num_nodes; j["num_cpus_per_node"] = v.num_cpus_per_node; diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index ff1d34852b..00bf1296fe 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -25,18 +25,19 @@ static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stride > 0); StridedRectangleSide side = strided_side_from_size_and_stride(side_size_t{stop - start}, stride); - StridedRectangle rect = {{side}}; + StridedRectangle rect = + StridedRectangle{std::vector{side}}; return rect; } MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); - return {device_id_t{start}, rect}; + return MachineView{device_id_t{start}, rect}; } MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); - return {device_id_t{start}, rect}; + return MachineView{device_id_t{start}, rect}; } MachineView make_1d_machine_view(device_id_t start, diff --git a/lib/pcg/src/pcg/machine_view.dtg.cc b/lib/pcg/src/pcg/machine_view.dtg.cc index edab125e3d..de577fe409 100644 --- a/lib/pcg/src/pcg/machine_view.dtg.cc +++ b/lib/pcg/src/pcg/machine_view.dtg.cc @@ -39,7 +39,7 @@ bool MachineView::operator>=(MachineView const &other) const { namespace std { size_t hash::operator()( - FlexFlow::MachineView const &x) const { + ::FlexFlow::MachineView const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::device_id_t>{}(x.start) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -50,13 +50,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::MachineView - adl_serializer::from_json(json const &j) { - return {j.at("start").template get<::FlexFlow::device_id_t>(), - j.at("rect").template get<::FlexFlow::StridedRectangle>()}; +::FlexFlow::MachineView + adl_serializer<::FlexFlow::MachineView>::from_json(json const &j) { + return ::FlexFlow::MachineView{ + j.at("start").template get<::FlexFlow::device_id_t>(), + j.at("rect").template get<::FlexFlow::StridedRectangle>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::MachineView const &v) { +void adl_serializer<::FlexFlow::MachineView>::to_json( + json &j, ::FlexFlow::MachineView const &v) { j["__type"] = "MachineView"; j["start"] = v.start; j["rect"] = v.rect; diff --git a/lib/pcg/src/pcg/num_points_t.dtg.cc b/lib/pcg/src/pcg/num_points_t.dtg.cc index 7a0a849814..e7c54dcfbe 100644 --- a/lib/pcg/src/pcg/num_points_t.dtg.cc +++ b/lib/pcg/src/pcg/num_points_t.dtg.cc @@ -35,7 +35,7 @@ bool num_points_t::operator>=(num_points_t const &other) const { namespace std { size_t hash::operator()( - FlexFlow::num_points_t const &x) const { + ::FlexFlow::num_points_t const &x) const { size_t result = 0; result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,20 +44,20 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::num_points_t - adl_serializer::from_json(json const &j) { - return {j.at("unwrapped").template get()}; +::FlexFlow::num_points_t + adl_serializer<::FlexFlow::num_points_t>::from_json(json const &j) { + return ::FlexFlow::num_points_t{j.at("unwrapped").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::num_points_t const &v) { +void adl_serializer<::FlexFlow::num_points_t>::to_json( + json &j, ::FlexFlow::num_points_t const &v) { j["__type"] = "num_points_t"; j["unwrapped"] = v.unwrapped; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::num_points_t> Arbitrary<::FlexFlow::num_points_t>::arbitrary() { + return gen::construct<::FlexFlow::num_points_t>(gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc index 381c948ad0..7d31197f9d 100644 --- a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc @@ -38,7 +38,7 @@ bool OperatorGraphInput::operator>=(OperatorGraphInput const &other) const { namespace std { size_t hash::operator()( - FlexFlow::OperatorGraphInput const &x) const { + ::FlexFlow::OperatorGraphInput const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc index 88c23c0c67..2b5a2abbcd 100644 --- a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc @@ -38,7 +38,7 @@ bool OperatorGraphOutput::operator>=(OperatorGraphOutput const &other) const { namespace std { size_t hash::operator()( - FlexFlow::OperatorGraphOutput const &x) const { + ::FlexFlow::OperatorGraphOutput const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc index d362459cc3..7ec6876c8b 100644 --- a/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc @@ -115,7 +115,7 @@ bool AdamOptimizerAttrs::operator>=(AdamOptimizerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::AdamOptimizerAttrs const &x) const { + ::FlexFlow::AdamOptimizerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.alpha) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -136,18 +136,19 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::AdamOptimizerAttrs - adl_serializer::from_json(json const &j) { - return {j.at("alpha").template get(), - j.at("beta1").template get(), - j.at("beta2").template get(), - j.at("weight_decay").template get(), - j.at("alpha_t").template get(), - j.at("beta_t").template get(), - j.at("beta2_t").template get()}; +::FlexFlow::AdamOptimizerAttrs + adl_serializer<::FlexFlow::AdamOptimizerAttrs>::from_json(json const &j) { + return ::FlexFlow::AdamOptimizerAttrs{ + j.at("alpha").template get(), + j.at("beta1").template get(), + j.at("beta2").template get(), + j.at("weight_decay").template get(), + j.at("alpha_t").template get(), + j.at("beta_t").template get(), + j.at("beta2_t").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::AdamOptimizerAttrs const &v) { +void adl_serializer<::FlexFlow::AdamOptimizerAttrs>::to_json( + json &j, ::FlexFlow::AdamOptimizerAttrs const &v) { j["__type"] = "AdamOptimizerAttrs"; j["alpha"] = v.alpha; j["beta1"] = v.beta1; @@ -160,15 +161,16 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary()); +Gen<::FlexFlow::AdamOptimizerAttrs> + Arbitrary<::FlexFlow::AdamOptimizerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::AdamOptimizerAttrs>( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc index d5e668917b..de1c5a4e6b 100644 --- a/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc @@ -52,7 +52,7 @@ bool SGDOptimizerAttrs::operator>=(SGDOptimizerAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::SGDOptimizerAttrs const &x) const { + ::FlexFlow::SGDOptimizerAttrs const &x) const { size_t result = 0; result ^= std::hash{}(x.lr) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -67,15 +67,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::SGDOptimizerAttrs - adl_serializer::from_json(json const &j) { - return {j.at("lr").template get(), - j.at("momentum").template get(), - j.at("nesterov").template get(), - j.at("weight_decay").template get()}; +::FlexFlow::SGDOptimizerAttrs + adl_serializer<::FlexFlow::SGDOptimizerAttrs>::from_json(json const &j) { + return ::FlexFlow::SGDOptimizerAttrs{ + j.at("lr").template get(), + j.at("momentum").template get(), + j.at("nesterov").template get(), + j.at("weight_decay").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::SGDOptimizerAttrs const &v) { +void adl_serializer<::FlexFlow::SGDOptimizerAttrs>::to_json( + json &j, ::FlexFlow::SGDOptimizerAttrs const &v) { j["__type"] = "SGDOptimizerAttrs"; j["lr"] = v.lr; j["momentum"] = v.momentum; @@ -85,12 +86,13 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary(), - gen::arbitrary()); +Gen<::FlexFlow::SGDOptimizerAttrs> + Arbitrary<::FlexFlow::SGDOptimizerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::SGDOptimizerAttrs>( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph.cc deleted file mode 100644 index c5557488b8..0000000000 --- a/lib/pcg/src/pcg/parallel_computation_graph.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "pcg/parallel_computation_graph.h" -#include "utils/containers.h" - -namespace FlexFlow { - -ParallelComputationGraph empty_parallel_computation_graph() { - return ParallelComputationGraph{DataflowGraph{}}; -} - -std::unordered_set get_parallel_layers(ParallelComputationGraph const &pcg) { - return transform(get_nodes(pcg.raw_graph), - [&](Node const &n) { return parallel_layer_guid_t{n}; }); -} - -ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, - parallel_tensor_guid_t const &t) { - return pcg.raw_graph.at(t.raw_graph_output); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc new file mode 100644 index 0000000000..491ac67708 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -0,0 +1,49 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers.h" + +namespace FlexFlow { + +ParallelComputationGraph empty_parallel_computation_graph() { + return ParallelComputationGraph{ + DataflowGraph{}}; +} + +std::unordered_set + get_parallel_layers(ParallelComputationGraph const &pcg) { + return transform(get_nodes(pcg.raw_graph), + [&](Node const &n) { return parallel_layer_guid_t{n}; }); +} + +std::vector + get_layer_inputs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return transform( + get_inputs(pcg.raw_graph, l.raw_graph_node), + [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); +} + +std::vector + get_layer_outputs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return transform( + get_outputs(pcg.raw_graph, l.raw_graph_node), + [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); +} + +parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, + parallel_tensor_guid_t const &t) { + return parallel_layer_guid_t{t.raw_graph_output.src}; +} + +ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return pcg.raw_graph.at(l.raw_graph_node); +} + +ParallelTensorAttrs + get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return pcg.raw_graph.at(t.raw_graph_output); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc similarity index 50% rename from lib/pcg/src/pcg/parallel_computation_graph.dtg.cc rename to lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc index e4e1555b4a..cdc9130979 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc @@ -1,17 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" + "generated_from": "1339be6e86e9818c36d6ecf5475e2d4b" } */ -#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/dataflow_graph.h" -#include "pcg/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/dataflow_graph/dataflow_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" namespace FlexFlow { ParallelComputationGraph::ParallelComputationGraph( diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc new file mode 100644 index 0000000000..90bc327a9a --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -0,0 +1,302 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers.h" +#include "utils/containers/concat_vectors.h" + +namespace FlexFlow { + +static std::string get_default_name(OperatorType op_type) { + return get_operator_type_name(op_type); +} + +static std::string get_default_name(PCGOperatorAttrs const &attrs) { + return get_default_name(get_op_type(attrs)); +} + +static ParallelTensorAttrs make_weight_attrs( + ParallelTensorShape const &shape, + std::optional const &initializer_attrs) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/initializer_attrs, + /*create_gradients=*/CreateGrad::YES, + }; +} + +ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() + : pcg(empty_parallel_computation_graph()) {} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( + ParallelTensorShape const &shape, + bool create_grad, + std::optional const &name) { + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), + }; + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{InputAttrs{}}, + name, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::add( + parallel_tensor_guid_t const &lhs, + parallel_tensor_guid_t const &rhs, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( + parallel_tensor_guid_t const &a, + parallel_tensor_guid_t const &b, + /* int a_seq_length_dim = -1, */ + /* int b_seq_length_dim = -1, */ + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( + parallel_tensor_guid_t const &input, + DataType result_type, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::conv2d( + parallel_tensor_guid_t const &raw_input, + int outChannels, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + std::optional const &activation, + int groups, + bool use_bias, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, + std::optional const &kernel_regularizer, + std::optional const &maybe_name) { + Conv2DAttrs attrs = Conv2DAttrs{outChannels, + kernelH, + kernelW, + strideH, + strideW, + paddingH, + paddingW, + groups, + activation, + use_bias}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + parallel_tensor_guid_t input = + this->as_type(raw_input, DataType::FLOAT, name + "input_pre_cast"); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + + std::vector weights; + + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), + kernel_initializer)); + + if (use_bias) { + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), + bias_initializer)); + } + + return this->add_layer(layer, {input}, weights, output_shape); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( + parallel_tensor_guid_t const &input, + int outDim, + std::optional activation, + bool use_bias, + DataType data_type, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( + parallel_tensor_guid_t const &input, + int num_entries, + int outDim, + AggregateOp aggr, + DataType dtype, + std::optional const &kernel_initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( + parallel_tensor_guid_t const &query, + parallel_tensor_guid_t const &key, + parallel_tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool bias, + bool add_bias_kv, + bool add_zero_attn, + std::optional initializer, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( + parallel_tensor_guid_t const &input, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( + parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_combine( + parallel_tensor_guid_t const &x, + ff_dim_t dim, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_replicate( + parallel_tensor_guid_t const &x, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_reduce( + parallel_tensor_guid_t const &x, + int degree, + std::optional const &name) { + NOT_IMPLEMENTED(); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::as_type( + parallel_tensor_guid_t const &input, + DataType goal_datatype, + std::string const &name) { + DataType input_datatype = this->get_shape(input).data_type; + if (input_datatype == goal_datatype) { + return input; + } else if (can_strictly_promote_datatype_from_to(input_datatype, + goal_datatype)) { + return this->cast(input, goal_datatype, name); + } else { + throw mk_runtime_error( + fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + input_datatype, + goal_datatype)); + } +} + +ParallelTensorShape ParallelComputationGraphBuilder::get_shape( + parallel_tensor_guid_t const &t) const { + return get_parallel_tensor_attrs(this->pcg, t).shape; +} + +std::vector ParallelComputationGraphBuilder::add_layer( + ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + std::vector raw_weight_tensors; + for (auto const &kv : enumerate_vector(weights)) { + int weight_idx = kv.first; + ParallelTensorAttrs weight_tensor_attrs = kv.second; + + std::optional weight_name = + transform(layer.name, [&](std::string const &layer_name) { + return fmt::format("{}.weights[{}]", layer_name, weight_idx); + }); + ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{}}, + weight_name, + }; + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = { + weight_tensor_attrs}; + raw_weight_tensors.push_back(get_only(this->pcg.raw_graph + .add_operator(weight_layer_attrs, + weight_layer_inputs, + weight_output_attrs) + .outputs)); + } + + std::vector raw_inputs = + transform(inputs, [](parallel_tensor_guid_t const &t) { + return t.raw_graph_output; + }); + std::vector raw_outputs = + this->pcg.raw_graph + .add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) + .outputs; + return transform(raw_outputs, [](MultiDiOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + +std::vector ParallelComputationGraphBuilder::add_layer( + ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer(layer, + inputs, + weights, + transform(outputs, [](ParallelTensorShape const &s) { + return ParallelTensorAttrs{ + /*shape=*/s, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + })); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::add_layer( + ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorAttrs const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::add_layer( + ParallelLayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + ParallelTensorShape const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc new file mode 100644 index 0000000000..5995e4ee01 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.cc @@ -0,0 +1,10 @@ +#include "pcg/parallel_computation_graph/parallel_layer_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" + +namespace FlexFlow { + +OperatorType get_op_type(ParallelLayerAttrs const &a) { + return get_op_type(a.op_attrs); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc similarity index 55% rename from lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc rename to lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc index 455fb22baf..a16998c698 100644 --- a/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml /* proj-data { - "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" + "generated_from": "9bb6e3cb7b0e523fae8f33bd8ad80d6d" } */ -#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "op-attrs/operator_attrs.h" #include "utils/stack_string.h" @@ -16,34 +16,40 @@ namespace FlexFlow { ParallelLayerAttrs::ParallelLayerAttrs( - ::FlexFlow::PCGOperatorAttrs const &attrs, + ::FlexFlow::PCGOperatorAttrs const &op_attrs, std::optional<::FlexFlow::stack_string> const &name) - : attrs(attrs), name(name) {} + : op_attrs(op_attrs), name(name) {} bool ParallelLayerAttrs::operator==(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) == + std::tie(other.op_attrs, other.name); } bool ParallelLayerAttrs::operator!=(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) != + std::tie(other.op_attrs, other.name); } bool ParallelLayerAttrs::operator<(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) < + std::tie(other.op_attrs, other.name); } bool ParallelLayerAttrs::operator>(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) > + std::tie(other.op_attrs, other.name); } bool ParallelLayerAttrs::operator<=(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) <= + std::tie(other.op_attrs, other.name); } bool ParallelLayerAttrs::operator>=(ParallelLayerAttrs const &other) const { - return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); + return std::tie(this->op_attrs, this->name) >= + std::tie(other.op_attrs, other.name); } } // namespace FlexFlow namespace std { size_t hash::operator()( - FlexFlow::ParallelLayerAttrs const &x) const { + ::FlexFlow::ParallelLayerAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::PCGOperatorAttrs>{}(x.attrs) + 0x9e3779b9 + + result ^= std::hash<::FlexFlow::PCGOperatorAttrs>{}(x.op_attrs) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash>>{}(x.name) + @@ -53,17 +59,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ParallelLayerAttrs - adl_serializer::from_json(json const &j) { - return { - j.at("attrs").template get<::FlexFlow::PCGOperatorAttrs>(), +::FlexFlow::ParallelLayerAttrs + adl_serializer<::FlexFlow::ParallelLayerAttrs>::from_json(json const &j) { + return ::FlexFlow::ParallelLayerAttrs{ + j.at("op_attrs").template get<::FlexFlow::PCGOperatorAttrs>(), j.at("name") .template get>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelLayerAttrs const &v) { +void adl_serializer<::FlexFlow::ParallelLayerAttrs>::to_json( + json &j, ::FlexFlow::ParallelLayerAttrs const &v) { j["__type"] = "ParallelLayerAttrs"; - j["attrs"] = v.attrs; + j["op_attrs"] = v.op_attrs; j["name"] = v.name; } } // namespace nlohmann @@ -72,7 +78,7 @@ namespace FlexFlow { std::string format_as(ParallelLayerAttrs const &x) { std::ostringstream oss; oss << ""; return oss.str(); diff --git a/lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc similarity index 90% rename from lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc rename to lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc index 876a735b14..df575ebc98 100644 --- a/lib/pcg/src/pcg/parallel_layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_layer_guid_t.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml /* proj-data { "generated_from": "c31301efeb92e151b04943786aa7bec1" } */ -#include "pcg/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "utils/graph.h" #include @@ -44,7 +44,7 @@ bool parallel_layer_guid_t::operator>=( namespace std { size_t hash::operator()( - FlexFlow::parallel_layer_guid_t const &x) const { + ::FlexFlow::parallel_layer_guid_t const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.raw_graph_node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc similarity index 91% rename from lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc rename to lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc index ae5d618172..13be5e839f 100644 --- a/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml /* proj-data { "generated_from": "b3e086b380bbc41d99332e1463a34b28" } */ -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/param_sync.dtg.h" @@ -82,7 +82,7 @@ bool ParallelTensorAttrs::operator>=(ParallelTensorAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ParallelTensorAttrs const &x) const { + ::FlexFlow::ParallelTensorAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.shape) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -98,17 +98,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::ParallelTensorAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::ParallelTensorAttrs + adl_serializer<::FlexFlow::ParallelTensorAttrs>::from_json(json const &j) { + return ::FlexFlow::ParallelTensorAttrs{ j.at("shape").template get<::FlexFlow::ParallelTensorShape>(), j.at("sync_type").template get>(), j.at("initializer") .template get>(), j.at("create_gradients").template get<::FlexFlow::CreateGrad>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelTensorAttrs const &v) { +void adl_serializer<::FlexFlow::ParallelTensorAttrs>::to_json( + json &j, ::FlexFlow::ParallelTensorAttrs const &v) { j["__type"] = "ParallelTensorAttrs"; j["shape"] = v.shape; j["sync_type"] = v.sync_type; diff --git a/lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc similarity index 90% rename from lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc rename to lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc index b64cf2901f..38c2970225 100644 --- a/lib/pcg/src/pcg/parallel_tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/parallel_tensor_guid_t.struct.toml +// lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml /* proj-data { "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" } */ -#include "pcg/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "utils/graph/multidiedge.h" #include @@ -44,7 +44,7 @@ bool parallel_tensor_guid_t::operator>=( namespace std { size_t hash::operator()( - FlexFlow::parallel_tensor_guid_t const &x) const { + ::FlexFlow::parallel_tensor_guid_t const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/src/pcg/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph_builder.cc deleted file mode 100644 index 81f479e8f6..0000000000 --- a/lib/pcg/src/pcg/parallel_computation_graph_builder.cc +++ /dev/null @@ -1,293 +0,0 @@ -#include "pcg/parallel_computation_graph_builder.h" -#include "op-attrs/ops/weight_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.h" -#include "pcg/parallel_computation_graph.h" -#include "utils/containers/concat_vectors.h" -#include "utils/containers.h" - -namespace FlexFlow { - -static std::string get_default_name(OperatorType op_type) { - return get_operator_type_name(op_type); -} - -static std::string get_default_name(PCGOperatorAttrs const &attrs) { - return get_default_name(get_op_type(attrs)); -} - -static ParallelTensorAttrs make_weight_attrs( - ParallelTensorShape const &shape, - std::optional const &initializer_attrs) { - return ParallelTensorAttrs{ - /*shape=*/shape, - /*sync_type=*/std::nullopt, - /*initializer=*/initializer_attrs, - /*create_gradients=*/CreateGrad::YES, - }; -} - - -ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() - : pcg(empty_parallel_computation_graph()) { } - -parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor(ParallelTensorShape const &shape, - bool create_grad, - std::optional const &name) { - ParallelTensorAttrs tensor_attrs = { - /*shape=*/shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), - }; - ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ - PCGOperatorAttrs{InputAttrs{}}, - name, - }; - - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); -} - -parallel_tensor_guid_t ParallelComputationGraphBuilder::add(parallel_tensor_guid_t const &lhs, - parallel_tensor_guid_t const &rhs, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::batch_matmul(parallel_tensor_guid_t const &a, - parallel_tensor_guid_t const &b, - /* int a_seq_length_dim = -1, */ - /* int b_seq_length_dim = -1, */ - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::cast(parallel_tensor_guid_t const &input, - DataType result_type, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::conv2d(parallel_tensor_guid_t const &raw_input, - int outChannels, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - std::optional const &activation, - int groups, - bool use_bias, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, - std::optional const &kernel_regularizer, - std::optional const &maybe_name) { - Conv2DAttrs attrs = {outChannels, - kernelH, - kernelW, - strideH, - strideW, - paddingH, - paddingW, - groups, - activation, - use_bias}; - - std::string name = - maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); - - parallel_tensor_guid_t input = - this->as_type(raw_input, DataType::FLOAT, name + "input_pre_cast"); - - ParallelLayerAttrs layer = {PCGOperatorAttrs{attrs}, name}; - - ParallelTensorShape input_shape = this->get_shape(input); - ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); - - std::vector weights; - - weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), - kernel_initializer)); - - if (use_bias) { - weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), - bias_initializer)); - } - - return this->add_layer(layer, {input}, weights, output_shape); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::dense(parallel_tensor_guid_t const &input, - int outDim, - std::optional activation, - bool use_bias, - DataType data_type, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::embedding( - parallel_tensor_guid_t const &input, - int num_entries, - int outDim, - AggregateOp aggr, - DataType dtype, - std::optional const &kernel_initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::multihead_attention( - parallel_tensor_guid_t const &query, - parallel_tensor_guid_t const &key, - parallel_tensor_guid_t const &value, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool bias, - bool add_bias_kv, - bool add_zero_attn, - std::optional initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::relu(parallel_tensor_guid_t const &input, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::parallel_partition(parallel_tensor_guid_t const &x, - ff_dim_t dim, - int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::parallel_combine(parallel_tensor_guid_t const &x, - ff_dim_t dim, - int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::parallel_replicate(parallel_tensor_guid_t const &x, - int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::parallel_reduce(parallel_tensor_guid_t const &x, - int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::as_type(parallel_tensor_guid_t const &input, DataType goal_datatype, std::string const &name) { - DataType input_datatype = this->get_shape(input).data_type; - if (input_datatype == goal_datatype) { - return input; - } else if (can_strictly_promote_datatype_from_to(input_datatype, goal_datatype)) { - return this->cast(input, goal_datatype, name); - } else { - throw mk_runtime_error( - fmt::format("Could not convert provided tensor data type {} to " - "desired data type {}", - input_datatype, - goal_datatype)); - } -} - -ParallelTensorShape -ParallelComputationGraphBuilder::get_shape(parallel_tensor_guid_t const &t) const { - return get_parallel_tensor_attrs(this->pcg, t).shape; -} - -std::vector -ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - std::vector raw_weight_tensors; - for (auto const &kv : enumerate_vector(weights)) { - int weight_idx = kv.first; - ParallelTensorAttrs weight_tensor_attrs = kv.second; - - std::optional weight_name = - transform(layer.name, [&](std::string const &layer_name) { - return fmt::format("{}.weights[{}]", layer_name, weight_idx); - }); - ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ - PCGOperatorAttrs{WeightAttrs{}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = {weight_tensor_attrs}; - raw_weight_tensors.push_back( - get_only(this->pcg.raw_graph.add_operator( - weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); - } - - std::vector raw_inputs = transform( - inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = - this->pcg.raw_graph.add_operator( - layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); - return transform(raw_outputs, - [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); -} - -std::vector -ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - return this->add_layer( - layer, inputs, weights, transform(outputs, [](ParallelTensorShape const &s) { - return ParallelTensorAttrs{ - /*shape=*/s, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - /*create_gradients=*/CreateGrad::YES, - }; - })); -} - - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - ParallelTensorAttrs const &output) { - std::vector outputs = {output}; - return get_only(this->add_layer(layer, inputs, weights, outputs)); -} - -parallel_tensor_guid_t -ParallelComputationGraphBuilder::add_layer(ParallelLayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - ParallelTensorShape const &output) { - std::vector outputs = {output}; - return get_only(this->add_layer(layer, inputs, weights, outputs)); -} - - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/side_size_t.dtg.cc b/lib/pcg/src/pcg/side_size_t.dtg.cc index 54db2974fe..0d13091cc8 100644 --- a/lib/pcg/src/pcg/side_size_t.dtg.cc +++ b/lib/pcg/src/pcg/side_size_t.dtg.cc @@ -35,7 +35,7 @@ bool side_size_t::operator>=(side_size_t const &other) const { namespace std { size_t hash::operator()( - FlexFlow::side_size_t const &x) const { + ::FlexFlow::side_size_t const &x) const { size_t result = 0; result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -44,20 +44,20 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::side_size_t - adl_serializer::from_json(json const &j) { - return {j.at("unwrapped").template get()}; +::FlexFlow::side_size_t + adl_serializer<::FlexFlow::side_size_t>::from_json(json const &j) { + return ::FlexFlow::side_size_t{j.at("unwrapped").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::side_size_t const &v) { +void adl_serializer<::FlexFlow::side_size_t>::to_json( + json &j, ::FlexFlow::side_size_t const &v) { j["__type"] = "side_size_t"; j["unwrapped"] = v.unwrapped; } } // namespace nlohmann namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary()); +Gen<::FlexFlow::side_size_t> Arbitrary<::FlexFlow::side_size_t>::arbitrary() { + return gen::construct<::FlexFlow::side_size_t>(gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc index e743a2722a..d50c5861ea 100644 --- a/lib/pcg/src/pcg/strided_rectangle.dtg.cc +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -39,7 +39,7 @@ bool StridedRectangle::operator>=(StridedRectangle const &other) const { namespace std { size_t hash::operator()( - FlexFlow::StridedRectangle const &x) const { + ::FlexFlow::StridedRectangle const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>{}( @@ -50,23 +50,24 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::StridedRectangle - adl_serializer::from_json(json const &j) { - return {j.at("sides") - .template get< - ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()}; +::FlexFlow::StridedRectangle + adl_serializer<::FlexFlow::StridedRectangle>::from_json(json const &j) { + return ::FlexFlow::StridedRectangle{ + j.at("sides") + .template get< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::StridedRectangle const &v) { +void adl_serializer<::FlexFlow::StridedRectangle>::to_json( + json &j, ::FlexFlow::StridedRectangle const &v) { j["__type"] = "StridedRectangle"; j["sides"] = v.sides; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::StridedRectangle> + Arbitrary<::FlexFlow::StridedRectangle>::arbitrary() { + return gen::construct<::FlexFlow::StridedRectangle>( gen::arbitrary< ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()); } diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc index 80258886d7..5e7274141d 100644 --- a/lib/pcg/src/pcg/strided_rectangle_side.cc +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -9,7 +9,7 @@ StridedRectangleSide strided_side_from_size_and_stride(side_size_t, } side_size_t get_side_size(StridedRectangleSide const &s) { - return s.num_points.unwrapped * s.stride; + return side_size_t{s.num_points.unwrapped * s.stride}; } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc index 0bb31b0496..e2533f7a21 100644 --- a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc +++ b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc @@ -44,7 +44,7 @@ bool StridedRectangleSide::operator>=(StridedRectangleSide const &other) const { namespace std { size_t hash::operator()( - FlexFlow::StridedRectangleSide const &x) const { + ::FlexFlow::StridedRectangleSide const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::num_points_t>{}(x.num_points) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -55,13 +55,14 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::StridedRectangleSide - adl_serializer::from_json(json const &j) { - return {j.at("num_points").template get<::FlexFlow::num_points_t>(), - j.at("stride").template get()}; +::FlexFlow::StridedRectangleSide + adl_serializer<::FlexFlow::StridedRectangleSide>::from_json(json const &j) { + return ::FlexFlow::StridedRectangleSide{ + j.at("num_points").template get<::FlexFlow::num_points_t>(), + j.at("stride").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::StridedRectangleSide const &v) { +void adl_serializer<::FlexFlow::StridedRectangleSide>::to_json( + json &j, ::FlexFlow::StridedRectangleSide const &v) { j["__type"] = "StridedRectangleSide"; j["num_points"] = v.num_points; j["stride"] = v.stride; @@ -69,9 +70,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::StridedRectangleSide> + Arbitrary<::FlexFlow::StridedRectangleSide>::arbitrary() { + return gen::construct<::FlexFlow::StridedRectangleSide>( gen::arbitrary<::FlexFlow::num_points_t>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/pcg/src/pcg/tensor_attrs.dtg.cc b/lib/pcg/src/pcg/tensor_attrs.dtg.cc index 46a6fb8d50..e75fe506f6 100644 --- a/lib/pcg/src/pcg/tensor_attrs.dtg.cc +++ b/lib/pcg/src/pcg/tensor_attrs.dtg.cc @@ -81,7 +81,7 @@ bool TensorAttrs::operator>=(TensorAttrs const &other) const { namespace std { size_t hash::operator()( - FlexFlow::TensorAttrs const &x) const { + ::FlexFlow::TensorAttrs const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -97,17 +97,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorAttrs - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::TensorAttrs + adl_serializer<::FlexFlow::TensorAttrs>::from_json(json const &j) { + return ::FlexFlow::TensorAttrs{ j.at("shape").template get<::FlexFlow::TensorShape>(), j.at("initializer") .template get>(), j.at("create_gradients").template get(), j.at("sync_type").template get>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorAttrs const &v) { +void adl_serializer<::FlexFlow::TensorAttrs>::to_json( + json &j, ::FlexFlow::TensorAttrs const &v) { j["__type"] = "TensorAttrs"; j["shape"] = v.shape; j["initializer"] = v.initializer; diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc index 779018296d..c8fbb7299b 100644 --- a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -37,7 +37,7 @@ bool tensor_guid_t::operator>=(tensor_guid_t const &other) const { namespace std { size_t hash::operator()( - FlexFlow::tensor_guid_t const &x) const { + ::FlexFlow::tensor_guid_t const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/pcg/test/src/pcg/dataflow_graph.cc b/lib/pcg/test/src/pcg/dataflow_graph.cc new file mode 100644 index 0000000000..0b4b31512b --- /dev/null +++ b/lib/pcg/test/src/pcg/dataflow_graph.cc @@ -0,0 +1,48 @@ +#include "pcg/dataflow_graph/dataflow_graph.h" +#include "test/utils/doctest.h" +#include "utils/fmt/unordered_set.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DataflowGraph") { + DataflowGraph g; + + int n1_label = 1; + int n2_label = 2; + int n3_label = 3; + int n4_label = 4; + + std::string o1_label = "o1"; + std::string o2_label = "o2"; + std::string o3_label = "o3"; + std::string o4_label = "o4"; + + OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); + Node n1 = n1_added.node; + MultiDiOutput o1 = get_only(n1_added.outputs); + + OperatorAddedResult n2_added = g.add_operator(n2_label, {}, {o2_label}); + Node n2 = n2_added.node; + MultiDiOutput o2 = get_only(n2_added.outputs); + + OperatorAddedResult n3_added = g.add_operator(n3_label, {}, {o3_label}); + Node n3 = n3_added.node; + MultiDiOutput o3 = get_only(n3_added.outputs); + + OperatorAddedResult n4_added = + g.add_operator(n4_label, {o1, o2, o3}, {o4_label}); + Node n4 = n4_added.node; + MultiDiOutput o4 = get_only(n4_added.outputs); + + SUBCASE("get_inputs") { + std::vector result = get_inputs(g, n4); + std::vector correct = {o1, o2, o3}; + CHECK(result == correct); + } + + SUBCASE("get_outputs") { + std::vector result = get_outputs(g, n4); + std::vector correct = {o4}; + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc new file mode 100644 index 0000000000..0eaf78966f --- /dev/null +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -0,0 +1,125 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.h" +#include "test/utils/doctest.h" +#include "utils/containers.h" +#include "utils/containers/without_nullopts.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ParallelComputationGraphBuilder") { + ParallelComputationGraphBuilder b; + + size_t batch_size = 2; + + TensorShape unpar_input_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, 3, 10, 10}}, + DataType::FLOAT, + }; + + ParallelTensorShape input_shape = + lift_to_parallel_with_degrees(unpar_input_shape, + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 1}); + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + + int outChannels = 6; + int kernelH = 5; + int kernelW = 4; + int strideH = 3; + int strideW = 2; + int paddingH = 1; + int paddingW = 0; + parallel_tensor_guid_t output = b.conv2d(input, + /*outChannels=*/outChannels, + /*kernelH=*/kernelH, + /*kernelW=*/kernelW, + /*strideH=*/strideH, + /*strideW=*/strideW, + /*paddingH=*/paddingH, + /*paddingW=*/paddingW); + + std::unordered_map layers = + generate_map(get_parallel_layers(b.pcg), + [&](parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(b.pcg, l); + }); + CHECK_MESSAGE(layers.size() == 4, "Incorrect layers ", layers); + + auto num_attrs_of_type = [&](OperatorType op_type) -> int { + return count(values(layers), [&](ParallelLayerAttrs const &l) { + return get_op_type(l) == op_type; + }); + }; + + int num_weight_attrs = num_attrs_of_type(OperatorType::WEIGHT); + CHECK(num_weight_attrs == 2); + + int num_input_attrs = num_attrs_of_type(OperatorType::INPUT); + CHECK(num_input_attrs == 1); + + int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D); + CHECK(num_conv_attrs == 1); + + parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( + as_vector(items(layers)), + [](std::pair const &kv) + -> std::optional { + if (get_op_type(kv.second) == OperatorType::CONV2D) { + return kv.first; + } else { + return std::nullopt; + } + }))); + Conv2DAttrs conv_attrs = layers.at(conv_guid).op_attrs.get(); + Conv2DAttrs correct_attrs = Conv2DAttrs{ + outChannels, + kernelH, + kernelW, + strideH, + strideW, + paddingH, + paddingW, + /*groups=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/true, + }; + CHECK(conv_attrs == correct_attrs); + + ParallelTensorShape correct_output_shape = + get_output_shape(correct_attrs, input_shape); + ParallelTensorShape correct_kernel_shape = + get_kernel_shape(correct_attrs, input_shape); + ParallelTensorShape correct_bias_shape = + get_bias_shape(correct_attrs, input_shape); + + std::vector conv_inputs = + get_layer_inputs(b.pcg, conv_guid); + + parallel_tensor_guid_t conv_input = conv_inputs.at(0); + ParallelTensorShape conv_input_shape = + get_parallel_tensor_attrs(b.pcg, conv_input).shape; + CHECK(conv_input_shape == input_shape); + + parallel_tensor_guid_t conv_kernel = conv_inputs.at(1); + ParallelTensorShape conv_kernel_shape = + get_parallel_tensor_attrs(b.pcg, conv_kernel).shape; + CHECK(conv_kernel_shape == correct_kernel_shape); + + parallel_tensor_guid_t conv_bias = conv_inputs.at(2); + ParallelTensorShape conv_bias_shape = + get_parallel_tensor_attrs(b.pcg, conv_bias).shape; + CHECK(conv_bias_shape == correct_bias_shape); + + std::vector conv_outputs = + get_layer_outputs(b.pcg, conv_guid); + CHECK(conv_outputs.size() == 1); + + parallel_tensor_guid_t conv_output = get_only(conv_outputs); + ParallelTensorShape conv_output_shape = + get_parallel_tensor_attrs(b.pcg, conv_output).shape; + CHECK(conv_output_shape == correct_output_shape); + }; +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc deleted file mode 100644 index b11cb504e2..0000000000 --- a/lib/pcg/test/src/pcg/parallel_computation_graph_builder.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "test/utils/doctest.h" -#include "pcg/parallel_computation_graph_builder.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/parallel_computation_graph.h" - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("ParallelComputationGraphBuilder") { - ParallelComputationGraphBuilder b; - - size_t batch_size = 2; - - TensorShape unpar_input_shape = { - TensorDims{FFOrdered{batch_size, 3, 10, 10}}, - DataType::FLOAT, - }; - - ParallelTensorShape input_shape = lift_to_parallel_with_degrees(unpar_input_shape, SumDegree{1}, DiscardCopyDegree{1}, FFOrdered{2, 1, 1, 1}); - - parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - - parallel_tensor_guid_t output = b.conv2d(input, - /*outChannels=*/5, - /*kernelH=*/3, - /*kernelW=*/3, - /*strideH=*/1, - /*strideW=*/1, - /*paddingH=*/0, - /*paddingW=*/0); - - CHECK(get_parallel_layers(b.pcg).size() == 1); - }; -} diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index e88e231bd0..34be83c281 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -8,7 +8,7 @@ TEST_SUITE(FF_TEST_SUITE) { size_t batch_size = 2; - TensorShape input_shape = { + TensorShape input_shape = TensorShape{ TensorDims{FFOrdered{batch_size, 3, 10, 10}}, DataType::FLOAT, }; diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h index 35ec9e499f..38e0b66f78 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct OperatorAttributeConstraint { OperatorAttributeConstraint() = delete; - OperatorAttributeConstraint( + explicit OperatorAttributeConstraint( ::FlexFlow::ConstraintType const &constraint_type, ::FlexFlow::OperatorAttributeExpr const &attribute_expr, ::FlexFlow::OperatorAttributeValue const &attribute_value); @@ -41,16 +41,16 @@ struct OperatorAttributeConstraint { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorAttributeConstraint const &) const; +struct hash<::FlexFlow::OperatorAttributeConstraint> { + size_t operator()(::FlexFlow::OperatorAttributeConstraint const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::OperatorAttributeConstraint from_json(json const &); - static void to_json(json &, FlexFlow::OperatorAttributeConstraint const &); +struct adl_serializer<::FlexFlow::OperatorAttributeConstraint> { + static ::FlexFlow::OperatorAttributeConstraint from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeConstraint const &); }; } // namespace nlohmann diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h index 5a30c40f8d..559352de40 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct OperatorAttributeListIndexAccess { OperatorAttributeListIndexAccess() = delete; - OperatorAttributeListIndexAccess( + explicit OperatorAttributeListIndexAccess( ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index); bool operator==(OperatorAttributeListIndexAccess const &) const; @@ -37,24 +37,24 @@ struct OperatorAttributeListIndexAccess { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorAttributeListIndexAccess const &) const; +struct hash<::FlexFlow::OperatorAttributeListIndexAccess> { + size_t operator()(::FlexFlow::OperatorAttributeListIndexAccess const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::OperatorAttributeListIndexAccess from_json(json const &); +struct adl_serializer<::FlexFlow::OperatorAttributeListIndexAccess> { + static ::FlexFlow::OperatorAttributeListIndexAccess from_json(json const &); static void to_json(json &, - FlexFlow::OperatorAttributeListIndexAccess const &); + ::FlexFlow::OperatorAttributeListIndexAccess const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::OperatorAttributeListIndexAccess> { + static Gen<::FlexFlow::OperatorAttributeListIndexAccess> arbitrary(); }; } // namespace rc diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h index 17d76a08f1..23779f9d3e 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct OperatorAttributeListSize { OperatorAttributeListSize() = delete; - OperatorAttributeListSize( + explicit OperatorAttributeListSize( ::FlexFlow::OperatorAttributeKey const &attribute_key); bool operator==(OperatorAttributeListSize const &) const; @@ -36,23 +36,23 @@ struct OperatorAttributeListSize { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorAttributeListSize const &) const; +struct hash<::FlexFlow::OperatorAttributeListSize> { + size_t operator()(::FlexFlow::OperatorAttributeListSize const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::OperatorAttributeListSize from_json(json const &); - static void to_json(json &, FlexFlow::OperatorAttributeListSize const &); +struct adl_serializer<::FlexFlow::OperatorAttributeListSize> { + static ::FlexFlow::OperatorAttributeListSize from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeListSize const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::OperatorAttributeListSize> { + static Gen<::FlexFlow::OperatorAttributeListSize> arbitrary(); }; } // namespace rc diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h index 7bce198f3d..4a491af2f6 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct OperatorAttributePattern { OperatorAttributePattern() = delete; - OperatorAttributePattern( + explicit OperatorAttributePattern( std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const &attribute_constraints); @@ -35,16 +35,16 @@ struct OperatorAttributePattern { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OperatorAttributePattern const &) const; +struct hash<::FlexFlow::OperatorAttributePattern> { + size_t operator()(::FlexFlow::OperatorAttributePattern const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::OperatorAttributePattern from_json(json const &); - static void to_json(json &, FlexFlow::OperatorAttributePattern const &); +struct adl_serializer<::FlexFlow::OperatorAttributePattern> { + static ::FlexFlow::OperatorAttributePattern from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributePattern const &); }; } // namespace nlohmann diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h index 9dd20bb10e..bc76f68c4d 100644 --- a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h @@ -19,7 +19,7 @@ namespace FlexFlow { struct AttrConstant { AttrConstant() = delete; - AttrConstant(::FlexFlow::OperatorAttributeValue const &value); + explicit AttrConstant(::FlexFlow::OperatorAttributeValue const &value); bool operator==(AttrConstant const &) const; bool operator!=(AttrConstant const &) const; @@ -33,8 +33,8 @@ struct AttrConstant { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::AttrConstant const &) const; +struct hash<::FlexFlow::AttrConstant> { + size_t operator()(::FlexFlow::AttrConstant const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h index 3d6fb21574..1e78d76777 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h @@ -16,8 +16,9 @@ namespace FlexFlow { struct OutputGraphExpr { OutputGraphExpr() = delete; - OutputGraphExpr(::FlexFlow::NodeLabelledOpenMultiDiGraph< - ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); + explicit OutputGraphExpr( + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); ::FlexFlow::NodeLabelledOpenMultiDiGraph< ::FlexFlow::OutputOperatorAttrsAssignment> diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h index 0d585f0aa0..d7137c90a6 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h @@ -20,8 +20,9 @@ namespace FlexFlow { struct OutputOperatorAttrAccess { OutputOperatorAttrAccess() = delete; - OutputOperatorAttrAccess(::FlexFlow::Node const &node, - ::FlexFlow::OperatorAttributeExpr const &attr_expr); + explicit OutputOperatorAttrAccess( + ::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr); bool operator==(OutputOperatorAttrAccess const &) const; bool operator!=(OutputOperatorAttrAccess const &) const; @@ -36,8 +37,8 @@ struct OutputOperatorAttrAccess { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OutputOperatorAttrAccess const &) const; +struct hash<::FlexFlow::OutputOperatorAttrAccess> { + size_t operator()(::FlexFlow::OutputOperatorAttrAccess const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h index 5586a90a08..5718965c27 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct OutputOperatorAttrsAssignment { OutputOperatorAttrsAssignment() = delete; - OutputOperatorAttrsAssignment( + explicit OutputOperatorAttrsAssignment( std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr> const &assignments); @@ -36,8 +36,8 @@ struct OutputOperatorAttrsAssignment { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OutputOperatorAttrsAssignment const &) const; +struct hash<::FlexFlow::OutputOperatorAttrsAssignment> { + size_t operator()(::FlexFlow::OutputOperatorAttrsAssignment const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h index 0c0cc41891..98aec04e61 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h @@ -17,9 +17,9 @@ namespace FlexFlow { struct PCGPattern { PCGPattern() = delete; - PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< - ::FlexFlow::OperatorAttributePattern, - ::FlexFlow::TensorAttributePattern> const &raw_graph); + explicit PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph); ::FlexFlow::OutputLabelledOpenMultiDiGraph< ::FlexFlow::OperatorAttributePattern, diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h index d31d65d83b..f0d6882dc9 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct SubParallelComputationGraph { SubParallelComputationGraph() = delete; - SubParallelComputationGraph( + explicit SubParallelComputationGraph( ::FlexFlow::OutputLabelledOpenMultiDiGraph< ::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs> const &raw_graph); diff --git a/lib/substitutions/include/substitutions/substitution.dtg.h b/lib/substitutions/include/substitutions/substitution.dtg.h index 5f50d9bafc..3515299acb 100644 --- a/lib/substitutions/include/substitutions/substitution.dtg.h +++ b/lib/substitutions/include/substitutions/substitution.dtg.h @@ -16,14 +16,14 @@ namespace FlexFlow { struct Substitution { Substitution() = delete; - Substitution(::FlexFlow::PCGPattern const &pcg_pattern, - ::FlexFlow::OutputGraphExpr const &output_graph_expr, - ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, - ::FlexFlow::InputMultiDiEdge> const - &input_edge_match_to_output, - ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::OutputMultiDiEdge> const - &output_edge_match_to_output); + explicit Substitution(::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output); ::FlexFlow::PCGPattern pcg_pattern; ::FlexFlow::OutputGraphExpr output_graph_expr; diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h index ba705a5d35..16807ff37c 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct TensorAttributeConstraint { TensorAttributeConstraint() = delete; - TensorAttributeConstraint( + explicit TensorAttributeConstraint( ::FlexFlow::ConstraintType const &constraint_type, ::FlexFlow::TensorAttributeExpr const &attribute_expr, ::FlexFlow::TensorAttributeValue const &attribute_value); @@ -41,16 +41,16 @@ struct TensorAttributeConstraint { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorAttributeConstraint const &) const; +struct hash<::FlexFlow::TensorAttributeConstraint> { + size_t operator()(::FlexFlow::TensorAttributeConstraint const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorAttributeConstraint from_json(json const &); - static void to_json(json &, FlexFlow::TensorAttributeConstraint const &); +struct adl_serializer<::FlexFlow::TensorAttributeConstraint> { + static ::FlexFlow::TensorAttributeConstraint from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeConstraint const &); }; } // namespace nlohmann diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h index 473f4e1698..e81d2fcc04 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h @@ -21,7 +21,7 @@ namespace FlexFlow { struct TensorAttributeListIndexAccess { TensorAttributeListIndexAccess() = delete; - TensorAttributeListIndexAccess( + explicit TensorAttributeListIndexAccess( ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index); bool operator==(TensorAttributeListIndexAccess const &) const; @@ -37,23 +37,24 @@ struct TensorAttributeListIndexAccess { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorAttributeListIndexAccess const &) const; +struct hash<::FlexFlow::TensorAttributeListIndexAccess> { + size_t operator()(::FlexFlow::TensorAttributeListIndexAccess const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorAttributeListIndexAccess from_json(json const &); - static void to_json(json &, FlexFlow::TensorAttributeListIndexAccess const &); +struct adl_serializer<::FlexFlow::TensorAttributeListIndexAccess> { + static ::FlexFlow::TensorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, + ::FlexFlow::TensorAttributeListIndexAccess const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TensorAttributeListIndexAccess> { + static Gen<::FlexFlow::TensorAttributeListIndexAccess> arbitrary(); }; } // namespace rc diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h index 1630014bdf..5516a4b07b 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h @@ -21,7 +21,8 @@ namespace FlexFlow { struct TensorAttributeListSize { TensorAttributeListSize() = delete; - TensorAttributeListSize(::FlexFlow::TensorAttributeKey const &attribute_key); + explicit TensorAttributeListSize( + ::FlexFlow::TensorAttributeKey const &attribute_key); bool operator==(TensorAttributeListSize const &) const; bool operator!=(TensorAttributeListSize const &) const; @@ -35,23 +36,23 @@ struct TensorAttributeListSize { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorAttributeListSize const &) const; +struct hash<::FlexFlow::TensorAttributeListSize> { + size_t operator()(::FlexFlow::TensorAttributeListSize const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorAttributeListSize from_json(json const &); - static void to_json(json &, FlexFlow::TensorAttributeListSize const &); +struct adl_serializer<::FlexFlow::TensorAttributeListSize> { + static ::FlexFlow::TensorAttributeListSize from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeListSize const &); }; } // namespace nlohmann namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::TensorAttributeListSize> { + static Gen<::FlexFlow::TensorAttributeListSize> arbitrary(); }; } // namespace rc diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h index ecc4bc7da0..a106b59073 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h @@ -22,7 +22,7 @@ namespace FlexFlow { struct TensorAttributePattern { TensorAttributePattern() = delete; - TensorAttributePattern( + explicit TensorAttributePattern( std::unordered_set<::FlexFlow::TensorAttributeConstraint> const &attribute_constraints); @@ -35,16 +35,16 @@ struct TensorAttributePattern { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::TensorAttributePattern const &) const; +struct hash<::FlexFlow::TensorAttributePattern> { + size_t operator()(::FlexFlow::TensorAttributePattern const &) const; }; } // namespace std namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::TensorAttributePattern from_json(json const &); - static void to_json(json &, FlexFlow::TensorAttributePattern const &); +struct adl_serializer<::FlexFlow::TensorAttributePattern> { + static ::FlexFlow::TensorAttributePattern from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributePattern const &); }; } // namespace nlohmann diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h index 6bf815791d..c67b508928 100644 --- a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct ClosedPatternEdge { ClosedPatternEdge() = delete; - ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge); + explicit ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge); bool operator==(ClosedPatternEdge const &) const; bool operator!=(ClosedPatternEdge const &) const; @@ -31,8 +31,8 @@ struct ClosedPatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ClosedPatternEdge const &) const; +struct hash<::FlexFlow::ClosedPatternEdge> { + size_t operator()(::FlexFlow::ClosedPatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h index 5ce0e63073..4eb6cbee7a 100644 --- a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h @@ -17,7 +17,8 @@ namespace FlexFlow { struct DownwardOpenPatternEdge { DownwardOpenPatternEdge() = delete; - DownwardOpenPatternEdge(::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge); + explicit DownwardOpenPatternEdge( + ::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge); bool operator==(DownwardOpenPatternEdge const &) const; bool operator!=(DownwardOpenPatternEdge const &) const; @@ -31,8 +32,8 @@ struct DownwardOpenPatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::DownwardOpenPatternEdge const &) const; +struct hash<::FlexFlow::DownwardOpenPatternEdge> { + size_t operator()(::FlexFlow::DownwardOpenPatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h index e92fe547b1..a69a5b5f6b 100644 --- a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h @@ -18,7 +18,7 @@ namespace FlexFlow { struct UnlabelledPatternEdgeSplits { UnlabelledPatternEdgeSplits() = delete; - UnlabelledPatternEdgeSplits( + explicit UnlabelledPatternEdgeSplits( ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>> const diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h index f292acba14..1240244762 100644 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct InputPatternEdge { InputPatternEdge() = delete; - InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge); + explicit InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge); bool operator==(InputPatternEdge const &) const; bool operator!=(InputPatternEdge const &) const; @@ -31,8 +31,8 @@ struct InputPatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::InputPatternEdge const &) const; +struct hash<::FlexFlow::InputPatternEdge> { + size_t operator()(::FlexFlow::InputPatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h index e910be21ba..f6c1df278a 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h @@ -18,7 +18,7 @@ namespace FlexFlow { struct MatchAdditionalCriterion { MatchAdditionalCriterion() = delete; - MatchAdditionalCriterion( + explicit MatchAdditionalCriterion( std::function const &node_criterion, std::function const &node_assignment, ::FlexFlow::bidict<::FlexFlow::PatternEdge, diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h index 04ec8c656d..0b8994fbff 100644 --- a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct OutputPatternEdge { OutputPatternEdge() = delete; - OutputPatternEdge(::FlexFlow::OutputMultiDiEdge const &raw_edge); + explicit OutputPatternEdge(::FlexFlow::OutputMultiDiEdge const &raw_edge); bool operator==(OutputPatternEdge const &) const; bool operator!=(OutputPatternEdge const &) const; @@ -31,8 +31,8 @@ struct OutputPatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::OutputPatternEdge const &) const; +struct hash<::FlexFlow::OutputPatternEdge> { + size_t operator()(::FlexFlow::OutputPatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h index 4883590130..8303cd8c9c 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct PatternEdge { PatternEdge() = delete; - PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); + explicit PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); bool operator==(PatternEdge const &) const; bool operator!=(PatternEdge const &) const; @@ -31,8 +31,8 @@ struct PatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::PatternEdge const &) const; +struct hash<::FlexFlow::PatternEdge> { + size_t operator()(::FlexFlow::PatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h index 56471c2e08..a8e473382c 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h @@ -17,7 +17,7 @@ namespace FlexFlow { struct PatternNode { PatternNode() = delete; - PatternNode(::FlexFlow::Node const &raw_node); + explicit PatternNode(::FlexFlow::Node const &raw_node); bool operator==(PatternNode const &) const; bool operator!=(PatternNode const &) const; @@ -31,8 +31,8 @@ struct PatternNode { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::PatternNode const &) const; +struct hash<::FlexFlow::PatternNode> { + size_t operator()(::FlexFlow::PatternNode const &) const; }; } // namespace std diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h index 453c4020a8..fb5c1d9b25 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h @@ -21,8 +21,9 @@ namespace FlexFlow { struct PatternSplit { PatternSplit() = delete; - PatternSplit(std::unordered_set<::FlexFlow::PatternNode> const &first, - std::unordered_set<::FlexFlow::PatternNode> const &second); + explicit PatternSplit( + std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second); bool operator==(PatternSplit const &) const; bool operator!=(PatternSplit const &) const; @@ -33,9 +34,9 @@ struct PatternSplit { namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::PatternSplit from_json(json const &); - static void to_json(json &, FlexFlow::PatternSplit const &); +struct adl_serializer<::FlexFlow::PatternSplit> { + static ::FlexFlow::PatternSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::PatternSplit const &); }; } // namespace nlohmann diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h index a2ba6c26d2..972dda4200 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h @@ -15,7 +15,8 @@ namespace FlexFlow { struct UnlabelledGraphPattern { UnlabelledGraphPattern() = delete; - UnlabelledGraphPattern(::FlexFlow::OpenMultiDiGraphView const &raw_graph); + explicit UnlabelledGraphPattern( + ::FlexFlow::OpenMultiDiGraphView const &raw_graph); ::FlexFlow::OpenMultiDiGraphView raw_graph; }; diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h index 82440b5820..e94403feb4 100644 --- a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h @@ -17,7 +17,8 @@ namespace FlexFlow { struct UpwardOpenPatternEdge { UpwardOpenPatternEdge() = delete; - UpwardOpenPatternEdge(::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge); + explicit UpwardOpenPatternEdge( + ::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge); bool operator==(UpwardOpenPatternEdge const &) const; bool operator!=(UpwardOpenPatternEdge const &) const; @@ -31,8 +32,8 @@ struct UpwardOpenPatternEdge { namespace std { template <> -struct hash { - size_t operator()(FlexFlow::UpwardOpenPatternEdge const &) const; +struct hash<::FlexFlow::UpwardOpenPatternEdge> { + size_t operator()(::FlexFlow::UpwardOpenPatternEdge const &) const; }; } // namespace std diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc index bc913b7c1a..2956dad2c4 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc @@ -73,7 +73,7 @@ bool OperatorAttributeConstraint::operator>=( namespace std { size_t hash::operator()( - FlexFlow::OperatorAttributeConstraint const &x) const { + ::FlexFlow::OperatorAttributeConstraint const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -86,17 +86,17 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::OperatorAttributeConstraint - adl_serializer::from_json( +::FlexFlow::OperatorAttributeConstraint + adl_serializer<::FlexFlow::OperatorAttributeConstraint>::from_json( json const &j) { - return { + return ::FlexFlow::OperatorAttributeConstraint{ j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), j.at("attribute_expr").template get<::FlexFlow::OperatorAttributeExpr>(), j.at("attribute_value") .template get<::FlexFlow::OperatorAttributeValue>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::OperatorAttributeConstraint const &v) { +void adl_serializer<::FlexFlow::OperatorAttributeConstraint>::to_json( + json &j, ::FlexFlow::OperatorAttributeConstraint const &v) { j["__type"] = "OperatorAttributeConstraint"; j["constraint_type"] = v.constraint_type; j["attribute_expr"] = v.attribute_expr; diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc index 71b71d4a51..67e3761515 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc @@ -50,7 +50,7 @@ bool OperatorAttributeListIndexAccess::operator>=( namespace std { size_t hash::operator()( - FlexFlow::OperatorAttributeListIndexAccess const &x) const { + ::FlexFlow::OperatorAttributeListIndexAccess const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -61,15 +61,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::OperatorAttributeListIndexAccess - adl_serializer::from_json( +::FlexFlow::OperatorAttributeListIndexAccess + adl_serializer<::FlexFlow::OperatorAttributeListIndexAccess>::from_json( json const &j) { - return { + return ::FlexFlow::OperatorAttributeListIndexAccess{ j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>(), j.at("index").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::OperatorAttributeListIndexAccess const &v) { +void adl_serializer<::FlexFlow::OperatorAttributeListIndexAccess>::to_json( + json &j, ::FlexFlow::OperatorAttributeListIndexAccess const &v) { j["__type"] = "OperatorAttributeListIndexAccess"; j["attribute_key"] = v.attribute_key; j["index"] = v.index; @@ -77,9 +77,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::OperatorAttributeListIndexAccess> + Arbitrary<::FlexFlow::OperatorAttributeListIndexAccess>::arbitrary() { + return gen::construct<::FlexFlow::OperatorAttributeListIndexAccess>( gen::arbitrary<::FlexFlow::OperatorAttributeKey>(), gen::arbitrary()); } diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc index eb7ae28131..2879aca911 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc @@ -44,7 +44,7 @@ bool OperatorAttributeListSize::operator>=( namespace std { size_t hash::operator()( - FlexFlow::OperatorAttributeListSize const &x) const { + ::FlexFlow::OperatorAttributeListSize const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -53,23 +53,23 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::OperatorAttributeListSize - adl_serializer::from_json( +::FlexFlow::OperatorAttributeListSize + adl_serializer<::FlexFlow::OperatorAttributeListSize>::from_json( json const &j) { - return { + return ::FlexFlow::OperatorAttributeListSize{ j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::OperatorAttributeListSize const &v) { +void adl_serializer<::FlexFlow::OperatorAttributeListSize>::to_json( + json &j, ::FlexFlow::OperatorAttributeListSize const &v) { j["__type"] = "OperatorAttributeListSize"; j["attribute_key"] = v.attribute_key; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::OperatorAttributeListSize> + Arbitrary<::FlexFlow::OperatorAttributeListSize>::arbitrary() { + return gen::construct<::FlexFlow::OperatorAttributeListSize>( gen::arbitrary<::FlexFlow::OperatorAttributeKey>()); } } // namespace rc diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc index 5eaf54bb5f..7aca1e75fc 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -33,7 +33,7 @@ bool OperatorAttributePattern::operator!=( namespace std { size_t hash::operator()( - FlexFlow::OperatorAttributePattern const &x) const { + ::FlexFlow::OperatorAttributePattern const &x) const { size_t result = 0; result ^= std::hash>{}( @@ -44,16 +44,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::OperatorAttributePattern - adl_serializer::from_json( +::FlexFlow::OperatorAttributePattern + adl_serializer<::FlexFlow::OperatorAttributePattern>::from_json( json const &j) { - return { + return ::FlexFlow::OperatorAttributePattern{ j.at("attribute_constraints") .template get< std::unordered_set<::FlexFlow::OperatorAttributeConstraint>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::OperatorAttributePattern const &v) { +void adl_serializer<::FlexFlow::OperatorAttributePattern>::to_json( + json &j, ::FlexFlow::OperatorAttributePattern const &v) { j["__type"] = "OperatorAttributePattern"; j["attribute_constraints"] = v.attribute_constraints; } diff --git a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc index f20afc1164..c0dc667822 100644 --- a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc @@ -37,7 +37,7 @@ bool AttrConstant::operator>=(AttrConstant const &other) const { namespace std { size_t hash::operator()( - FlexFlow::AttrConstant const &x) const { + ::FlexFlow::AttrConstant const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc index 0c6abc925d..2864ccbfac 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc @@ -52,7 +52,7 @@ bool OutputOperatorAttrAccess::operator>=( namespace std { size_t hash::operator()( - FlexFlow::OutputOperatorAttrAccess const &x) const { + ::FlexFlow::OutputOperatorAttrAccess const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc index 7a1950482a..98183c9a14 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc @@ -32,7 +32,7 @@ bool OutputOperatorAttrsAssignment::operator!=( namespace std { size_t hash::operator()( - FlexFlow::OutputOperatorAttrsAssignment const &x) const { + ::FlexFlow::OutputOperatorAttrsAssignment const &x) const { size_t result = 0; result ^= std::hash=( namespace std { size_t hash::operator()( - FlexFlow::TensorAttributeConstraint const &x) const { + ::FlexFlow::TensorAttributeConstraint const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -86,16 +86,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorAttributeConstraint - adl_serializer::from_json( +::FlexFlow::TensorAttributeConstraint + adl_serializer<::FlexFlow::TensorAttributeConstraint>::from_json( json const &j) { - return { + return ::FlexFlow::TensorAttributeConstraint{ j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), j.at("attribute_expr").template get<::FlexFlow::TensorAttributeExpr>(), j.at("attribute_value").template get<::FlexFlow::TensorAttributeValue>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorAttributeConstraint const &v) { +void adl_serializer<::FlexFlow::TensorAttributeConstraint>::to_json( + json &j, ::FlexFlow::TensorAttributeConstraint const &v) { j["__type"] = "TensorAttributeConstraint"; j["constraint_type"] = v.constraint_type; j["attribute_expr"] = v.attribute_expr; diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc index 4e28de2c28..c7e81718ed 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc @@ -50,7 +50,7 @@ bool TensorAttributeListIndexAccess::operator>=( namespace std { size_t hash::operator()( - FlexFlow::TensorAttributeListIndexAccess const &x) const { + ::FlexFlow::TensorAttributeListIndexAccess const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -61,14 +61,15 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorAttributeListIndexAccess - adl_serializer::from_json( +::FlexFlow::TensorAttributeListIndexAccess + adl_serializer<::FlexFlow::TensorAttributeListIndexAccess>::from_json( json const &j) { - return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>(), - j.at("index").template get()}; + return ::FlexFlow::TensorAttributeListIndexAccess{ + j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>(), + j.at("index").template get()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorAttributeListIndexAccess const &v) { +void adl_serializer<::FlexFlow::TensorAttributeListIndexAccess>::to_json( + json &j, ::FlexFlow::TensorAttributeListIndexAccess const &v) { j["__type"] = "TensorAttributeListIndexAccess"; j["attribute_key"] = v.attribute_key; j["index"] = v.index; @@ -76,9 +77,9 @@ void adl_serializer::to_json( } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::TensorAttributeListIndexAccess> + Arbitrary<::FlexFlow::TensorAttributeListIndexAccess>::arbitrary() { + return gen::construct<::FlexFlow::TensorAttributeListIndexAccess>( gen::arbitrary<::FlexFlow::TensorAttributeKey>(), gen::arbitrary()); } } // namespace rc diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc index 24d8b6c025..52a61a8a87 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc @@ -44,7 +44,7 @@ bool TensorAttributeListSize::operator>=( namespace std { size_t hash::operator()( - FlexFlow::TensorAttributeListSize const &x) const { + ::FlexFlow::TensorAttributeListSize const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -53,22 +53,23 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorAttributeListSize - adl_serializer::from_json( +::FlexFlow::TensorAttributeListSize + adl_serializer<::FlexFlow::TensorAttributeListSize>::from_json( json const &j) { - return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>()}; + return ::FlexFlow::TensorAttributeListSize{ + j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorAttributeListSize const &v) { +void adl_serializer<::FlexFlow::TensorAttributeListSize>::to_json( + json &j, ::FlexFlow::TensorAttributeListSize const &v) { j["__type"] = "TensorAttributeListSize"; j["attribute_key"] = v.attribute_key; } } // namespace nlohmann namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::TensorAttributeListSize> + Arbitrary<::FlexFlow::TensorAttributeListSize>::arbitrary() { + return gen::construct<::FlexFlow::TensorAttributeListSize>( gen::arbitrary<::FlexFlow::TensorAttributeKey>()); } } // namespace rc diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc index 121549d4dc..8f96fd49b8 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc @@ -33,7 +33,7 @@ bool TensorAttributePattern::operator!=( namespace std { size_t hash::operator()( - FlexFlow::TensorAttributePattern const &x) const { + ::FlexFlow::TensorAttributePattern const &x) const { size_t result = 0; result ^= std::hash>{}( @@ -44,14 +44,16 @@ size_t hash::operator()( } // namespace std namespace nlohmann { -FlexFlow::TensorAttributePattern - adl_serializer::from_json(json const &j) { - return {j.at("attribute_constraints") - .template get< - std::unordered_set<::FlexFlow::TensorAttributeConstraint>>()}; +::FlexFlow::TensorAttributePattern + adl_serializer<::FlexFlow::TensorAttributePattern>::from_json( + json const &j) { + return ::FlexFlow::TensorAttributePattern{ + j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::TensorAttributeConstraint>>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::TensorAttributePattern const &v) { +void adl_serializer<::FlexFlow::TensorAttributePattern>::to_json( + json &j, ::FlexFlow::TensorAttributePattern const &v) { j["__type"] = "TensorAttributePattern"; j["attribute_constraints"] = v.attribute_constraints; } diff --git a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc index fbefc6f01a..401c738d88 100644 --- a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc @@ -36,7 +36,7 @@ bool ClosedPatternEdge::operator>=(ClosedPatternEdge const &other) const { namespace std { size_t hash::operator()( - FlexFlow::ClosedPatternEdge const &x) const { + ::FlexFlow::ClosedPatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::MultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc index 30c52fbbb2..65c87db0e4 100644 --- a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc @@ -43,7 +43,7 @@ bool DownwardOpenPatternEdge::operator>=( namespace std { size_t hash::operator()( - FlexFlow::DownwardOpenPatternEdge const &x) const { + ::FlexFlow::DownwardOpenPatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::DownwardOpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc index f3f5a8ce45..e46becf4be 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc @@ -36,7 +36,7 @@ bool InputPatternEdge::operator>=(InputPatternEdge const &other) const { namespace std { size_t hash::operator()( - FlexFlow::InputPatternEdge const &x) const { + ::FlexFlow::InputPatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::InputMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc index fb9de06135..152115d52a 100644 --- a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc @@ -37,7 +37,7 @@ bool OutputPatternEdge::operator>=(OutputPatternEdge const &other) const { namespace std { size_t hash::operator()( - FlexFlow::OutputPatternEdge const &x) const { + ::FlexFlow::OutputPatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OutputMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc index e4d11d0d7e..a19e5bb6d1 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc @@ -36,7 +36,7 @@ bool PatternEdge::operator>=(PatternEdge const &other) const { namespace std { size_t hash::operator()( - FlexFlow::PatternEdge const &x) const { + ::FlexFlow::PatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::OpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc index 6ea64de69e..b2cd557c06 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc @@ -36,7 +36,7 @@ bool PatternNode::operator>=(PatternNode const &other) const { namespace std { size_t hash::operator()( - FlexFlow::PatternNode const &x) const { + ::FlexFlow::PatternNode const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc index bbcd4c3902..d678a1edfe 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc @@ -30,15 +30,15 @@ bool PatternSplit::operator!=(PatternSplit const &other) const { } // namespace FlexFlow namespace nlohmann { -FlexFlow::PatternSplit - adl_serializer::from_json(json const &j) { - return { +::FlexFlow::PatternSplit + adl_serializer<::FlexFlow::PatternSplit>::from_json(json const &j) { + return ::FlexFlow::PatternSplit{ j.at("first").template get>(), j.at("second") .template get>()}; } -void adl_serializer::to_json( - json &j, FlexFlow::PatternSplit const &v) { +void adl_serializer<::FlexFlow::PatternSplit>::to_json( + json &j, ::FlexFlow::PatternSplit const &v) { j["__type"] = "PatternSplit"; j["first"] = v.first; j["second"] = v.second; diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc index ca8dd6c020..1fe34ed778 100644 --- a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc @@ -43,7 +43,7 @@ bool UpwardOpenPatternEdge::operator>=( namespace std { size_t hash::operator()( - FlexFlow::UpwardOpenPatternEdge const &x) const { + ::FlexFlow::UpwardOpenPatternEdge const &x) const { size_t result = 0; result ^= std::hash<::FlexFlow::UpwardOpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + (result << 6) + (result >> 2); diff --git a/lib/utils/include/utils/containers/without_nullopts.h b/lib/utils/include/utils/containers/without_nullopts.h new file mode 100644 index 0000000000..f888654b60 --- /dev/null +++ b/lib/utils/include/utils/containers/without_nullopts.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_NULLOPTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_NULLOPTS_H + +#include +#include + +namespace FlexFlow { + +template +std::vector without_nullopts(std::vector> const &v) { + std::vector result; + for (std::optional const &t : v) { + if (t.has_value()) { + result.push_back(t.value()); + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 04902c8240..5b8d474025 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -26,31 +26,6 @@ typename std::enable_if>::value, namespace fmt { -template -struct formatter< - ::std::unordered_set, - Char, - std::enable_if_t>::value>> - : formatter<::std::string, Char> { - template - auto format(::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()); -}; - -/* template */ -/* std::string format_as(::std::unordered_set const &); */ - -template -struct formatter< - ::std::vector, - Char, - std::enable_if_t>::value>> - : formatter<::std::string> { - template - auto format(::std::vector const &m, FormatContext &ctx) - -> decltype(ctx.out()); -}; - template struct formatter<::std::variant> : formatter<::std::string> { template diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 967a41f22b..72fca552d8 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -12,38 +12,6 @@ namespace fmt { -template -template -auto formatter< - ::std::unordered_set, - Char, - std::enable_if_t>::value>>:: - format(::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - /* CHECK_FMTABLE(T); */ - - /* std::string result = ::FlexFlow::join_strings( */ - /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); - * }); */ - std::string result = ""; - return formatter::format(result, ctx); -} - -template -template -auto formatter< - ::std::vector, - Char, - std::enable_if_t>::value>>:: - format(::std::vector const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - CHECK_FMTABLE(T); - - std::string result = ::FlexFlow::join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format("[" + result + "]", ctx); -} - template template auto formatter<::std::variant>::format(::std::variant const &m, @@ -58,15 +26,6 @@ auto formatter<::std::variant>::format(::std::variant const &m, namespace FlexFlow { -template -struct delegate_ostream_operator> : std::true_type {}; - -template -struct delegate_ostream_operator> : std::true_type {}; - -template -struct delegate_ostream_operator> : std::true_type {}; - template struct delegate_ostream_operator> : std::true_type {}; diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h new file mode 100644 index 0000000000..1ce36fa97a --- /dev/null +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_SET_H + +#include +#include "utils/join_strings.h" +#include +#include "utils/check_fmtable.h" + +namespace fmt { + +template +struct formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = ::FlexFlow::join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format("{" + result + "}", ctx); + } +}; + +} + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h new file mode 100644 index 0000000000..82cdcfdb3c --- /dev/null +++ b/lib/utils/include/utils/fmt/vector.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VECTOR_H + +#include +#include "utils/join_strings.h" +#include +#include "utils/check_fmtable.h" + +namespace fmt { + +template +struct formatter< + ::std::vector, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::vector const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = ::FlexFlow::join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format("[" + result + "]", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::vector const &v) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(v); +} + +} // namespace FlexFlow + + +#endif diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index d7c2c1590b..2a9a417e7e 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -6,6 +6,7 @@ #include "node_port.h" #include "utils/strong_typedef.h" #include "utils/visitable.h" +#include "utils/fmt/pair.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index d47886b055..142d4fe7b5 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -12,6 +12,7 @@ #include #include #include +#include "utils/fmt/vector.h" namespace FlexFlow { @@ -294,7 +295,7 @@ struct stack_vector { } friend std::vector format_as(stack_vector const &v) { - // CHECK_FMTABLE(std::vector); + CHECK_FMTABLE(std::vector); return static_cast>(v); } @@ -314,9 +315,9 @@ struct stack_vector { }; template -struct delegate_ostream_operator> : std::true_type {}; - -// CHECK_FMTABLE(stack_vector); +std::ostream &operator<<(std::ostream &s, stack_vector const &v) { + return s << fmt::to_string(v); +} template void to_json(json &j, stack_vector const &v) { From 35fa653aef1bbce8f0f38fd36407f2440ce30afb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 10 Jun 2024 00:37:19 -0700 Subject: [PATCH 04/71] Add remainder of PCG tests --- .../include/op-attrs/get_output_shapes.h | 6 - lib/op-attrs/include/op-attrs/ops/attention.h | 32 +- lib/op-attrs/include/op-attrs/ops/cast.h | 9 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 2 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 123 ++++- lib/op-attrs/src/op-attrs/ops/cast.cc | 29 ++ lib/op-attrs/test/src/ops/attention.cc | 467 +++++++++-------- lib/op-attrs/test/src/ops/cast.cc | 58 +++ .../parallel_computation_graph_builder.h | 6 +- .../parallel_computation_graph_builder.cc | 252 ++++++++-- .../parallel_computation_graph_builder.cc | 470 +++++++++++++++++- 11 files changed, 1189 insertions(+), 265 deletions(-) create mode 100644 lib/op-attrs/test/src/ops/cast.cc diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index a826e1cb54..25b1092ed3 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -112,12 +112,8 @@ std::vector get_output_shapes(Attrs const &attrs, ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(CastAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ConcatAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(Conv2DAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, @@ -131,8 +127,6 @@ ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReplicateAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(SplitAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 8233775e63..e126c425dc 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -42,17 +42,37 @@ tl::expected TensorShape const &input_q, TensorShape const &input_k, TensorShape const &input_v); -tl::expected - get_weights_shape(MultiHeadAttentionAttrs const &, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v); - +tl::expected + get_input_bias_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected + get_output_bias_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); tl::expected get_output_shape(MultiHeadAttentionAttrs const &, TensorShape const &input_q, TensorShape const &input_k, TensorShape const &input_v); + +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_input_bias_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_output_bias_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); tl::expected get_output_shape(MultiHeadAttentionAttrs const &, ParallelTensorShape const &input_q, diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 117dcb1e01..8a97bbafe6 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -1,12 +1,19 @@ #ifndef _FLEXFLOW_CAST_ATTRS_H #define _FLEXFLOW_CAST_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(CastAttrs); + +tl::expected get_output_shape(CastAttrs const &, TensorShape const &); +tl::expected get_output_shape(CastAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 7759380088..d80e4b5862 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_CONV_2D_ATTRS_H #define _FLEXFLOW_CONV_2D_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 14ab2b9b00..834c3b7330 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -114,7 +114,7 @@ tl::expected // W^O in "Attention Is All You Need" top of page 5, with num_heads factored // out - size_t outWeightSize = parsed.value_size * attrs.embed_dim; + size_t outWeightSize = attrs.vdim * attrs.embed_dim; return TensorShape{ TensorDims{FFOrdered{ @@ -126,6 +126,51 @@ tl::expected }; } +tl::expected + get_input_bias_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + MultiHeadAttentionInputs parsed = ({ + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + parse_result.value(); + }); + + return TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(attrs.kdim + attrs.kdim + attrs.vdim), + }}, + parsed.datatype, + }; +} + +tl::expected + get_output_bias_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + MultiHeadAttentionInputs parsed = ({ + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + parse_result.value(); + }); + + return TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(attrs.embed_dim), + }}, + parsed.datatype, + }; +} + + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &input_q, @@ -158,6 +203,82 @@ tl::expected FFOrdered{joined_dim_degree, head_dim_degree}); } +tl::expected + get_input_bias_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + MultiHeadAttentionParallelInputs parsed = ({ + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + + parse_result.value(); + }); + + TensorShape unpar_shape = ({ + tl::expected result_unpar = get_input_bias_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + + result_unpar.value(); + }); + + SumDegree sum_degree = SumDegree{1}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + parsed.batch_dim.degree * parsed.discard_copy_degree.value + }; + FFOrdered shard_degrees = FFOrdered{1}; + return lift_to_parallel_with_degrees(unpar_shape, + sum_degree, + discard_copy_degree, + shard_degrees); +} + +tl::expected + get_output_bias_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + MultiHeadAttentionParallelInputs parsed = ({ + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + + parse_result.value(); + }); + + TensorShape unpar_shape = ({ + tl::expected result_unpar = get_output_bias_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + + result_unpar.value(); + }); + + SumDegree sum_degree = SumDegree{1}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + parsed.batch_dim.degree * parsed.discard_copy_degree.value + }; + FFOrdered shard_degrees = FFOrdered{1}; + return lift_to_parallel_with_degrees(unpar_shape, + sum_degree, + discard_copy_degree, + shard_degrees); +} + tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &input_q, diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index e4ab178a7e..e4ae8f9759 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -2,6 +2,35 @@ namespace FlexFlow { +tl::expected + get_output_shape(CastAttrs const &attrs, TensorShape const &input) { + + if (!can_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { + return tl::unexpected(fmt::format("Cast cannot strictly promote input datatype {} to output datatype {}", + input.data_type, + attrs.dtype)); + } + + TensorShape output = input; + output.data_type = attrs.dtype; + return output; +} + +tl::expected + get_output_shape(CastAttrs const &attrs, ParallelTensorShape const &input) { + + if (!can_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { + return tl::unexpected(fmt::format("Cast cannot strictly promote input datatype {} to output datatype {}", + input.data_type, + attrs.dtype)); + } + + ParallelTensorShape output = input; + output.data_type = attrs.dtype; + + return output; +} + /* bool CastAttrs::is_valid(ParallelTensorShape const &input) const { */ /* bool valid = input.is_valid(); */ /* valid &= (input.at(input.num_dims() - 1).degree == 1); */ diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/ops/attention.cc index 2c7121e4a8..7f69d57cd7 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -7,13 +7,14 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " "TensorShape, TensorShape)") { int embed_dim = 32; + int num_heads = 10; /* Parameter meanings match those at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ /*embed_dim=*/embed_dim, - /*num_heads=*/10, + /*num_heads=*/num_heads, /*kdim=*/embed_dim, /*vdim=*/embed_dim, /*dropout=*/0.0, @@ -24,259 +25,291 @@ TEST_SUITE(FF_TEST_SUITE) { size_t batch_size = 40; size_t seq_len = 48; + size_t feature_size = 36; TensorShape input_q = TensorShape{ - TensorDims{FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.embed_dim), - }}, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, + }, + }, DataType::FLOAT, }; TensorShape input_k = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.kdim), - }, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, }, + }, DataType::FLOAT, }; TensorShape input_v = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.vdim), - }, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, }, + }, DataType::FLOAT, }; - SUBCASE("get_output_shape") { - tl::expected result = - get_output_shape(attrs, input_q, input_k, input_v); - - tl::expected correct = TensorShape{ - TensorDims{FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.embed_dim), - }}, - DataType::FLOAT, - }; - - CHECK(result == correct); - } - - SUBCASE("get_weights_shape") { - tl::expected result = - get_weights_shape(attrs, input_q, input_k, input_v); - - int qProjPerHeadWeightSize = - attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); - int kProjPerHeadWeightSize = - attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); - int vProjPerHeadWeightSize = - attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); - int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; - int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + - vProjPerHeadWeightSize + oProjPerHeadWeightSize; - - tl::expected correct = TensorShape{ - TensorDims{FFOrdered{ - size_t_from_int(perHeadWeightSize), - size_t_from_int(attrs.num_heads), - }}, - DataType::FLOAT, - }; - - CHECK(result == correct); - } - } - - TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { - int embed_dim = 32; - - /* Parameter meanings can be found at - * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html - */ - MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ - /*embed_dim=*/embed_dim, - /*num_heads=*/10, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }, + }, + DataType::FLOAT, }; - size_t batchsize = 40; - size_t seq_len = 48; - size_t q_size = 56; - size_t k_size = 64; - size_t v_size = 72; - - TensorShape unpar_q_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - q_size, - }, + TensorShape weights = TensorShape{ + TensorDims{ + FFOrdered{ + (feature_size * embed_dim) * 3 + (embed_dim * embed_dim), + size_t_from_int(num_heads), }, - DataType::FLOAT, + }, + DataType::FLOAT, }; - TensorShape unpar_k_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - k_size, - }, + TensorShape input_bias = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(embed_dim * 3), }, - DataType::FLOAT, + }, + DataType::FLOAT, }; - TensorShape unpar_v_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - v_size, - }, + TensorShape output_bias = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(embed_dim), }, - DataType::FLOAT, + }, + DataType::FLOAT, }; - tl::expected result_unpar_o_shape = - get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); - REQUIRE(result_unpar_o_shape.has_value()); - TensorShape unpar_o_shape = result_unpar_o_shape.value(); - - tl::expected result_unpar_w_shape = - get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); - REQUIRE(result_unpar_o_shape.has_value()); - TensorShape unpar_w_shape = result_unpar_w_shape.value(); - - auto make_q = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_batch, - int o_seq_len, - int o_q) { - return lift_to_parallel_with_degrees( - unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); - }; + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, input_q, input_k, input_v); - auto make_k = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_batch, - int o_seq_len, - int o_k) { - return lift_to_parallel_with_degrees( - unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); - }; + tl::expected correct = output; + CHECK(result == correct); + } - auto make_v = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_batch, - int o_seq_len, - int o_v) { - return lift_to_parallel_with_degrees( - unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); - }; + SUBCASE("get_weights_shape") { + tl::expected result = + get_weights_shape(attrs, input_q, input_k, input_v); - auto make_o = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_batch, - int o_seq_len, - int o_o) { - return lift_to_parallel_with_degrees( - unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); - }; + tl::expected correct = weights; + CHECK(result == correct); + } - auto make_w = - [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_e, int o_h) { - return lift_to_parallel_with_degrees( - unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); - }; - - SUBCASE("data parallelism") { - int o_b = 4; - ParallelTensorShape q = - make_q(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); - ParallelTensorShape k = - make_k(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); - ParallelTensorShape v = - make_v(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); - - tl::expected result_o = - get_output_shape(attrs, q, k, v); - tl::expected correct_o = - make_o(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); - - CHECK(result_o == correct_o); - - tl::expected result_w = - get_weights_shape(attrs, q, k, v); - tl::expected correct_w = - make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, 1); - - CHECK(result_w == correct_w); + SUBCASE("get_input_bias_shape") { + tl::expected result = + get_input_bias_shape(attrs, input_q, input_k, input_v); + tl::expected correct = input_bias; + CHECK(result == correct); } - SUBCASE("attention head parallelism") { - int o_h = 2; - ParallelTensorShape q = - make_q(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); - ParallelTensorShape k = - make_k(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); - ParallelTensorShape v = - make_v(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + SUBCASE("get_output_bias_shape") { + tl::expected result = + get_output_bias_shape(attrs, input_q, input_k, input_v); + tl::expected correct = output_bias; + CHECK(result == correct); + } - tl::expected result_o = - get_output_shape(attrs, q, k, v); - tl::expected correct_o = - make_o(SumDegree{o_h}, DiscardCopyDegree{1}, 1, 1, 1); + SUBCASE("parallel shape inference") { + auto make_q = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_q) { + return lift_to_parallel_with_degrees( + input_q, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); + }; - CHECK(result_o == correct_o); + auto make_k = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_k) { + return lift_to_parallel_with_degrees( + input_k, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); + }; - tl::expected result_w = - get_weights_shape(attrs, q, k, v); - tl::expected correct_w = - make_w(SumDegree{1}, DiscardCopyDegree{1}, 1, o_h); + auto make_v = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_v) { + return lift_to_parallel_with_degrees( + input_v, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); + }; - CHECK(result_w == correct_w); - } + auto make_o = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_o) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); + }; - SUBCASE("combined data & attention head parallelism") { - int o_b = 4; - int o_h = 2; - ParallelTensorShape q = - make_q(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); - ParallelTensorShape k = - make_k(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); - ParallelTensorShape v = - make_v(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); - - tl::expected result_o = - get_output_shape(attrs, q, k, v); - tl::expected correct_o = - make_o(SumDegree{o_h}, DiscardCopyDegree{1}, o_b, 1, 1); - - CHECK(result_o == correct_o); - - tl::expected result_w = - get_weights_shape(attrs, q, k, v); - tl::expected correct_w = - make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, o_h); - - CHECK(result_w == correct_w); + auto make_w = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_e, int o_h) { + return lift_to_parallel_with_degrees( + weights, o_sum, o_eq, FFOrdered{o_e, o_h}); + }; + + auto make_input_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_in_proj_channel) { + return lift_to_parallel_with_degrees( + input_bias, o_sum, o_eq, FFOrdered{o_in_proj_channel}); + }; + + auto make_output_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_out_proj_channel) { + return lift_to_parallel_with_degrees( + output_bias, o_sum, o_eq, FFOrdered{o_out_proj_channel}); + }; + + SUBCASE("data parallelism") { + int o_b = 4; + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, q, k, v); + tl::expected correct = + make_o(SumDegree{1}, DiscardCopyDegree{1}, o_b, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_weights_shape") { + tl::expected result = + get_weights_shape(attrs, q, k, v); + tl::expected correct = + make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_input_bias_shape") { + tl::expected result = + get_input_bias_shape(attrs, q, k, v); + tl::expected correct = + make_input_bias(SumDegree{1}, DiscardCopyDegree{o_b}, 1); + CHECK(result == correct); + } + + SUBCASE("get_output_bias_shape") { + tl::expected result = + get_output_bias_shape(attrs, q, k, v); + tl::expected correct = + make_output_bias(SumDegree{1}, DiscardCopyDegree{o_b}, 1); + CHECK(result == correct); + } + } + + SUBCASE("attention head parallelism") { + int o_h = 2; + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{o_h}, 1, 1, 1); + + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, q, k, v); + tl::expected correct = + make_o(SumDegree{o_h}, DiscardCopyDegree{1}, 1, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_weight_shape") { + tl::expected result = + get_weights_shape(attrs, q, k, v); + tl::expected correct = + make_w(SumDegree{1}, DiscardCopyDegree{1}, 1, o_h); + CHECK(result == correct); + } + + SUBCASE("get_input_bias_shape") { + tl::expected result = + get_input_bias_shape(attrs, q, k, v); + tl::expected correct = + make_input_bias(SumDegree{1}, DiscardCopyDegree{o_h}, 1); + CHECK(result == correct); + } + + SUBCASE("get_output_bias_shape") { + tl::expected result = + get_output_bias_shape(attrs, q, k, v); + tl::expected correct = + make_output_bias(SumDegree{1}, DiscardCopyDegree{o_h}, 1); + CHECK(result == correct); + } + } + + SUBCASE("combined data & attention head parallelism") { + int o_b = 4; + int o_h = 2; + ParallelTensorShape q = + make_q(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); + ParallelTensorShape k = + make_k(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); + ParallelTensorShape v = + make_v(SumDegree{1}, DiscardCopyDegree{o_h}, o_b, 1, 1); + + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, q, k, v); + tl::expected correct = + make_o(SumDegree{o_h}, DiscardCopyDegree{1}, o_b, 1, 1); + CHECK(result == correct); + } + + SUBCASE("get_weights_shape") { + tl::expected result = + get_weights_shape(attrs, q, k, v); + tl::expected correct = + make_w(SumDegree{1}, DiscardCopyDegree{o_b}, 1, o_h); + CHECK(result == correct); + } + + SUBCASE("get_input_bias_shape") { + tl::expected result = + get_input_bias_shape(attrs, q, k, v); + tl::expected correct = + make_input_bias(SumDegree{1}, DiscardCopyDegree{o_b * o_h}, 1); + CHECK(result == correct); + } + + SUBCASE("get_output_bias_shape") { + tl::expected result = + get_output_bias_shape(attrs, q, k, v); + tl::expected correct = + make_output_bias(SumDegree{1}, DiscardCopyDegree{o_b * o_h}, 1); + CHECK(result == correct); + } + } } } } diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc new file mode 100644 index 0000000000..9d2e79dfd9 --- /dev/null +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -0,0 +1,58 @@ +#include "op-attrs/ops/cast.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Cast shape inference") { + DataType input_datatype = DataType::FLOAT; + DataType output_datatype = DataType::DOUBLE; + + CastAttrs attrs = CastAttrs{output_datatype}; + + size_t d1 = 12; + size_t d2 = 16; + TensorShape input = TensorShape{ + TensorDims{FFOrdered{d1, d2}}, + input_datatype, + }; + + TensorShape output = TensorShape{ + TensorDims{FFOrdered{d1, d2}}, + output_datatype, + }; + + SUBCASE("get_output_shape(CastAttrs, TensorShape)") { + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = output; + CHECK(result == correct); + } + + SUBCASE("get_output_shape(CastAttrs, ParallelTensorShape)") { + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_features) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_features}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_outchannels) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); + }; + + SumDegree sum_degree = SumDegree{2}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{3}; + int batch_degree = 4; + int feature_degree = 8; + ParallelTensorShape par_input = make_input(sum_degree, discard_copy_degree, batch_degree, feature_degree); + + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = make_output(sum_degree, discard_copy_degree, batch_degree, feature_degree); + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index cdeb846af3..5b34ee641a 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -73,13 +73,15 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t const &value, int embed_dim, int num_heads, - int kdim = 0, - int vdim = 0, + std::optional kdim = std::nullopt, + std::optional vdim = std::nullopt, float dropout = 0.0f, bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, std::optional initializer = std::nullopt, + std::optional input_bias_initializer = std::nullopt, + std::optional output_bias_initializer = std::nullopt, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 90bc327a9a..9785982b08 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -50,24 +50,68 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( parallel_tensor_guid_t ParallelComputationGraphBuilder::add( parallel_tensor_guid_t const &lhs, parallel_tensor_guid_t const &rhs, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + ParallelTensorShape lhs_shape = this->get_shape(lhs); + ParallelTensorShape rhs_shape = this->get_shape(rhs); + + DataType datatype = [&] { + if (lhs_shape.data_type != rhs_shape.data_type) { + throw mk_runtime_error(fmt::format("Datatypes do not match: {} (lhs) != {} (rhs)", lhs_shape.data_type, rhs_shape.data_type)); + } else { + return lhs_shape.data_type; + } + }(); + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + datatype, + false, + false, + }; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, lhs_shape, rhs_shape)); + + return this->add_layer(layer, {lhs, rhs}, {}, output_shape); } parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( parallel_tensor_guid_t const &a, parallel_tensor_guid_t const &b, - /* int a_seq_length_dim = -1, */ - /* int b_seq_length_dim = -1, */ - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + BatchMatmulAttrs attrs = BatchMatmulAttrs{ + /*a_seq_length_dim=*/-1, + /*b_seq_length_dim=*/-1, + }; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(a), this->get_shape(b))); + + return this->add_layer(layer, {a, b}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( parallel_tensor_guid_t const &input, DataType result_type, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + CastAttrs attrs = CastAttrs{result_type}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::conv2d( @@ -129,8 +173,38 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( DataType data_type, std::optional const &kernel_initializer, std::optional const &bias_initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + LinearAttrs attrs = LinearAttrs{ + outDim, + use_bias, + data_type, + activation, + std::nullopt, + }; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + { + ParallelTensorShape kernel_shape = throw_if_unexpected(get_kernel_shape(attrs, input_shape)); + weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + } + + if (use_bias) { + ParallelTensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, input_shape)); + weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + } else if (bias_initializer.has_value()) { + throw mk_runtime_error("Dense received unexpected bias initializer even though use_bias is set to false"); + } + + return this->add_layer(layer, {input}, weights, output_shape); } parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( @@ -140,8 +214,27 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( AggregateOp aggr, DataType dtype, std::optional const &kernel_initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + EmbeddingAttrs attrs = EmbeddingAttrs{ + num_entries, + outDim, + aggr, + dtype, + }; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); + ParallelTensorShape weights_shape = throw_if_unexpected(get_weights_shape(attrs, input_shape)); + + ParallelTensorAttrs weights_attrs = make_weight_attrs(weights_shape, kernel_initializer); + + return this->add_layer(layer, {input}, {weights_attrs}, output_shape); } parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( @@ -150,51 +243,150 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t const &value, int embed_dim, int num_heads, - int kdim, - int vdim, + std::optional maybe_kdim, + std::optional maybe_vdim, float dropout, bool bias, bool add_bias_kv, bool add_zero_attn, std::optional initializer, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional input_bias_initializer, + std::optional output_bias_initializer, + std::optional const &maybe_name) { + + int kdim = maybe_kdim.value_or(embed_dim); + int vdim = maybe_vdim.value_or(embed_dim); + + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/kdim, + /*vdim=*/vdim, + /*dropout=*/dropout, + /*bias=*/bias, + /*add_bias_kv=*/add_bias_kv, + /*add_zero_attn=*/add_zero_attn, + }; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape query_shape = this->get_shape(query); + ParallelTensorShape key_shape = this->get_shape(key); + ParallelTensorShape value_shape = this->get_shape(value); + + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, query_shape, key_shape, value_shape)); + + std::vector weights; + + ParallelTensorAttrs weight_attrs = [&] { + ParallelTensorShape weight_shape = throw_if_unexpected(get_weights_shape(attrs, query_shape, key_shape, value_shape)); + return make_weight_attrs(weight_shape, initializer); + }(); + + weights.push_back(weight_attrs); + + if (bias) { + ParallelTensorShape input_bias_shape = throw_if_unexpected(get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back(make_weight_attrs(input_bias_shape, input_bias_initializer)); + ParallelTensorShape output_bias_shape = throw_if_unexpected(get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back(make_weight_attrs(output_bias_shape, output_bias_initializer)); + + } else if (input_bias_initializer.has_value()) { + throw mk_runtime_error("MultiheadAttention received unexpected input bias initializer even though bias is set to false"); + } else if (output_bias_initializer.has_value()) { + throw mk_runtime_error("MultiheadAttention received unexpected output bias initializer even though bias is set to false"); + } + + return this->add_layer(layer, {query, key, value}, weights, output_shape); } parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( parallel_tensor_guid_t const &input, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( - parallel_tensor_guid_t const &x, + parallel_tensor_guid_t const &input, ff_dim_t dim, int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + RepartitionAttrs attrs = RepartitionAttrs{dim, degree}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_combine( - parallel_tensor_guid_t const &x, + parallel_tensor_guid_t const &input, ff_dim_t dim, int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + CombineAttrs attrs = CombineAttrs{dim, degree}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_replicate( - parallel_tensor_guid_t const &x, + parallel_tensor_guid_t const &input, int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + ReplicateAttrs attrs = ReplicateAttrs{degree}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_reduce( - parallel_tensor_guid_t const &x, + parallel_tensor_guid_t const &input, int degree, - std::optional const &name) { - NOT_IMPLEMENTED(); + std::optional const &maybe_name) { + + ReductionAttrs attrs = ReductionAttrs{degree}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {}, {output_shape}); } parallel_tensor_guid_t ParallelComputationGraphBuilder::as_type( diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 0eaf78966f..8561548a87 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -7,7 +7,152 @@ #include "utils/containers/without_nullopts.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("ParallelComputationGraphBuilder") { + TEST_CASE("ParallelComputationGraphBuilder::add") { + ParallelComputationGraphBuilder b; + + ShardParallelDim d1 = ShardParallelDim{10, 2}; + ShardParallelDim d2 = ShardParallelDim{15, 3}; + + ParallelTensorShape lhs_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{15, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape rhs_shape = lhs_shape; + + parallel_tensor_guid_t lhs = b.create_input_tensor(lhs_shape); + parallel_tensor_guid_t rhs = b.create_input_tensor(rhs_shape); + + parallel_tensor_guid_t out = b.add(lhs, rhs); + parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { lhs, rhs }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { out }; + CHECK(result == correct); + } + + SUBCASE("op attrs") { + PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; + PCGOperatorAttrs correct = PCGOperatorAttrs{ElementBinaryAttrs{OperatorType::EW_ADD, DataType::FLOAT, false, false}}; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::batch_matmul") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{4, 2}; + + ParallelTensorShape a_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + ShardParallelDim{10, 1}, + ShardParallelDim{15, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape b_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + ShardParallelDim{15, 3}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t a_tensor = b.create_input_tensor(a_shape); + parallel_tensor_guid_t b_tensor = b.create_input_tensor(b_shape); + + parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); + parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { a_tensor, b_tensor }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { out }; + CHECK(result == correct); + } + + SUBCASE("op attrs") { + PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; + PCGOperatorAttrs correct = PCGOperatorAttrs{BatchMatmulAttrs{-1, -1}}; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::cast") { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + DataType output_datatype = DataType::DOUBLE; + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.cast(input, output_datatype); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + + ParallelTensorShape output_shape = get_parallel_tensor_attrs(b.pcg, output).shape; + CHECK(output_shape.data_type == output_datatype); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::conv2d") { ParallelComputationGraphBuilder b; size_t batch_size = 2; @@ -122,4 +267,327 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_tensor_attrs(b.pcg, conv_output).shape; CHECK(conv_output_shape == correct_output_shape); }; + + TEST_CASE("ParallelComputationGraphBuilder::dense") { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + int outDim = 14; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.dense(input, + outDim, + Activation::RELU, + /*use_bias=*/true, + DataType::FLOAT); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + CHECK(result.at(0) == input); + + CHECK(result.size() == 3); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::embedding") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{12, 2}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::INT32, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.embedding(input, + /*num_entries=*/32, + /*outDim=*/8, + AggregateOp::SUM, + DataType::FLOAT); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + CHECK(result.at(0) == input); + + CHECK(result.size() == 2); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::multihead_attention") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{12, 2}; + ShardParallelDim sequence_dim = ShardParallelDim{16, 1}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + ParallelTensorShape query_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + sequence_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape key_shape = query_shape; + ParallelTensorShape value_shape = query_shape; + + int embed_dim = 8; + int num_heads = 6; + + parallel_tensor_guid_t query = b.create_input_tensor(query_shape); + parallel_tensor_guid_t key = b.create_input_tensor(key_shape); + parallel_tensor_guid_t value = b.create_input_tensor(value_shape); + parallel_tensor_guid_t output = b.multihead_attention(query, + key, + value, + embed_dim, + num_heads); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + CHECK(result.at(0) == query); + CHECK(result.at(1) == key); + CHECK(result.at(2) == value); + CHECK(result.size() == 6); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::relu") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{18, 3}; + ShardParallelDim feature_dim = ShardParallelDim{32, 1}; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.relu(input); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::parallel_partition") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{18, 2}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.parallel_partition(input, + ff_dim_t{0}, + 2); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::parallel_combine") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{18, 2}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.parallel_combine(input, + ff_dim_t{0}, + 2); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::parallel_replicate") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{18, 2}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.parallel_replicate(input, 2); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } + + TEST_CASE("ParallelComputationGraphBuilder::parallel_reduce") { + ParallelComputationGraphBuilder b; + + ShardParallelDim batch_dim = ShardParallelDim{18, 2}; + ShardParallelDim feature_dim = ShardParallelDim{10, 1}; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{4}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t output = b.parallel_reduce(input, 2); + parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + + SUBCASE("inputs") { + std::vector result = get_layer_inputs(b.pcg, layer); + std::vector correct = { input }; + CHECK(result == correct); + } + + SUBCASE("outputs") { + std::vector result = get_layer_outputs(b.pcg, layer); + std::vector correct = { output }; + CHECK(result == correct); + } + } } From f379539e627e2b7e6f923653ef55f053d51347fc Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 10 Jun 2024 00:58:28 -0700 Subject: [PATCH 05/71] Fix build issues in local-execution --- .proj.toml | 1 + lib/local-execution/src/ops/attention.cc | 33 ++++++++----------- .../include/pcg/computation_graph_builder.h | 5 --- lib/pcg/src/pcg/computation_graph_builder.cc | 15 +++------ .../parallel_computation_graph_builder.cc | 2 +- 5 files changed, 21 insertions(+), 35 deletions(-) diff --git a/.proj.toml b/.proj.toml index b076671498..01ae36eddd 100644 --- a/.proj.toml +++ b/.proj.toml @@ -11,6 +11,7 @@ build_targets = [ # "substitutions", # "compiler", "substitution-generator", + "local-execution", ] test_targets = [ "utils-tests", diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index c40e4f1e2d..d366addb91 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -16,6 +16,7 @@ #include "attention.h" #include "kernels/attention_kernels.h" #include "local-execution/op_task_signature.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" namespace FlexFlow { @@ -95,31 +96,25 @@ static DeviceSpecific ParallelTensorShape value_parallel_tensor_shape = acc.get_argument(VALUE_PARALLEL_TENSOR_SHAPE); - MultiHeadAttentionInputs inputs = { - shard_dim_at_idx(query_parallel_tensor_shape, ff_dim_t{0}).size, - shard_dim_at_idx(query_parallel_tensor_shape, ff_dim_t{1}).size, - qProjSize, - kProjSize, - vProjSize, - query_parallel_tensor_shape.data_type}; - ; + MultiHeadAttentionParallelInputs parsed = throw_if_unexpected( + parse_attention_parallel_input_shape(query_parallel_tensor_shape, + key_parallel_tensor_shape, + value_parallel_tensor_shape) + ); ParallelTensorShape weight_parallel_tensor_shape = throw_if_unexpected(get_weights_shape(attrs, query_parallel_tensor_shape, key_parallel_tensor_shape, value_parallel_tensor_shape)); - int kvSeqLength = get_kvSeqLength(inputs); - int qSize = get_qSize(inputs); - int kSize = get_kSize(inputs); - int vSize = get_vSize(inputs); - - int qoSeqLength = - dim_at_idx(get_piece_shape(query_parallel_tensor_shape), ff_dim_t(1)); - int num_samples = - dim_at_idx(get_piece_shape(query_parallel_tensor_shape), ff_dim_t(2)); - int num_heads = - dim_at_idx(get_piece_shape(weight_parallel_tensor_shape), ff_dim_t(1)); + int kvSeqLength = get_kvSeqLength(parsed); + int qSize = get_qSize(parsed); + int kSize = get_kSize(parsed); + int vSize = get_vSize(parsed); + + int qoSeqLength = get_qoSeqLength(parsed); + int num_samples = get_num_samples(parsed); + int num_heads = attrs.num_heads; MHAPerDeviceState per_device_state = init_kernel(handle, allocator, diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 4bb04fc22a..fe156172e2 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -281,11 +281,6 @@ struct ComputationGraphBuilder { tensor_guid_t const &input, float scalar, std::optional const &name = std::nullopt); - tensor_guid_t - element_unary(ElementUnaryAttrs const &, - tensor_guid_t const &input, - std::optional const &name = std::nullopt); - public: ComputationGraph computation_graph; }; diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 6ef900e875..d3dcf79ca6 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -138,9 +138,13 @@ static std::string get_default_name(ComputationGraphOpAttrs const &attrs) { } tensor_guid_t ComputationGraphBuilder::element_unary( - ElementUnaryAttrs const &attrs, + OperatorType op_type, tensor_guid_t const &x, + std::optional scalar, std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{op_type, scalar}; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); @@ -155,15 +159,6 @@ tensor_guid_t ComputationGraphBuilder::element_unary( return this->add_layer(layer, {input}, {}, output_shape); } -tensor_guid_t ComputationGraphBuilder::element_unary( - OperatorType op_type, - tensor_guid_t const &input, - std::optional scalar, - std::optional const &name) { - ElementUnaryAttrs attrs = {op_type, scalar}; - return this->element_unary(attrs, input, name); -} - tensor_guid_t ComputationGraphBuilder::element_binary( OperatorType op_type, tensor_guid_t const &lhs, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 9785982b08..d09508eab9 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -307,7 +307,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( parallel_tensor_guid_t const &input, std::optional const &maybe_name) { - ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; + ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU, std::nullopt}; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); From 2dbb3b9e4bbf972ff920a0a7032855eeaf019d68 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 10 Jun 2024 01:04:01 -0700 Subject: [PATCH 06/71] Format --- lib/local-execution/src/ops/attention.cc | 7 +- lib/op-attrs/include/op-attrs/ops/cast.h | 8 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 2 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 71 ++-- lib/op-attrs/src/op-attrs/ops/cast.cc | 26 +- lib/op-attrs/test/src/ops/attention.cc | 94 ++--- lib/op-attrs/test/src/ops/cast.cc | 20 +- .../include/pcg/computation_graph_builder.h | 1 + .../pcg/dataflow_graph/dataflow_graph.h | 9 +- .../parallel_computation_graph.cc | 2 +- .../parallel_computation_graph_builder.cc | 140 ++++--- .../parallel_computation_graph_builder.cc | 381 +++++++++--------- lib/utils/include/utils/fmt/unordered_set.h | 16 +- lib/utils/include/utils/fmt/vector.h | 11 +- lib/utils/include/utils/graph/multidiedge.h | 2 +- lib/utils/include/utils/stack_vector.h | 4 +- 16 files changed, 423 insertions(+), 371 deletions(-) diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index d366addb91..be1fae475c 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -97,10 +97,9 @@ static DeviceSpecific acc.get_argument(VALUE_PARALLEL_TENSOR_SHAPE); MultiHeadAttentionParallelInputs parsed = throw_if_unexpected( - parse_attention_parallel_input_shape(query_parallel_tensor_shape, - key_parallel_tensor_shape, - value_parallel_tensor_shape) - ); + parse_attention_parallel_input_shape(query_parallel_tensor_shape, + key_parallel_tensor_shape, + value_parallel_tensor_shape)); ParallelTensorShape weight_parallel_tensor_shape = throw_if_unexpected(get_weights_shape(attrs, query_parallel_tensor_shape, diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 8a97bbafe6..ead779c553 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_CAST_ATTRS_H #define _FLEXFLOW_CAST_ATTRS_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" #include @@ -11,8 +11,10 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(CastAttrs); -tl::expected get_output_shape(CastAttrs const &, TensorShape const &); -tl::expected get_output_shape(CastAttrs const &, ParallelTensorShape const &); +tl::expected get_output_shape(CastAttrs const &, + TensorShape const &); +tl::expected + get_output_shape(CastAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index d80e4b5862..72d1123c39 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_CONV_2D_ATTRS_H #define _FLEXFLOW_CONV_2D_ATTRS_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 834c3b7330..3e4095eca8 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -128,9 +128,9 @@ tl::expected tl::expected get_input_bias_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { MultiHeadAttentionInputs parsed = ({ tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); @@ -141,18 +141,18 @@ tl::expected }); return TensorShape{ - TensorDims{FFOrdered{ - size_t_from_int(attrs.kdim + attrs.kdim + attrs.vdim), - }}, - parsed.datatype, + TensorDims{FFOrdered{ + size_t_from_int(attrs.kdim + attrs.kdim + attrs.vdim), + }}, + parsed.datatype, }; } tl::expected get_output_bias_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { MultiHeadAttentionInputs parsed = ({ tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); @@ -163,14 +163,13 @@ tl::expected }); return TensorShape{ - TensorDims{FFOrdered{ - size_t_from_int(attrs.embed_dim), - }}, - parsed.datatype, + TensorDims{FFOrdered{ + size_t_from_int(attrs.embed_dim), + }}, + parsed.datatype, }; } - tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &input_q, @@ -219,10 +218,11 @@ tl::expected }); TensorShape unpar_shape = ({ - tl::expected result_unpar = get_input_bias_shape(attrs, - get_reduced_shape(input_q), - get_reduced_shape(input_k), - get_reduced_shape(input_v)); + tl::expected result_unpar = + get_input_bias_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } @@ -232,20 +232,17 @@ tl::expected SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - parsed.batch_dim.degree * parsed.discard_copy_degree.value - }; + parsed.batch_dim.degree * parsed.discard_copy_degree.value}; FFOrdered shard_degrees = FFOrdered{1}; - return lift_to_parallel_with_degrees(unpar_shape, - sum_degree, - discard_copy_degree, - shard_degrees); + return lift_to_parallel_with_degrees( + unpar_shape, sum_degree, discard_copy_degree, shard_degrees); } tl::expected get_output_bias_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v) { + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { MultiHeadAttentionParallelInputs parsed = ({ tl::expected parse_result = parse_attention_parallel_input_shape(input_q, input_k, input_v); @@ -257,10 +254,11 @@ tl::expected }); TensorShape unpar_shape = ({ - tl::expected result_unpar = get_output_bias_shape(attrs, - get_reduced_shape(input_q), - get_reduced_shape(input_k), - get_reduced_shape(input_v)); + tl::expected result_unpar = + get_output_bias_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } @@ -270,13 +268,10 @@ tl::expected SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - parsed.batch_dim.degree * parsed.discard_copy_degree.value - }; + parsed.batch_dim.degree * parsed.discard_copy_degree.value}; FFOrdered shard_degrees = FFOrdered{1}; - return lift_to_parallel_with_degrees(unpar_shape, - sum_degree, - discard_copy_degree, - shard_degrees); + return lift_to_parallel_with_degrees( + unpar_shape, sum_degree, discard_copy_degree, shard_degrees); } tl::expected diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index e4ae8f9759..444409ffcb 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -2,30 +2,32 @@ namespace FlexFlow { -tl::expected - get_output_shape(CastAttrs const &attrs, TensorShape const &input) { +tl::expected + get_output_shape(CastAttrs const &attrs, TensorShape const &input) { if (!can_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { - return tl::unexpected(fmt::format("Cast cannot strictly promote input datatype {} to output datatype {}", - input.data_type, - attrs.dtype)); + return tl::unexpected(fmt::format( + "Cast cannot strictly promote input datatype {} to output datatype {}", + input.data_type, + attrs.dtype)); } - + TensorShape output = input; output.data_type = attrs.dtype; return output; } tl::expected - get_output_shape(CastAttrs const &attrs, ParallelTensorShape const &input) { + get_output_shape(CastAttrs const &attrs, ParallelTensorShape const &input) { if (!can_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { - return tl::unexpected(fmt::format("Cast cannot strictly promote input datatype {} to output datatype {}", - input.data_type, - attrs.dtype)); + return tl::unexpected(fmt::format( + "Cast cannot strictly promote input datatype {} to output datatype {}", + input.data_type, + attrs.dtype)); } - - ParallelTensorShape output = input; + + ParallelTensorShape output = input; output.data_type = attrs.dtype; return output; diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/ops/attention.cc index 7f69d57cd7..ade219a6a9 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -28,75 +28,75 @@ TEST_SUITE(FF_TEST_SUITE) { size_t feature_size = 36; TensorShape input_q = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - feature_size, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, + }, }, - }, DataType::FLOAT, }; TensorShape input_k = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - feature_size, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, + }, }, - }, DataType::FLOAT, }; TensorShape input_v = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - feature_size, + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + feature_size, + }, }, - }, DataType::FLOAT, }; TensorShape output = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.embed_dim), + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape weights = TensorShape{ - TensorDims{ - FFOrdered{ - (feature_size * embed_dim) * 3 + (embed_dim * embed_dim), - size_t_from_int(num_heads), + TensorDims{ + FFOrdered{ + (feature_size * embed_dim) * 3 + (embed_dim * embed_dim), + size_t_from_int(num_heads), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape input_bias = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(embed_dim * 3), + TensorDims{ + FFOrdered{ + size_t_from_int(embed_dim * 3), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape output_bias = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(embed_dim), + TensorDims{ + FFOrdered{ + size_t_from_int(embed_dim), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; SUBCASE("get_output_shape") { @@ -116,14 +116,14 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("get_input_bias_shape") { - tl::expected result = + tl::expected result = get_input_bias_shape(attrs, input_q, input_k, input_v); tl::expected correct = input_bias; CHECK(result == correct); } SUBCASE("get_output_bias_shape") { - tl::expected result = + tl::expected result = get_output_bias_shape(attrs, input_q, input_k, input_v); tl::expected correct = output_bias; CHECK(result == correct); @@ -169,19 +169,19 @@ TEST_SUITE(FF_TEST_SUITE) { auto make_w = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_e, int o_h) { return lift_to_parallel_with_degrees( - weights, o_sum, o_eq, FFOrdered{o_e, o_h}); + weights, o_sum, o_eq, FFOrdered{o_e, o_h}); }; - auto make_input_bias = + auto make_input_bias = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_in_proj_channel) { return lift_to_parallel_with_degrees( - input_bias, o_sum, o_eq, FFOrdered{o_in_proj_channel}); + input_bias, o_sum, o_eq, FFOrdered{o_in_proj_channel}); }; auto make_output_bias = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_out_proj_channel) { return lift_to_parallel_with_degrees( - output_bias, o_sum, o_eq, FFOrdered{o_out_proj_channel}); + output_bias, o_sum, o_eq, FFOrdered{o_out_proj_channel}); }; SUBCASE("data parallelism") { diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 9d2e79dfd9..086d25d042 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -11,17 +11,18 @@ TEST_SUITE(FF_TEST_SUITE) { size_t d1 = 12; size_t d2 = 16; TensorShape input = TensorShape{ - TensorDims{FFOrdered{d1, d2}}, - input_datatype, + TensorDims{FFOrdered{d1, d2}}, + input_datatype, }; TensorShape output = TensorShape{ - TensorDims{FFOrdered{d1, d2}}, - output_datatype, + TensorDims{FFOrdered{d1, d2}}, + output_datatype, }; SUBCASE("get_output_shape(CastAttrs, TensorShape)") { - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = output; CHECK(result == correct); } @@ -47,10 +48,13 @@ TEST_SUITE(FF_TEST_SUITE) { DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{3}; int batch_degree = 4; int feature_degree = 8; - ParallelTensorShape par_input = make_input(sum_degree, discard_copy_degree, batch_degree, feature_degree); + ParallelTensorShape par_input = make_input( + sum_degree, discard_copy_degree, batch_degree, feature_degree); - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = make_output(sum_degree, discard_copy_degree, batch_degree, feature_degree); + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = make_output( + sum_degree, discard_copy_degree, batch_degree, feature_degree); CHECK(result == correct); } diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index fe156172e2..7c0c73dd1d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -281,6 +281,7 @@ struct ComputationGraphBuilder { tensor_guid_t const &input, float scalar, std::optional const &name = std::nullopt); + public: ComputationGraph computation_graph; }; diff --git a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h index e90acf533d..1ff1d86ed6 100644 --- a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h +++ b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h @@ -66,14 +66,17 @@ struct DataflowGraph { return this->g.at(o); } - std::unordered_map> const &get_output_map() const { + std::unordered_map> const & + get_output_map() const { return this->output_map; } + private: OutputLabelledMultiDiGraph g; bidict port_mapping; - std::unordered_map> output_map; // NOTE(@lockshaw): temporary workaround until not tracking outputs - // independent of edges in multidigraph is resolved + std::unordered_map> + output_map; // NOTE(@lockshaw): temporary workaround until not tracking + // outputs independent of edges in multidigraph is resolved }; template diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 491ac67708..daf2b67303 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -30,7 +30,7 @@ std::vector [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); } -parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, +parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, parallel_tensor_guid_t const &t) { return parallel_layer_guid_t{t.raw_graph_output.src}; } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index d09508eab9..29723ed078 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -57,24 +57,28 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add( DataType datatype = [&] { if (lhs_shape.data_type != rhs_shape.data_type) { - throw mk_runtime_error(fmt::format("Datatypes do not match: {} (lhs) != {} (rhs)", lhs_shape.data_type, rhs_shape.data_type)); + throw mk_runtime_error( + fmt::format("Datatypes do not match: {} (lhs) != {} (rhs)", + lhs_shape.data_type, + rhs_shape.data_type)); } else { return lhs_shape.data_type; } }(); ElementBinaryAttrs attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - datatype, - false, - false, + OperatorType::EW_ADD, + datatype, + false, + false, }; - std::string name = - maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); - + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, lhs_shape, rhs_shape)); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, lhs_shape, rhs_shape)); return this->add_layer(layer, {lhs, rhs}, {}, output_shape); } @@ -85,15 +89,16 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_matmul( std::optional const &maybe_name) { BatchMatmulAttrs attrs = BatchMatmulAttrs{ - /*a_seq_length_dim=*/-1, - /*b_seq_length_dim=*/-1, + /*a_seq_length_dim=*/-1, + /*b_seq_length_dim=*/-1, }; - std::string name = - maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(a), this->get_shape(b))); + ParallelTensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(a), this->get_shape(b))); return this->add_layer(layer, {a, b}, {}, {output_shape}); } @@ -105,11 +110,12 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::cast( CastAttrs attrs = CastAttrs{result_type}; - std::string name = - maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -175,33 +181,37 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::optional const &bias_initializer, std::optional const &maybe_name) { LinearAttrs attrs = LinearAttrs{ - outDim, - use_bias, - data_type, - activation, - std::nullopt, + outDim, + use_bias, + data_type, + activation, + std::nullopt, }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); - + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; ParallelTensorShape input_shape = this->get_shape(input); - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); std::vector weights; { - ParallelTensorShape kernel_shape = throw_if_unexpected(get_kernel_shape(attrs, input_shape)); + ParallelTensorShape kernel_shape = + throw_if_unexpected(get_kernel_shape(attrs, input_shape)); weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); } if (use_bias) { - ParallelTensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, input_shape)); + ParallelTensorShape bias_shape = + throw_if_unexpected(get_bias_shape(attrs, input_shape)); weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); } else if (bias_initializer.has_value()) { - throw mk_runtime_error("Dense received unexpected bias initializer even though use_bias is set to false"); + throw mk_runtime_error("Dense received unexpected bias initializer even " + "though use_bias is set to false"); } return this->add_layer(layer, {input}, weights, output_shape); @@ -217,10 +227,10 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( std::optional const &maybe_name) { EmbeddingAttrs attrs = EmbeddingAttrs{ - num_entries, - outDim, - aggr, - dtype, + num_entries, + outDim, + aggr, + dtype, }; std::string name = @@ -229,10 +239,13 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::embedding( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; ParallelTensorShape input_shape = this->get_shape(input); - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - ParallelTensorShape weights_shape = throw_if_unexpected(get_weights_shape(attrs, input_shape)); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); + ParallelTensorShape weights_shape = + throw_if_unexpected(get_weights_shape(attrs, input_shape)); - ParallelTensorAttrs weights_attrs = make_weight_attrs(weights_shape, kernel_initializer); + ParallelTensorAttrs weights_attrs = + make_weight_attrs(weights_shape, kernel_initializer); return this->add_layer(layer, {input}, {weights_attrs}, output_shape); } @@ -258,14 +271,14 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( int vdim = maybe_vdim.value_or(embed_dim); MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ - /*embed_dim=*/embed_dim, - /*num_heads=*/num_heads, - /*kdim=*/kdim, - /*vdim=*/vdim, - /*dropout=*/dropout, - /*bias=*/bias, - /*add_bias_kv=*/add_bias_kv, - /*add_zero_attn=*/add_zero_attn, + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/kdim, + /*vdim=*/vdim, + /*dropout=*/dropout, + /*bias=*/bias, + /*add_bias_kv=*/add_bias_kv, + /*add_zero_attn=*/add_zero_attn, }; std::string name = @@ -277,27 +290,35 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( ParallelTensorShape key_shape = this->get_shape(key); ParallelTensorShape value_shape = this->get_shape(value); - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, query_shape, key_shape, value_shape)); + ParallelTensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, query_shape, key_shape, value_shape)); std::vector weights; - + ParallelTensorAttrs weight_attrs = [&] { - ParallelTensorShape weight_shape = throw_if_unexpected(get_weights_shape(attrs, query_shape, key_shape, value_shape)); + ParallelTensorShape weight_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); return make_weight_attrs(weight_shape, initializer); }(); weights.push_back(weight_attrs); if (bias) { - ParallelTensorShape input_bias_shape = throw_if_unexpected(get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); - weights.push_back(make_weight_attrs(input_bias_shape, input_bias_initializer)); - ParallelTensorShape output_bias_shape = throw_if_unexpected(get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); - weights.push_back(make_weight_attrs(output_bias_shape, output_bias_initializer)); + ParallelTensorShape input_bias_shape = throw_if_unexpected( + get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back( + make_weight_attrs(input_bias_shape, input_bias_initializer)); + ParallelTensorShape output_bias_shape = throw_if_unexpected( + get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back( + make_weight_attrs(output_bias_shape, output_bias_initializer)); } else if (input_bias_initializer.has_value()) { - throw mk_runtime_error("MultiheadAttention received unexpected input bias initializer even though bias is set to false"); + throw mk_runtime_error("MultiheadAttention received unexpected input bias " + "initializer even though bias is set to false"); } else if (output_bias_initializer.has_value()) { - throw mk_runtime_error("MultiheadAttention received unexpected output bias initializer even though bias is set to false"); + throw mk_runtime_error("MultiheadAttention received unexpected output bias " + "initializer even though bias is set to false"); } return this->add_layer(layer, {query, key, value}, weights, output_shape); @@ -314,7 +335,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -324,7 +346,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( ff_dim_t dim, int degree, std::optional const &maybe_name) { - + RepartitionAttrs attrs = RepartitionAttrs{dim, degree}; std::string name = @@ -332,7 +354,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -350,7 +373,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_combine( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -367,7 +391,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_replicate( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + ParallelTensorShape output_shape = + get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, {output_shape}); } @@ -384,7 +409,8 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_reduce( ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; - ParallelTensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, {output_shape}); } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 8561548a87..50ad727c12 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -14,17 +14,17 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim d2 = ShardParallelDim{15, 3}; ParallelTensorShape lhs_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{15, 3}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{15, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, }, - ReplicaParallelDimSet{ - SumDegree{2}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; ParallelTensorShape rhs_shape = lhs_shape; @@ -36,57 +36,60 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, out); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { lhs, rhs }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {lhs, rhs}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { out }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {out}; CHECK(result == correct); } SUBCASE("op attrs") { PCGOperatorAttrs result = get_parallel_layer_attrs(b.pcg, layer).op_attrs; - PCGOperatorAttrs correct = PCGOperatorAttrs{ElementBinaryAttrs{OperatorType::EW_ADD, DataType::FLOAT, false, false}}; + PCGOperatorAttrs correct = PCGOperatorAttrs{ElementBinaryAttrs{ + OperatorType::EW_ADD, DataType::FLOAT, false, false}}; CHECK(result == correct); } } TEST_CASE("ParallelComputationGraphBuilder::batch_matmul") { - ParallelComputationGraphBuilder b; + ParallelComputationGraphBuilder b; ShardParallelDim batch_dim = ShardParallelDim{4, 2}; ParallelTensorShape a_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - ShardParallelDim{10, 1}, - ShardParallelDim{15, 3}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + ShardParallelDim{10, 1}, + ShardParallelDim{15, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; ParallelTensorShape b_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - ShardParallelDim{15, 3}, - ShardParallelDim{12, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + ShardParallelDim{15, 3}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t a_tensor = b.create_input_tensor(a_shape); @@ -96,14 +99,16 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, out); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { a_tensor, b_tensor }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {a_tensor, b_tensor}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { out }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {out}; CHECK(result == correct); } @@ -118,17 +123,17 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraphBuilder b; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{12, 1}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{1}, + }, }, - ReplicaParallelDimSet{ - SumDegree{3}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; DataType output_datatype = DataType::DOUBLE; @@ -137,17 +142,20 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); - ParallelTensorShape output_shape = get_parallel_tensor_attrs(b.pcg, output).shape; + ParallelTensorShape output_shape = + get_parallel_tensor_attrs(b.pcg, output).shape; CHECK(output_shape.data_type == output_datatype); } } @@ -269,63 +277,65 @@ TEST_SUITE(FF_TEST_SUITE) { }; TEST_CASE("ParallelComputationGraphBuilder::dense") { - ParallelComputationGraphBuilder b; + ParallelComputationGraphBuilder b; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{16, 1}, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; int outDim = 14; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - parallel_tensor_guid_t output = b.dense(input, + parallel_tensor_guid_t output = b.dense(input, outDim, Activation::RELU, /*use_bias=*/true, DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(b.pcg, output); - + SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); + std::vector result = + get_layer_inputs(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 3); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } TEST_CASE("ParallelComputationGraphBuilder::embedding") { - ParallelComputationGraphBuilder b; + ParallelComputationGraphBuilder b; ShardParallelDim batch_dim = ShardParallelDim{12, 2}; ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::INT32, + DataType::INT32, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); @@ -337,15 +347,17 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); + std::vector result = + get_layer_inputs(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 2); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -357,18 +369,18 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim sequence_dim = ShardParallelDim{16, 1}; ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape query_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - sequence_dim, - feature_dim, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + sequence_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; ParallelTensorShape key_shape = query_shape; @@ -380,15 +392,13 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t query = b.create_input_tensor(query_shape); parallel_tensor_guid_t key = b.create_input_tensor(key_shape); parallel_tensor_guid_t value = b.create_input_tensor(value_shape); - parallel_tensor_guid_t output = b.multihead_attention(query, - key, - value, - embed_dim, - num_heads); + parallel_tensor_guid_t output = + b.multihead_attention(query, key, value, embed_dim, num_heads); parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); + std::vector result = + get_layer_inputs(b.pcg, layer); CHECK(result.at(0) == query); CHECK(result.at(1) == key); CHECK(result.at(2) == value); @@ -396,8 +406,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -409,17 +420,17 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim feature_dim = ShardParallelDim{32, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); @@ -427,14 +438,16 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -446,34 +459,34 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - parallel_tensor_guid_t output = b.parallel_partition(input, - ff_dim_t{0}, - 2); + parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -485,34 +498,34 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - parallel_tensor_guid_t output = b.parallel_combine(input, - ff_dim_t{0}, - 2); + parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -524,17 +537,17 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, - }, - ReplicaParallelDimSet{ - SumDegree{1}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); @@ -542,14 +555,16 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } @@ -561,17 +576,17 @@ TEST_SUITE(FF_TEST_SUITE) { ShardParallelDim feature_dim = ShardParallelDim{10, 1}; ParallelTensorShape input_shape = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - batch_dim, - feature_dim, - }, - ReplicaParallelDimSet{ - SumDegree{4}, - DiscardCopyDegree{1}, + ParallelTensorDims{ + FFOrdered{ + batch_dim, + feature_dim, + }, + ReplicaParallelDimSet{ + SumDegree{4}, + DiscardCopyDegree{1}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); @@ -579,14 +594,16 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_layer_guid_t layer = get_source_layer(b.pcg, output); SUBCASE("inputs") { - std::vector result = get_layer_inputs(b.pcg, layer); - std::vector correct = { input }; + std::vector result = + get_layer_inputs(b.pcg, layer); + std::vector correct = {input}; CHECK(result == correct); } SUBCASE("outputs") { - std::vector result = get_layer_outputs(b.pcg, layer); - std::vector correct = { output }; + std::vector result = + get_layer_outputs(b.pcg, layer); + std::vector correct = {output}; CHECK(result == correct); } } diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 1ce36fa97a..8954faf7c5 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_SET_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_SET_H -#include +#include "utils/check_fmtable.h" #include "utils/join_strings.h" #include -#include "utils/check_fmtable.h" +#include namespace fmt { @@ -12,20 +12,22 @@ template struct formatter< ::std::unordered_set, Char, - std::enable_if_t>::value>> + std::enable_if_t>::value>> : formatter<::std::string> { template auto format(::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { + -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = ::FlexFlow::join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + std::string result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); return formatter::format("{" + result + "}", ctx); } }; -} +} // namespace fmt namespace FlexFlow { diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 82cdcfdb3c..5d9ca0aeae 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VECTOR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VECTOR_H -#include +#include "utils/check_fmtable.h" #include "utils/join_strings.h" #include -#include "utils/check_fmtable.h" +#include namespace fmt { @@ -19,8 +19,10 @@ struct formatter< -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = ::FlexFlow::join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + std::string result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); return formatter::format("[" + result + "]", ctx); } }; @@ -38,5 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow - #endif diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index 2a9a417e7e..de4ab4fd82 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -4,9 +4,9 @@ #include "diedge.h" #include "node.h" #include "node_port.h" +#include "utils/fmt/pair.h" #include "utils/strong_typedef.h" #include "utils/visitable.h" -#include "utils/fmt/pair.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 142d4fe7b5..c2fdbe0afe 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -5,6 +5,7 @@ #include "hash-utils.h" #include "rapidcheck.h" #include "utils/fmt.h" +#include "utils/fmt/vector.h" #include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" @@ -12,7 +13,6 @@ #include #include #include -#include "utils/fmt/vector.h" namespace FlexFlow { @@ -316,7 +316,7 @@ struct stack_vector { template std::ostream &operator<<(std::ostream &s, stack_vector const &v) { - return s << fmt::to_string(v); + return s << fmt::to_string(v); } template From 4050c996e24a62ba25e8f5d3632a733318abcf04 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 16 Jun 2024 22:14:11 -0700 Subject: [PATCH 07/71] Address Reyna comments, add topological_order function for PCG --- .../test/src/test_regularizer_attrs.cc | 15 ++--- .../include/pcg/dataflow_graph/algorithms.h | 36 ++++++++++ .../pcg/dataflow_graph/dataflow_graph.h | 21 ------ lib/pcg/include/pcg/initializer_attrs.dtg.h | 11 ++- .../pcg/initializer_attrs.variant.toml | 3 +- .../constant_initializer_attrs.dtg.h | 10 ++- .../constant_initializer_attrs.struct.toml | 2 +- .../uniform_initializer_attrs.dtg.h | 2 +- .../initializers/uniform_initializer_attrs.h | 16 +++++ .../uniform_initializer_attrs.struct.toml | 1 - .../parallel_computation_graph.h | 8 +++ .../parallel_layer_added_result.dtg.h | 44 ++++++++++++ .../parallel_layer_added_result.struct.toml | 23 +++++++ .../parallel_layer_attrs.dtg.h | 10 ++- .../parallel_layer_attrs.struct.toml | 2 +- .../parallel_tensor_attrs.dtg.h | 10 ++- .../parallel_tensor_attrs.struct.toml | 2 +- lib/pcg/src/pcg/dataflow_graph/algorithms.cc | 1 + lib/pcg/src/pcg/initializer_attrs.dtg.cc | 18 ++++- .../constant_initializer_attrs.dtg.cc | 10 ++- .../initializers/uniform_initializer_attrs.cc | 20 ++++++ .../uniform_initializer_attrs.dtg.cc | 2 +- .../parallel_computation_graph.cc | 21 ++++++ .../parallel_layer_added_result.dtg.cc | 67 +++++++++++++++++++ .../parallel_layer_attrs.dtg.cc | 11 ++- .../parallel_tensor_attrs.dtg.cc | 13 +++- .../algorithms.cc} | 32 ++++++++- .../initializers/uniform_initializer_attrs.cc | 11 +++ .../parallel_computation_graph.cc | 31 +++++++++ lib/utils/include/utils/graph/algorithms.h | 2 - lib/utils/include/utils/stack_string.h | 19 ++++++ .../test/common/include/test/utils/all.h | 7 +- .../common/include/test/utils/rapidcheck.h | 4 ++ .../include/test/utils/rapidcheck/some.h | 16 +++++ lib/utils/test/src/test_stack_string.cc | 6 ++ lib/utils/test/src/test_stack_vector.cc | 10 +-- 36 files changed, 457 insertions(+), 60 deletions(-) create mode 100644 lib/pcg/include/pcg/dataflow_graph/algorithms.h create mode 100644 lib/pcg/include/pcg/initializers/uniform_initializer_attrs.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml create mode 100644 lib/pcg/src/pcg/dataflow_graph/algorithms.cc create mode 100644 lib/pcg/src/pcg/initializers/uniform_initializer_attrs.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc rename lib/pcg/test/src/pcg/{dataflow_graph.cc => dataflow_graph/algorithms.cc} (59%) create mode 100644 lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc create mode 100644 lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc create mode 100644 lib/utils/test/common/include/test/utils/rapidcheck.h create mode 100644 lib/utils/test/common/include/test/utils/rapidcheck/some.h diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc index 198c3add38..63471d2bd3 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -1,14 +1,11 @@ -#include "doctest/doctest.h" #include "op-attrs/regularizer_attrs.dtg.h" -#include - -using namespace FlexFlow; +#include "test/utils/doctest.h" +#include "test/utils/rapidcheck.h" TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("RC") { - CHECK(rc::check("valid variant", [](RegularizerAttrs reg) { - return reg.has() || reg.has(); - })); + TEST_CASE("Arbitrary") { + rc::dc_check("valid variant", [](RegularizerAttrs reg) { + RC_ASSERT(reg.has() || reg.has()); + }); } } diff --git a/lib/pcg/include/pcg/dataflow_graph/algorithms.h b/lib/pcg/include/pcg/dataflow_graph/algorithms.h new file mode 100644 index 0000000000..413fecd92a --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph/algorithms.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H + +#include "pcg/dataflow_graph/dataflow_graph.h" + +namespace FlexFlow { + +template +std::vector + get_inputs(DataflowGraph const &g, Node const &n) { + std::vector> input_edges = + transform(as_vector(get_incoming_edges(g.get_raw_graph(), + std::unordered_set{n})), + [&](MultiDiEdge const &e) { + int idx = g.idx_for_port(e.dst_idx); + MultiDiOutput val = static_cast(e); + return std::make_pair(idx, val); + }); + + return vector_from_indexed_set(input_edges); +} + +template +std::vector + get_outputs(DataflowGraph const &g, Node const &n) { + return g.get_output_map().at(n); +} + +template +std::vector topological_ordering(DataflowGraph const &g) { + return get_topological_ordering(g.get_raw_graph()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h index 1ff1d86ed6..c0650bc9b4 100644 --- a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h +++ b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h @@ -100,27 +100,6 @@ std::vector }); } -template -std::vector - get_inputs(DataflowGraph const &g, Node const &n) { - std::vector> input_edges = - transform(as_vector(get_incoming_edges(g.get_raw_graph(), - std::unordered_set{n})), - [&](MultiDiEdge const &e) { - int idx = g.idx_for_port(e.dst_idx); - MultiDiOutput val = static_cast(e); - return std::make_pair(idx, val); - }); - - return vector_from_indexed_set(input_edges); -} - -template -std::vector - get_outputs(DataflowGraph const &g, Node const &n) { - return g.get_output_map().at(n); -} - } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializer_attrs.dtg.h index 7f5a470a90..3de94dcc86 100644 --- a/lib/pcg/include/pcg/initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializer_attrs.variant.toml /* proj-data { - "generated_from": "f66f3a89ea937e96a058d83ab52e2826" + "generated_from": "f4d932a4a7728ebfc28a23f2e6ca3201" } */ @@ -15,8 +15,9 @@ #include "pcg/initializers/constant_initializer_attrs.dtg.h" #include "pcg/initializers/glorot_uniform_attrs.dtg.h" #include "pcg/initializers/norm_initializer_attrs.dtg.h" -#include "pcg/initializers/uniform_initializer_attrs.dtg.h" +#include "pcg/initializers/uniform_initializer_attrs.h" #include "pcg/initializers/zero_initializer_attrs.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -161,6 +162,12 @@ struct adl_serializer<::FlexFlow::InitializerAttrs> { static void to_json(json &, ::FlexFlow::InitializerAttrs const &); }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::InitializerAttrs> { + static Gen<::FlexFlow::InitializerAttrs> arbitrary(); +}; +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::InitializerAttrs const &); std::ostream &operator<<(std::ostream &, ::FlexFlow::InitializerAttrs const &); diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 14a5cfdcac..1ea9ce05a6 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -6,12 +6,13 @@ features = [ "hash", "json", "fmt", + "rapidcheck", ] includes = [ "pcg/initializers/glorot_uniform_attrs.dtg.h", "pcg/initializers/zero_initializer_attrs.dtg.h", - "pcg/initializers/uniform_initializer_attrs.dtg.h", + "pcg/initializers/uniform_initializer_attrs.h", "pcg/initializers/norm_initializer_attrs.dtg.h", "pcg/initializers/constant_initializer_attrs.dtg.h", ] diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h index 1512cb8e18..18876046b2 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml /* proj-data { - "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" + "generated_from": "4ffc8ccd7dfdb7674556487433ea9913" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" +#include "rapidcheck.h" #include "utils/json.h" #include #include @@ -48,6 +49,13 @@ struct adl_serializer<::FlexFlow::ConstantInitializerAttrs> { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ConstantInitializerAttrs> { + static Gen<::FlexFlow::ConstantInitializerAttrs> arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ConstantInitializerAttrs const &); std::ostream &operator<<(std::ostream &, ConstantInitializerAttrs const &); diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 3a80559d7b..511ec057fa 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h index 9493d2ffff..2ff17a9e54 100644 --- a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml /* proj-data { - "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" + "generated_from": "dd9cbe65dc4495b031aef40d353db928" } */ diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.h new file mode 100644 index 0000000000..a6d6a50a12 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_H + +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" +#include + +namespace rc { + +template <> +struct Arbitrary<::FlexFlow::UniformInitializerAttrs> { + static Gen<::FlexFlow::UniformInitializerAttrs> arbitrary(); +}; + +} + +#endif diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml index 11a6597c0a..8ee67b9d4b 100644 --- a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml @@ -5,7 +5,6 @@ features = [ "ord", "hash", "json", - # "rapidcheck", "fmt", ] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 6dda689d35..340794405f 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -4,6 +4,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" namespace FlexFlow { @@ -12,6 +13,11 @@ ParallelComputationGraph empty_parallel_computation_graph(); std::unordered_set get_parallel_layers(ParallelComputationGraph const &); +ParallelLayerAddedResult add_parallel_layer(ParallelComputationGraph &pcg, + ParallelLayerAttrs const &layer_attrs, + std::vector const &inputs, + std::vector const &output_labels); + std::vector get_layer_inputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -27,6 +33,8 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); +std::vector topological_ordering(ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h new file mode 100644 index 0000000000..8b59ab2b2f --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h @@ -0,0 +1,44 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "cb4fa8a3a6319d9b7de628a58d08bfed" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/fmt/vector.h" +#include +#include +#include + +namespace FlexFlow { +struct ParallelLayerAddedResult { + ParallelLayerAddedResult() = delete; + explicit ParallelLayerAddedResult( + ::FlexFlow::parallel_layer_guid_t const ¶llel_layer, + std::vector<::FlexFlow::parallel_tensor_guid_t> const &outputs); + + bool operator==(ParallelLayerAddedResult const &) const; + bool operator!=(ParallelLayerAddedResult const &) const; + bool operator<(ParallelLayerAddedResult const &) const; + bool operator>(ParallelLayerAddedResult const &) const; + bool operator<=(ParallelLayerAddedResult const &) const; + bool operator>=(ParallelLayerAddedResult const &) const; + ::FlexFlow::parallel_layer_guid_t parallel_layer; + std::vector<::FlexFlow::parallel_tensor_guid_t> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(ParallelLayerAddedResult const &); +std::ostream &operator<<(std::ostream &, ParallelLayerAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml new file mode 100644 index 0000000000..f3113255ef --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelLayerAddedResult" + +features = [ + "eq", + "ord", + "fmt", +] + +includes = [ + "", + "utils/fmt/vector.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", +] + +[[fields]] +name = "parallel_layer" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::parallel_tensor_guid_t>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h index cf0011d4ba..8b23599f1d 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml /* proj-data { - "generated_from": "9bb6e3cb7b0e523fae8f33bd8ad80d6d" + "generated_from": "1b3a0491865fd43c79afcf4939b56fae" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/operator_attrs.h" +#include "rapidcheck.h" #include "utils/stack_string.h" #include #include @@ -52,6 +53,13 @@ struct adl_serializer<::FlexFlow::ParallelLayerAttrs> { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ParallelLayerAttrs> { + static Gen<::FlexFlow::ParallelLayerAttrs> arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelLayerAttrs const &); std::ostream &operator<<(std::ostream &, ParallelLayerAttrs const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index f3f3c6a8bb..1ba9ac5487 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h index a9dfb1d163..c6baa1e138 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml /* proj-data { - "generated_from": "b3e086b380bbc41d99332e1463a34b28" + "generated_from": "3d641c90950f49a7bef664d0153c97f6" } */ @@ -16,6 +16,7 @@ #include "op-attrs/param_sync.dtg.h" #include "pcg/create_grad.dtg.h" #include "pcg/initializer_attrs.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -58,6 +59,13 @@ struct adl_serializer<::FlexFlow::ParallelTensorAttrs> { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ParallelTensorAttrs> { + static Gen<::FlexFlow::ParallelTensorAttrs> arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorAttrs const &); std::ostream &operator<<(std::ostream &, ParallelTensorAttrs const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index 1f81b56ec8..faf7159ad7 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/pcg/src/pcg/dataflow_graph/algorithms.cc b/lib/pcg/src/pcg/dataflow_graph/algorithms.cc new file mode 100644 index 0000000000..3ef04c95a3 --- /dev/null +++ b/lib/pcg/src/pcg/dataflow_graph/algorithms.cc @@ -0,0 +1 @@ +#include "pcg/dataflow_graph/algorithms.h" diff --git a/lib/pcg/src/pcg/initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializer_attrs.dtg.cc index 2a4e97db1e..44e1135869 100644 --- a/lib/pcg/src/pcg/initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializer_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializer_attrs.variant.toml /* proj-data { - "generated_from": "f66f3a89ea937e96a058d83ab52e2826" + "generated_from": "f4d932a4a7728ebfc28a23f2e6ca3201" } */ @@ -114,6 +114,22 @@ void adl_serializer<::FlexFlow::InitializerAttrs>::to_json( } } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::InitializerAttrs> + Arbitrary<::FlexFlow::InitializerAttrs>::arbitrary() { + return gen::oneOf( + gen::construct<::FlexFlow::InitializerAttrs>( + gen::arbitrary<::FlexFlow::GlorotUniformAttrs>()), + gen::construct<::FlexFlow::InitializerAttrs>( + gen::arbitrary<::FlexFlow::ZeroInitializerAttrs>()), + gen::construct<::FlexFlow::InitializerAttrs>( + gen::arbitrary<::FlexFlow::UniformInitializerAttrs>()), + gen::construct<::FlexFlow::InitializerAttrs>( + gen::arbitrary<::FlexFlow::NormInitializerAttrs>()), + gen::construct<::FlexFlow::InitializerAttrs>( + gen::arbitrary<::FlexFlow::ConstantInitializerAttrs>())); +} +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::InitializerAttrs const &x) { std::ostringstream oss; diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc index 2848d420b7..6c1ae1dfac 100644 --- a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml /* proj-data { - "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" + "generated_from": "4ffc8ccd7dfdb7674556487433ea9913" } */ @@ -67,6 +67,14 @@ void adl_serializer<::FlexFlow::ConstantInitializerAttrs>::to_json( } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ConstantInitializerAttrs> + Arbitrary<::FlexFlow::ConstantInitializerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ConstantInitializerAttrs>( + gen::arbitrary<::FlexFlow::DataTypeValue>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ConstantInitializerAttrs const &x) { std::ostringstream oss; diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.cc new file mode 100644 index 0000000000..947892a554 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.cc @@ -0,0 +1,20 @@ +#include "pcg/initializers/uniform_initializer_attrs.h" + +namespace rc { + +using ::FlexFlow::UniformInitializerAttrs; + +Gen Arbitrary::arbitrary() { + return gen::map>([](std::tuple const &generated) { + auto [f1, f2, seed] = generated; + float minval = std::min(f1, f2); + float maxval = std::max(f1, f2); + return ::FlexFlow::UniformInitializerAttrs{ + seed, + minval, + maxval, + }; + }); +}; + +} diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc index 4eb3bdc015..b66544d4b3 100644 --- a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml /* proj-data { - "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" + "generated_from": "dd9cbe65dc4495b031aef40d353db928" } */ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index daf2b67303..174ac07977 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,5 +1,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" +#include "pcg/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -14,6 +15,21 @@ std::unordered_set [&](Node const &n) { return parallel_layer_guid_t{n}; }); } +ParallelLayerAddedResult add_parallel_layer(ParallelComputationGraph &pcg, + ParallelLayerAttrs const &layer_attrs, + std::vector const &inputs, + std::vector const &output_labels) { + std::vector unwrapped_inputs = transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + OperatorAddedResult op_added = pcg.raw_graph.add_operator(layer_attrs, + unwrapped_inputs, + output_labels); + return ParallelLayerAddedResult{ + parallel_layer_guid_t{op_added.node}, + transform(op_added.outputs, + [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }), + }; +} + std::vector get_layer_inputs(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -46,4 +62,9 @@ ParallelTensorAttrs return pcg.raw_graph.at(t.raw_graph_output); } +std::vector topological_ordering(ParallelComputationGraph const &pcg) { + return transform(topological_ordering(pcg.raw_graph), + [](Node const &n) { return parallel_layer_guid_t{n}; }); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc new file mode 100644 index 0000000000..7b2dbf8de1 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "cb4fa8a3a6319d9b7de628a58d08bfed" +} +*/ + +#include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/fmt/vector.h" +#include +#include + +namespace FlexFlow { +ParallelLayerAddedResult::ParallelLayerAddedResult( + ::FlexFlow::parallel_layer_guid_t const ¶llel_layer, + std::vector<::FlexFlow::parallel_tensor_guid_t> const &outputs) + : parallel_layer(parallel_layer), outputs(outputs) {} +bool ParallelLayerAddedResult::operator==( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) == + std::tie(other.parallel_layer, other.outputs); +} +bool ParallelLayerAddedResult::operator!=( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) != + std::tie(other.parallel_layer, other.outputs); +} +bool ParallelLayerAddedResult::operator<( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) < + std::tie(other.parallel_layer, other.outputs); +} +bool ParallelLayerAddedResult::operator>( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) > + std::tie(other.parallel_layer, other.outputs); +} +bool ParallelLayerAddedResult::operator<=( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) <= + std::tie(other.parallel_layer, other.outputs); +} +bool ParallelLayerAddedResult::operator>=( + ParallelLayerAddedResult const &other) const { + return std::tie(this->parallel_layer, this->outputs) >= + std::tie(other.parallel_layer, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(ParallelLayerAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelLayerAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc index a16998c698..5a982b13ab 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml /* proj-data { - "generated_from": "9bb6e3cb7b0e523fae8f33bd8ad80d6d" + "generated_from": "1b3a0491865fd43c79afcf4939b56fae" } */ @@ -74,6 +74,15 @@ void adl_serializer<::FlexFlow::ParallelLayerAttrs>::to_json( } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ParallelLayerAttrs> + Arbitrary<::FlexFlow::ParallelLayerAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ParallelLayerAttrs>( + gen::arbitrary<::FlexFlow::PCGOperatorAttrs>(), + gen::arbitrary>>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelLayerAttrs const &x) { std::ostringstream oss; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc index 13be5e839f..88f7ed4d3c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml /* proj-data { - "generated_from": "b3e086b380bbc41d99332e1463a34b28" + "generated_from": "3d641c90950f49a7bef664d0153c97f6" } */ @@ -117,6 +117,17 @@ void adl_serializer<::FlexFlow::ParallelTensorAttrs>::to_json( } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ParallelTensorAttrs> + Arbitrary<::FlexFlow::ParallelTensorAttrs>::arbitrary() { + return gen::construct<::FlexFlow::ParallelTensorAttrs>( + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary>(), + gen::arbitrary>(), + gen::arbitrary<::FlexFlow::CreateGrad>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorAttrs const &x) { std::ostringstream oss; diff --git a/lib/pcg/test/src/pcg/dataflow_graph.cc b/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc similarity index 59% rename from lib/pcg/test/src/pcg/dataflow_graph.cc rename to lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc index 0b4b31512b..7032133cdb 100644 --- a/lib/pcg/test/src/pcg/dataflow_graph.cc +++ b/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc @@ -1,9 +1,9 @@ -#include "pcg/dataflow_graph/dataflow_graph.h" +#include "pcg/dataflow_graph/algorithms.h" #include "test/utils/doctest.h" #include "utils/fmt/unordered_set.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("DataflowGraph") { + TEST_CASE("get_inputs/get_outputs") { DataflowGraph g; int n1_label = 1; @@ -45,4 +45,32 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("topological_ordering") { + DataflowGraph g; + + int n1_label = 1; + int n2_label = 2; + int n3_label = 3; + + std::string o1_label = "o1"; + std::string o2_label = "o2"; + std::string o3_label = "o3"; + + OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); + Node n1 = n1_added.node; + MultiDiOutput o1 = get_only(n1_added.outputs); + + OperatorAddedResult n2_added = g.add_operator(n2_label, {o1}, {o2_label}); + Node n2 = n2_added.node; + MultiDiOutput o2 = get_only(n2_added.outputs); + + OperatorAddedResult n3_added = g.add_operator(n3_label, {o2}, {o3_label}); + Node n3 = n3_added.node; + MultiDiOutput o3 = get_only(n3_added.outputs); + + std::vector result = topological_ordering(g); + std::vector correct = { n1, n2, n3 }; + CHECK(result == correct); + } } diff --git a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc new file mode 100644 index 0000000000..dd272cab3c --- /dev/null +++ b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc @@ -0,0 +1,11 @@ +#include "test/utils/doctest.h" +#include "test/utils/rapidcheck.h" +#include "pcg/initializers/uniform_initializer_attrs.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Arbitrary") { + rc::dc_check("arbitrary generates valid", [](UniformInitializerAttrs const &attrs) { + RC_ASSERT(attrs.max_val >= attrs.min_val); + }); + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc new file mode 100644 index 0000000000..25d7e3afe7 --- /dev/null +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -0,0 +1,31 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "test/utils/rapidcheck.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("topological_ordering") { + // TODO(@lockshaw) should probably be replaced with a rapidcheck test that compares + // ParallelComputationGraph to DataflowGraph, but since we currently don't have rapidcheck + // generation for DataflowGraph this will have to do for now + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_label = some(); + ParallelTensorAttrs tensor_label = some(); + + ParallelLayerAddedResult layer1_added = add_parallel_layer(pcg, layer_label, {}, {tensor_label}); + parallel_layer_guid_t layer1 = layer1_added.parallel_layer; + parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + + ParallelLayerAddedResult layer2_added = add_parallel_layer(pcg, layer_label, {tensor1}, {tensor_label}); + parallel_layer_guid_t layer2 = layer2_added.parallel_layer; + parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); + + ParallelLayerAddedResult layer3_added = add_parallel_layer(pcg, layer_label, {tensor2}, {tensor_label}); + parallel_layer_guid_t layer3 = layer3_added.parallel_layer; + parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); + + std::vector result = topological_ordering(pcg); + std::vector correct = { layer1, layer2, layer3 }; + CHECK(result == correct); + } +} diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 87b42a90d2..4114b7a936 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -266,8 +266,6 @@ std::vector get_bfs_ordering(DiGraphView const &, std::unordered_set const &starting_points); std::vector get_topological_ordering(DiGraphView const &); -// std::vector get_topological_ordering(MultiDiGraphView const &); -// std::vector get_topological_ordering(OpenMultiDiGraphView const &); std::vector get_unchecked_topological_ordering(DiGraphView const &); std::vector get_edge_topological_ordering(DiGraphView const &); diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 0074877768..884b840afb 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -8,6 +8,7 @@ #include "utils/type_traits.h" #include #include +#include namespace FlexFlow { @@ -17,6 +18,10 @@ struct stack_basic_string { stack_basic_string(Char const *c) : contents(c, c + std::strlen(c)) {} + template + stack_basic_string(Iterator start, Iterator end) + : contents(start, end) { } + stack_basic_string(std::basic_string const &s) : stack_basic_string(s.c_str()) {} @@ -92,6 +97,20 @@ struct hash<::FlexFlow::stack_basic_string> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::stack_basic_string> { + static Gen<::FlexFlow::stack_basic_string> arbitrary() { + return gen::mapcat(gen::inRange(0, MAXSIZE), [](size_t size) { + return gen::container<::FlexFlow::stack_basic_string>( + size, gen::arbitrary()); + }); + } +}; + +} // namespace rc + namespace FlexFlow { static_assert(is_default_constructible>::value, diff --git a/lib/utils/test/common/include/test/utils/all.h b/lib/utils/test/common/include/test/utils/all.h index 308b58e630..ced1c9ce38 100644 --- a/lib/utils/test/common/include/test/utils/all.h +++ b/lib/utils/test/common/include/test/utils/all.h @@ -1,5 +1,2 @@ -#include "doctest.h" -#include "doctest/doctest.h" -#include "rapidcheck/doctest.h" -#include "rapidcheck/gen.h" -#include "rapidcheck/visitable.h" +#include "test/utils/doctest.h" +#include "test/utils/rapidcheck.h" diff --git a/lib/utils/test/common/include/test/utils/rapidcheck.h b/lib/utils/test/common/include/test/utils/rapidcheck.h new file mode 100644 index 0000000000..6f1ff5ad87 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/rapidcheck.h @@ -0,0 +1,4 @@ +#include "rapidcheck/doctest.h" +#include "rapidcheck/gen.h" +#include "rapidcheck/visitable.h" +#include "rapidcheck/some.h" diff --git a/lib/utils/test/common/include/test/utils/rapidcheck/some.h b/lib/utils/test/common/include/test/utils/rapidcheck/some.h new file mode 100644 index 0000000000..3db5e35052 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/rapidcheck/some.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_RAPIDCHECK_SOME_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_RAPIDCHECK_SOME_H + +#include + +namespace FlexFlow { + +template +T some() { + rc::Random r{}; + return rc::gen::arbitrary()(r).value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index 1836e0824a..308209b37b 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -1,4 +1,5 @@ #include "test/utils/doctest.h" +#include "test/utils/rapidcheck.h" #include "utils/stack_string.h" using namespace FlexFlow; @@ -80,4 +81,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::string stdStr = static_cast(str); CHECK(stdStr == "Hello"); } + + TEST_CASE("Arbitrary") { + constexpr std::size_t MAXSIZE = 10; + rc::dc_check("arbitrary returns valid", [](stack_string const &s) {} ); + } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 141cd30e95..b70bfce553 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,7 +1,7 @@ #include "test/utils/doctest.h" +#include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" #include -#include using namespace FlexFlow; @@ -78,10 +78,10 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.back() == 20); } - TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + TEST_CASE_TEMPLATE("Arbitrary", T, int, double, char) { constexpr std::size_t MAXSIZE = 10; - CHECK(rc::check("within bound", [](stack_vector v) { - return v.size() <= MAXSIZE; - })); + rc::dc_check("within bound" , [&](stack_vector v) { + RC_ASSERT(v.size() <= MAXSIZE); + }); } } From 42c1968fa6ab2f55fb22843befb239832cfd9200 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 18 Jun 2024 21:28:21 -0700 Subject: [PATCH 08/71] Pre multidigraph refactor --- flake.lock | 6 +- lib/utils/include/utils/graph/algorithms.h | 86 +------- .../graph/dataflow_graph/dataflow_edge.dtg.h | 49 +++++ .../dataflow_graph/dataflow_edge.struct.toml | 21 ++ .../dataflow_graph/dataflow_edge_query.dtg.h | 54 +++++ .../dataflow_edge_query.struct.toml | 29 +++ .../graph/dataflow_graph/dataflow_graph.h | 21 ++ .../dataflow_graph/dataflow_graph_view.h | 31 +++ .../graph/dataflow_graph/dataflow_input.dtg.h | 46 ++++ .../dataflow_graph/dataflow_input.struct.toml | 16 ++ .../dataflow_graph/dataflow_output.dtg.h | 47 ++++ .../dataflow_output.struct.toml | 20 ++ .../dataflow_output_query.dtg.h | 50 +++++ .../dataflow_output_query.struct.toml | 21 ++ .../graph/dataflow_graph/i_dataflow_graph.h | 18 ++ .../dataflow_graph/i_dataflow_graph_view.h | 24 ++ .../dataflow_graph/node_added_result.dtg.h | 45 ++++ .../node_added_result.struct.toml | 24 ++ lib/utils/include/utils/graph/diedge.h | 41 ---- .../graph/{ => digraph}/adjacency_digraph.h | 2 +- .../utils/graph/digraph/di_input.dtg.h | 46 ++++ .../utils/graph/digraph/di_input.struct.toml | 16 ++ .../utils/graph/digraph/di_output.dtg.h | 39 ++++ .../utils/graph/digraph/di_output.struct.toml | 15 ++ .../utils/graph/{ => digraph}/digraph.h | 40 +--- .../utils/graph/digraph/digraph_view.h | 41 ++++ .../utils/graph/digraph/directed_edge.dtg.h | 48 ++++ .../graph/digraph/directed_edge.struct.toml | 20 ++ .../graph/digraph/directed_edge_query.dtg.h | 42 ++++ .../utils/graph/digraph/directed_edge_query.h | 17 ++ .../digraph/directed_edge_query.struct.toml | 20 ++ .../include/utils/graph/digraph/i_digraph.h | 21 ++ .../utils/graph/digraph/i_digraph_view.h | 28 +++ .../include/utils/graph/digraph_interfaces.h | 37 ---- .../downward_open_multi_di_edge.dtg.h | 116 ++++++++++ .../downward_open_multi_di_edge.variant.toml | 21 ++ .../downward_open_multi_di_edge_query.dtg.h | 50 +++++ ...nward_open_multi_di_edge_query.struct.toml | 21 ++ .../downward_open_multidigraph.h | 48 ++++ .../downward_open_multidigraph_view.h | 43 ++++ .../i_downward_open_multidigraph.h | 22 ++ .../i_downward_open_multidigraph_view.h | 23 ++ lib/utils/include/utils/graph/multidiedge.h | 101 --------- lib/utils/include/utils/graph/multidigraph.h | 79 ------- .../adjacency_multidigraph.h | 5 +- .../utils/graph/multidigraph/i_multidigraph.h | 26 +++ .../graph/multidigraph/i_multidigraph_view.h | 24 ++ .../graph/multidigraph/multi_di_edge.dtg.h | 50 +++++ .../multidigraph/multi_di_edge.struct.toml | 24 ++ .../multidigraph/multi_di_edge_query.dtg.h | 50 +++++ .../graph/multidigraph/multi_di_edge_query.h | 17 ++ .../multi_di_edge_query.struct.toml | 21 ++ .../utils/graph/multidigraph/multidigraph.h | 47 ++++ .../graph/multidigraph/multidigraph_view.h | 40 ++++ .../utils/graph/multidigraph_interfaces.h | 45 ---- lib/utils/include/utils/graph/node.h | 110 ---------- lib/utils/include/utils/graph/node/graph.h | 41 ++++ .../include/utils/graph/node/graph_view.h | 31 +++ lib/utils/include/utils/graph/node/i_graph.h | 23 ++ .../include/utils/graph/node/i_graph_view.h | 21 ++ lib/utils/include/utils/graph/node/node.dtg.h | 46 ++++ .../include/utils/graph/node/node.struct.toml | 16 ++ .../include/utils/graph/node/node_query.dtg.h | 47 ++++ .../include/utils/graph/node/node_query.h | 15 ++ .../utils/graph/node/node_query.struct.toml | 17 ++ lib/utils/include/utils/graph/node_port.h | 22 -- lib/utils/include/utils/graph/open_edge.h | 73 ------- .../utils/graph/open_graph_interfaces.h | 82 ------- lib/utils/include/utils/graph/open_graphs.h | 206 ------------------ .../adjacency_openmultidigraph.h | 4 +- .../open_multidigraph/i_open_multidigraph.h | 20 ++ .../i_open_multidigraph_view.h | 20 ++ .../input_multi_di_edge.dtg.h | 48 ++++ .../open_multidigraph/input_multi_di_edge.h | 13 ++ .../input_multi_di_edge.struct.toml | 21 ++ .../input_multi_di_edge_query.dtg.h | 48 ++++ .../input_multi_di_edge_query.h | 13 ++ .../input_multi_di_edge_query.struct.toml | 17 ++ .../open_multi_di_edge.dtg.h | 126 +++++++++++ .../open_multidigraph/open_multi_di_edge.h | 14 ++ .../open_multi_di_edge.variant.toml | 26 +++ .../open_multi_di_edge_query.dtg.h | 53 +++++ .../open_multi_di_edge_query.h | 12 + .../open_multi_di_edge_query.struct.toml | 26 +++ .../open_multidigraph/open_multidigraph.h | 46 ++++ .../open_multidigraph_view.h | 41 ++++ .../output_multi_di_edge.dtg.h | 49 +++++ .../open_multidigraph/output_multi_di_edge.h | 13 ++ .../output_multi_di_edge.struct.toml | 21 ++ .../output_multi_di_edge_query.dtg.h | 48 ++++ .../output_multi_di_edge_query.h | 13 ++ .../output_multi_di_edge_query.struct.toml | 17 ++ lib/utils/include/utils/graph/query_set.h | 4 + lib/utils/include/utils/graph/traversal.h | 4 +- lib/utils/include/utils/graph/undirected.h | 113 ---------- .../hashmap_undirected_graph.h | 2 +- .../graph/undirected/i_undirected_graph.h | 23 ++ .../undirected/i_undirected_graph_view.h | 30 +++ .../utils/graph/undirected/undirected_edge.h | 33 +++ .../undirected/undirected_edge_query.dtg.h | 48 ++++ .../graph/undirected/undirected_edge_query.h | 15 ++ .../undirected_edge_query.struct.toml | 17 ++ .../utils/graph/undirected/undirected_graph.h | 47 ++++ .../graph/undirected/undirected_graph_view.h | 41 ++++ .../include/utils/graph/undirected_edge.h | 35 --- .../i_upward_open_multidigraph.h | 20 ++ .../i_upward_open_multidigraph_view.h | 21 ++ .../upward_open_multi_di_edge.dtg.h | 114 ++++++++++ .../upward_open_multi_di_edge.h | 13 ++ .../upward_open_multi_di_edge.variant.toml | 21 ++ .../upward_open_multi_di_edge_query.dtg.h | 50 +++++ ...pward_open_multi_di_edge_query.struct.toml | 21 ++ .../upward_open_multidigraph.h | 46 ++++ .../upward_open_multidigraph_view.h | 42 ++++ .../utils/graph/views/join_node_key.dtg.h | 49 +++++ .../graph/views/join_node_key.struct.toml | 21 ++ .../utils/graph/views/lr_direction.dtg.h | 40 ++++ .../utils/graph/views/lr_direction.enum.toml | 14 ++ .../include/utils/graph/{ => views}/views.h | 31 +-- lib/utils/src/graph/multidiedge.cc | 155 ------------- lib/utils/src/graph/node.cc | 63 ------ lib/utils/src/graph/open_edge.cc | 75 ------- lib/utils/src/graph/undirected_edge.cc | 23 -- lib/utils/src/{ => utils}/graph/algorithms.cc | 0 .../graph/dataflow_graph/dataflow_edge.dtg.cc | 64 ++++++ .../dataflow_graph/dataflow_edge_query.dtg.cc | 100 +++++++++ .../graph/dataflow_graph/dataflow_graph.cc | 10 + .../dataflow_graph/dataflow_graph_view.cc | 21 ++ .../dataflow_graph/dataflow_input.dtg.cc | 61 ++++++ .../dataflow_graph/dataflow_output.dtg.cc | 62 ++++++ .../dataflow_output_query.dtg.cc | 71 ++++++ .../dataflow_graph/i_dataflow_graph_view.cc | 20 ++ .../dataflow_graph/node_added_result.dtg.cc | 62 ++++++ .../graph/digraph}/adjacency_digraph.cc | 4 +- .../src/utils/graph/digraph/di_input.dtg.cc | 57 +++++ .../src/utils/graph/digraph/di_output.dtg.cc | 44 ++++ .../{graph => utils/graph/digraph}/digraph.cc | 18 +- .../src/utils/graph/digraph/digraph_view.cc | 18 ++ .../utils/graph/digraph/directed_edge.dtg.cc | 63 ++++++ .../graph/digraph/directed_edge_query.dtg.cc | 53 +++++ .../directed_graph/directed_edge_query.cc} | 8 +- .../downward_open_multi_di_edge.dtg.cc | 80 +++++++ .../downward_open_multi_di_edge_query.dtg.cc | 80 +++++++ .../i_downward_open_multidigraph.cc | 1 + .../i_downward_open_multidigraph_view.cc | 13 ++ .../src/{ => utils}/graph/labelled_graphs.cc | 0 .../multidigraph}/adjacency_multidigraph.cc | 2 +- .../graph/multidigraph/i_multidigraph.cc | 1 + .../graph/multidigraph/i_multidigraph_view.cc | 15 ++ .../graph/multidigraph/multi_di_edge.dtg.cc | 73 +++++++ .../graph/multidigraph/multi_di_edge_query.cc | 46 ++++ .../multidigraph/multi_di_edge_query.dtg.cc | 65 ++++++ .../graph/multidigraph}/multidigraph.cc | 21 +- lib/utils/src/utils/graph/node/graph.cc | 30 +++ lib/utils/src/utils/graph/node/graph_view.cc | 15 ++ .../src/utils/graph/node/i_graph_view.cc | 1 + lib/utils/src/utils/graph/node/node.dtg.cc | 57 +++++ lib/utils/src/utils/graph/node/node_query.cc | 31 +++ .../src/utils/graph/node/node_query.dtg.cc | 60 +++++ .../open_multidigraph/input_multi_di_edge.cc | 9 + .../input_multi_di_edge.dtg.cc | 70 ++++++ .../input_multi_di_edge_query.cc | 13 ++ .../input_multi_di_edge_query.dtg.cc | 67 ++++++ .../open_multidigraph/open_multi_di_edge.cc | 18 ++ .../open_multi_di_edge.dtg.cc | 79 +++++++ .../open_multi_di_edge_query.cc | 16 ++ .../open_multi_di_edge_query.dtg.cc | 101 +++++++++ .../open_multidigraph/output_multi_di_edge.cc | 9 + .../output_multi_di_edge.dtg.cc | 70 ++++++ .../output_multi_di_edge_query.cc | 13 ++ .../output_multi_di_edge_query.dtg.cc | 67 ++++++ .../src/{ => utils}/graph/serialparallel.cc | 0 .../graph/serialparallel_internal.h | 0 lib/utils/src/{ => utils}/graph/traversal.cc | 0 .../undirected}/hashmap_undirected_graph.cc | 4 +- .../utils/graph/undirected/undirected_edge.cc | 40 ++++ .../graph/undirected/undirected_edge_query.cc | 16 ++ .../undirected/undirected_edge_query.dtg.cc | 61 ++++++ .../graph/undirected/undirected_graph.cc} | 20 +- .../graph/undirected/undirected_graph_view.cc | 20 ++ .../i_upward_open_multidigraph.cc | 11 + .../i_upward_open_multidigraph_view.cc | 16 ++ .../upward_open_multi_di_edge.cc | 13 ++ .../upward_open_multi_di_edge.dtg.cc | 79 +++++++ .../upward_open_multi_di_edge_query.dtg.cc | 78 +++++++ lib/utils/src/{ => utils}/graph/views.cc | 124 ++++++----- .../utils/graph/views/join_node_key.dtg.cc | 70 ++++++ .../src/utils/graph/views/lr_direction.dtg.cc | 70 ++++++ 188 files changed, 5592 insertions(+), 1533 deletions(-) create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml delete mode 100644 lib/utils/include/utils/graph/diedge.h rename lib/utils/include/utils/graph/{ => digraph}/adjacency_digraph.h (96%) create mode 100644 lib/utils/include/utils/graph/digraph/di_input.dtg.h create mode 100644 lib/utils/include/utils/graph/digraph/di_input.struct.toml create mode 100644 lib/utils/include/utils/graph/digraph/di_output.dtg.h create mode 100644 lib/utils/include/utils/graph/digraph/di_output.struct.toml rename lib/utils/include/utils/graph/{ => digraph}/digraph.h (51%) create mode 100644 lib/utils/include/utils/graph/digraph/digraph_view.h create mode 100644 lib/utils/include/utils/graph/digraph/directed_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/digraph/directed_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/digraph/directed_edge_query.h create mode 100644 lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/digraph/i_digraph.h create mode 100644 lib/utils/include/utils/graph/digraph/i_digraph_view.h delete mode 100644 lib/utils/include/utils/graph/digraph_interfaces.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h delete mode 100644 lib/utils/include/utils/graph/multidiedge.h delete mode 100644 lib/utils/include/utils/graph/multidigraph.h rename lib/utils/include/utils/graph/{ => multidigraph}/adjacency_multidigraph.h (92%) create mode 100644 lib/utils/include/utils/graph/multidigraph/i_multidigraph.h create mode 100644 lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h create mode 100644 lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/multidigraph/multidigraph.h create mode 100644 lib/utils/include/utils/graph/multidigraph/multidigraph_view.h delete mode 100644 lib/utils/include/utils/graph/multidigraph_interfaces.h delete mode 100644 lib/utils/include/utils/graph/node.h create mode 100644 lib/utils/include/utils/graph/node/graph.h create mode 100644 lib/utils/include/utils/graph/node/graph_view.h create mode 100644 lib/utils/include/utils/graph/node/i_graph.h create mode 100644 lib/utils/include/utils/graph/node/i_graph_view.h create mode 100644 lib/utils/include/utils/graph/node/node.dtg.h create mode 100644 lib/utils/include/utils/graph/node/node.struct.toml create mode 100644 lib/utils/include/utils/graph/node/node_query.dtg.h create mode 100644 lib/utils/include/utils/graph/node/node_query.h create mode 100644 lib/utils/include/utils/graph/node/node_query.struct.toml delete mode 100644 lib/utils/include/utils/graph/node_port.h delete mode 100644 lib/utils/include/utils/graph/open_edge.h delete mode 100644 lib/utils/include/utils/graph/open_graph_interfaces.h delete mode 100644 lib/utils/include/utils/graph/open_graphs.h rename lib/utils/include/utils/graph/{ => open_multidigraph}/adjacency_openmultidigraph.h (94%) create mode 100644 lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h create mode 100644 lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml delete mode 100644 lib/utils/include/utils/graph/undirected.h rename lib/utils/include/utils/graph/{ => undirected}/hashmap_undirected_graph.h (96%) create mode 100644 lib/utils/include/utils/graph/undirected/i_undirected_graph.h create mode 100644 lib/utils/include/utils/graph/undirected/i_undirected_graph_view.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge_query.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/undirected/undirected_graph.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_graph_view.h delete mode 100644 lib/utils/include/utils/graph/undirected_edge.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h create mode 100644 lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h create mode 100644 lib/utils/include/utils/graph/views/join_node_key.dtg.h create mode 100644 lib/utils/include/utils/graph/views/join_node_key.struct.toml create mode 100644 lib/utils/include/utils/graph/views/lr_direction.dtg.h create mode 100644 lib/utils/include/utils/graph/views/lr_direction.enum.toml rename lib/utils/include/utils/graph/{ => views}/views.h (95%) delete mode 100644 lib/utils/src/graph/multidiedge.cc delete mode 100644 lib/utils/src/graph/node.cc delete mode 100644 lib/utils/src/graph/open_edge.cc delete mode 100644 lib/utils/src/graph/undirected_edge.cc rename lib/utils/src/{ => utils}/graph/algorithms.cc (100%) create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc rename lib/utils/src/{graph => utils/graph/digraph}/adjacency_digraph.cc (92%) create mode 100644 lib/utils/src/utils/graph/digraph/di_input.dtg.cc create mode 100644 lib/utils/src/utils/graph/digraph/di_output.dtg.cc rename lib/utils/src/{graph => utils/graph/digraph}/digraph.cc (66%) create mode 100644 lib/utils/src/utils/graph/digraph/digraph_view.cc create mode 100644 lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc rename lib/utils/src/{graph/diedge.cc => utils/graph/directed_graph/directed_edge_query.cc} (83%) create mode 100644 lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc create mode 100644 lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc rename lib/utils/src/{ => utils}/graph/labelled_graphs.cc (100%) rename lib/utils/src/{graph => utils/graph/multidigraph}/adjacency_multidigraph.cc (97%) create mode 100644 lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc create mode 100644 lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc create mode 100644 lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc create mode 100644 lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc rename lib/utils/src/{graph => utils/graph/multidigraph}/multidigraph.cc (70%) create mode 100644 lib/utils/src/utils/graph/node/graph.cc create mode 100644 lib/utils/src/utils/graph/node/graph_view.cc create mode 100644 lib/utils/src/utils/graph/node/i_graph_view.cc create mode 100644 lib/utils/src/utils/graph/node/node.dtg.cc create mode 100644 lib/utils/src/utils/graph/node/node_query.cc create mode 100644 lib/utils/src/utils/graph/node/node_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc create mode 100644 lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc rename lib/utils/src/{ => utils}/graph/serialparallel.cc (100%) rename lib/utils/src/{ => utils}/graph/serialparallel_internal.h (100%) rename lib/utils/src/{ => utils}/graph/traversal.cc (100%) rename lib/utils/src/{graph => utils/graph/undirected}/hashmap_undirected_graph.cc (93%) create mode 100644 lib/utils/src/utils/graph/undirected/undirected_edge.cc create mode 100644 lib/utils/src/utils/graph/undirected/undirected_edge_query.cc create mode 100644 lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc rename lib/utils/src/{graph/undirected.cc => utils/graph/undirected/undirected_graph.cc} (66%) create mode 100644 lib/utils/src/utils/graph/undirected/undirected_graph_view.cc create mode 100644 lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc create mode 100644 lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc create mode 100644 lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc create mode 100644 lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc rename lib/utils/src/{ => utils}/graph/views.cc (80%) create mode 100644 lib/utils/src/utils/graph/views/join_node_key.dtg.cc create mode 100644 lib/utils/src/utils/graph/views/lr_direction.dtg.cc diff --git a/flake.lock b/flake.lock index dde0c989c3..3a9fffbdd1 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1717990636, - "narHash": "sha256-wqIc2qAkRfVp2d+NAVIYPKMx7YYpu8iBGHHT1U5sxhE=", + "lastModified": 1718643207, + "narHash": "sha256-VhPjZi4Zl4XgaagzqI0Z2bgFoJhF2SblwUq4eZR08DU=", "owner": "lockshaw", "repo": "proj", - "rev": "f7e20a9c232dda1b945a775d91e1ed4f525b5f51", + "rev": "5dc9d970f0fe67e65146b2ba1d7aa44d11324d48", "type": "github" }, "original": { diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 4114b7a936..52a1f71d31 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -1,21 +1,14 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H #define _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H -#include "digraph.h" -#include "labelled_graphs.h" -#include "multidigraph.h" -#include "node.h" -#include "open_graphs.h" -#include "undirected.h" -#include "utils/containers.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/node/graph.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/open_multidigraph/open_multidigraph.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" +#include "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" #include "utils/dot_file.h" -#include "utils/exception.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/open_graph_interfaces.h" -#include "utils/optional.h" -#include "views.h" -#include -#include namespace FlexFlow { @@ -25,11 +18,7 @@ std::vector add_nodes(DiGraph &, int); std::vector add_nodes(MultiDiGraph &, int); std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); -std::vector add_node_ports(MultiDiGraph &, int); - std::unordered_set get_nodes(GraphView const &); -std::unordered_set get_present_node_ports(MultiDiGraphView const &); - std::unordered_set get_nodes(OpenMultiDiEdge const &); std::unordered_set query_nodes(GraphView const &, @@ -104,9 +93,6 @@ std::unordered_set get_edges(OpenMultiDiGraphView const &); std::unordered_set get_node_edges(UndirectedGraphView const &, Node const &); -std::unordered_set get_outputs(MultiDiGraphView const &); -std::unordered_set get_inputs(MultiDiGraphView const &); - std::unordered_set get_open_outputs(OpenMultiDiGraphView const &); std::unordered_set @@ -128,11 +114,6 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &, std::unordered_set get_incoming_edges(DiGraphView const &, std::unordered_set const &); -std::unordered_map> - get_incoming_edges_by_idx(MultiDiGraphView const &, Node const &); -std::unordered_map> - get_outgoing_edges_by_idx(MultiDiGraphView const &, Node const &); - std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); std::unordered_set get_outgoing_edges(DiGraphView const &, @@ -165,59 +146,6 @@ Node get_dst_node(MultiDiEdge const &); Node get_dst_node(InputMultiDiEdge const &); Node get_src_node(OutputMultiDiEdge const &); -struct GetSrcNodeFunctor { - template - Node operator()(T const &t) const { - return get_src_node(t); - } -}; - -struct GetDstNodeFunctor { - template - Node operator()(T const &t) const { - return get_dst_node(t); - } -}; - -template -Node get_src_node(std::variant const &t) { - return visit(GetSrcNodeFunctor{}, t); -} - -template -Node get_dst_node(std::variant const &t) { - return visit(GetDstNodeFunctor{}, t); -} - -NodePort get_src_idx(MultiDiEdge const &); -NodePort get_dst_idx(MultiDiEdge const &); -NodePort get_dst_idx(InputMultiDiEdge const &); -NodePort get_src_idx(OutputMultiDiEdge const &); - -struct GetSrcIdxFunctor { - template - NodePort operator()(T const &t) const { - return get_src_idx(t); - } -}; - -struct GetDstIdxFunctor { - template - NodePort operator()(T const &t) const { - return get_dst_idx(t); - } -}; - -template -NodePort get_src_idx(std::variant const &t) { - return visit(GetSrcIdxFunctor{}, t); -} - -template -NodePort get_dst_idx(std::variant const &t) { - return visit(GetDstIdxFunctor{}, t); -} - std::unordered_set get_neighbors(UndirectedGraphView const &, Node const &); std::unordered_set get_neighbors(DiGraphView const &, Node const &); diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.h new file mode 100644 index 0000000000..6cad9ddb16 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml +/* proj-data +{ + "generated_from": "4728f139efc6884057f39e38f44a791b" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowEdge { + DataflowEdge() = delete; + explicit DataflowEdge(::FlexFlow::DataflowOutput const &src, + ::FlexFlow::DataflowInput const &dst); + + bool operator==(DataflowEdge const &) const; + bool operator!=(DataflowEdge const &) const; + bool operator<(DataflowEdge const &) const; + bool operator>(DataflowEdge const &) const; + bool operator<=(DataflowEdge const &) const; + bool operator>=(DataflowEdge const &) const; + ::FlexFlow::DataflowOutput src; + ::FlexFlow::DataflowInput dst; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowEdge> { + size_t operator()(::FlexFlow::DataflowEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowEdge const &); +std::ostream &operator<<(std::ostream &, DataflowEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml new file mode 100644 index 0000000000..a3237dde09 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::DataflowOutput" + +[[fields]] +name = "dst" +type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h new file mode 100644 index 0000000000..49fdb24992 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml +/* proj-data +{ + "generated_from": "684726a7add4aa912e194335fcfe91ab" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowEdgeQuery { + DataflowEdgeQuery() = delete; + explicit DataflowEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &src_nodes, + ::FlexFlow::query_set const &src_idxs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dst_nodes, + ::FlexFlow::query_set const &dst_idxs); + + bool operator==(DataflowEdgeQuery const &) const; + bool operator!=(DataflowEdgeQuery const &) const; + bool operator<(DataflowEdgeQuery const &) const; + bool operator>(DataflowEdgeQuery const &) const; + bool operator<=(DataflowEdgeQuery const &) const; + bool operator>=(DataflowEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> src_nodes; + ::FlexFlow::query_set src_idxs; + ::FlexFlow::query_set<::FlexFlow::Node> dst_nodes; + ::FlexFlow::query_set dst_idxs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowEdgeQuery> { + size_t operator()(::FlexFlow::DataflowEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowEdgeQuery const &); +std::ostream &operator<<(std::ostream &, DataflowEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml new file mode 100644 index 0000000000..c941bbf985 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "DataflowEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node.dtg.h", +] + +[[fields]] +name = "src_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "src_idxs" +type = "::FlexFlow::query_set" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_idxs" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h new file mode 100644 index 0000000000..f5f0e669b4 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" +#include "utils/graph/dataflow_graph/i_dataflow_graph.h" + +namespace FlexFlow { + +struct DataflowGraph : virtual DataflowGraphView { +public: + NodeAddedResult add_node(std::vector const &inputs, + int num_outputs); +private: + IDataflowGraph const &get_interface() const; + IDataflowGraph &get_interface(); +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h new file mode 100644 index 0000000000..bdbf204882 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" + +namespace FlexFlow { + +struct DataflowGraphView : virtual MultiDiGraphView { + DataflowGraphView(DataflowGraphView const &) = default; + DataflowGraphView &operator=(DataflowGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(DataflowEdgeQuery const &) const; + std::unordered_set query_outputs(DataflowOutputQuery const &) const; + + template + static typename std::enable_if::value, + DataflowGraphView>::type + create(Args &&...args) { + return DataflowGraphView(make_cow_ptr(std::forward(args)...)); + } +private: + IDataflowGraphView const &get_interface() const; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h new file mode 100644 index 0000000000..e98994ecf1 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml +/* proj-data +{ + "generated_from": "9fc7657f7fcc71fdad9e6a5040771ad7" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_INPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_INPUT_DTG_H + +#include "fmt/format.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowInput { + DataflowInput() = delete; + explicit DataflowInput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(DataflowInput const &) const; + bool operator!=(DataflowInput const &) const; + bool operator<(DataflowInput const &) const; + bool operator>(DataflowInput const &) const; + bool operator<=(DataflowInput const &) const; + bool operator>=(DataflowInput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowInput> { + size_t operator()(::FlexFlow::DataflowInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInput const &); +std::ostream &operator<<(std::ostream &, DataflowInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_INPUT_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml new file mode 100644 index 0000000000..19da01ab9f --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DataflowInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h new file mode 100644 index 0000000000..4938220290 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml +/* proj-data +{ + "generated_from": "b704f2549a69ee6bfc1c5e28df421f9c" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowOutput { + DataflowOutput() = delete; + explicit DataflowOutput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(DataflowOutput const &) const; + bool operator!=(DataflowOutput const &) const; + bool operator<(DataflowOutput const &) const; + bool operator>(DataflowOutput const &) const; + bool operator<=(DataflowOutput const &) const; + bool operator>=(DataflowOutput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowOutput> { + size_t operator()(::FlexFlow::DataflowOutput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutput const &); +std::ostream &operator<<(std::ostream &, DataflowOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml new file mode 100644 index 0000000000..6f2ce25f2b --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DataflowOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h new file mode 100644 index 0000000000..5a122c6d51 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml +/* proj-data +{ + "generated_from": "6f662c3c4d285a4fd3c60713e6fc67fa" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowOutputQuery { + DataflowOutputQuery() = delete; + explicit DataflowOutputQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &nodes, + ::FlexFlow::query_set const &output_idxs); + + bool operator==(DataflowOutputQuery const &) const; + bool operator!=(DataflowOutputQuery const &) const; + bool operator<(DataflowOutputQuery const &) const; + bool operator>(DataflowOutputQuery const &) const; + bool operator<=(DataflowOutputQuery const &) const; + bool operator>=(DataflowOutputQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> nodes; + ::FlexFlow::query_set output_idxs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowOutputQuery> { + size_t operator()(::FlexFlow::DataflowOutputQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutputQuery const &); +std::ostream &operator<<(std::ostream &, DataflowOutputQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml new file mode 100644 index 0000000000..a61edbcdb0 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowOutputQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node.dtg.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "output_idxs" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h new file mode 100644 index 0000000000..94fd54802b --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" + +namespace FlexFlow { + +struct IDataflowGraph : virtual public IDataflowGraphView { + virtual NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) = 0; + virtual IDataflowGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h new file mode 100644 index 0000000000..5a1d29f9dc --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/multidigraph/i_multidigraph_view.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" + +namespace FlexFlow { + +struct IDataflowGraphView : virtual public IMultiDiGraphView { + virtual std::unordered_set query_edges(DataflowEdgeQuery const &) const = 0; + virtual std::unordered_set query_outputs(DataflowOutputQuery const &) const = 0; + + std::unordered_set query_edges(MultiDiEdgeQuery const &) const override final; + + virtual ~IDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h new file mode 100644 index 0000000000..2a8159576c --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml +/* proj-data +{ + "generated_from": "4536bb54376e2e221e0ff29347e81662" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_NODE_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_NODE_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "utils/fmt/vector.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct NodeAddedResult { + NodeAddedResult() = delete; + explicit NodeAddedResult( + ::FlexFlow::Node const &node, + std::vector<::FlexFlow::DataflowOutput> const &outputs); + + bool operator==(NodeAddedResult const &) const; + bool operator!=(NodeAddedResult const &) const; + bool operator<(NodeAddedResult const &) const; + bool operator>(NodeAddedResult const &) const; + bool operator<=(NodeAddedResult const &) const; + bool operator>=(NodeAddedResult const &) const; + ::FlexFlow::Node node; + std::vector<::FlexFlow::DataflowOutput> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(NodeAddedResult const &); +std::ostream &operator<<(std::ostream &, NodeAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_NODE_ADDED_RESULT_DTG_H diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml new file mode 100644 index 0000000000..515541eb71 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "NodeAddedResult" + +features = [ + "eq", + "ord", + "fmt", +] + +includes = [ + "", + "utils/graph/node.dtg.h", + "utils/graph/multidigraph/multi_di_edge.dtg.h", + "utils/fmt/vector.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/diedge.h b/lib/utils/include/utils/graph/diedge.h deleted file mode 100644 index 75b5068271..0000000000 --- a/lib/utils/include/utils/graph/diedge.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIEDGE -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIEDGE - -#include "node.h" -#include "query_set.h" - -namespace FlexFlow { - -struct DiInput { - Node dst; -}; -FF_VISITABLE_STRUCT(DiInput, dst); -FF_VISIT_FMTABLE(DiInput); - -struct DiOutput { - Node src; -}; -FF_VISITABLE_STRUCT(DiOutput, src); -FF_VISIT_FMTABLE(DiOutput); - -struct DirectedEdge : DiInput, DiOutput {}; -FF_VISITABLE_STRUCT(DirectedEdge, src, dst); -FF_VISIT_FMTABLE(DirectedEdge); - -struct DirectedEdgeQuery { - query_set srcs; - query_set dsts; - - static DirectedEdgeQuery all(); -}; -FF_VISITABLE_STRUCT(DirectedEdgeQuery, srcs, dsts); -FF_VISIT_FMTABLE(DirectedEdgeQuery); - -bool matches_edge(DirectedEdgeQuery const &, DirectedEdge const &); - -DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &, - DirectedEdgeQuery const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/adjacency_digraph.h b/lib/utils/include/utils/graph/digraph/adjacency_digraph.h similarity index 96% rename from lib/utils/include/utils/graph/adjacency_digraph.h rename to lib/utils/include/utils/graph/digraph/adjacency_digraph.h index 6909821382..9a2e13a3a5 100644 --- a/lib/utils/include/utils/graph/adjacency_digraph.h +++ b/lib/utils/include/utils/graph/digraph/adjacency_digraph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_ADJACENCY_DIGRAPH_H #define _FLEXFLOW_UTILS_GRAPH_ADJACENCY_DIGRAPH_H -#include "digraph.h" +#include "utils/graph/digraph/digraph.h" #include #include diff --git a/lib/utils/include/utils/graph/digraph/di_input.dtg.h b/lib/utils/include/utils/graph/digraph/di_input.dtg.h new file mode 100644 index 0000000000..ed65cfd7a9 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_input.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/di_input.struct.toml +/* proj-data +{ + "generated_from": "19ab2e465577ae9e7add8b73c63e671f" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_INPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_INPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DiInput { + DiInput() = delete; + explicit DiInput(::FlexFlow::Node const &dst); + + bool operator==(DiInput const &) const; + bool operator!=(DiInput const &) const; + bool operator<(DiInput const &) const; + bool operator>(DiInput const &) const; + bool operator<=(DiInput const &) const; + bool operator>=(DiInput const &) const; + ::FlexFlow::Node dst; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DiInput> { + size_t operator()(::FlexFlow::DiInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DiInput const &); +std::ostream &operator<<(std::ostream &, DiInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_INPUT_DTG_H diff --git a/lib/utils/include/utils/graph/digraph/di_input.struct.toml b/lib/utils/include/utils/graph/digraph/di_input.struct.toml new file mode 100644 index 0000000000..1bd11e069c --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DiInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/di_output.dtg.h b/lib/utils/include/utils/graph/digraph/di_output.dtg.h new file mode 100644 index 0000000000..e88d929a5d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_output.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/di_output.struct.toml +/* proj-data +{ + "generated_from": "a8f3fc2ad9e00f3c29a6dcd4658199ba" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_OUTPUT_DTG_H + +#include "utils/graph/node.dtg.h" +#include +#include + +namespace FlexFlow { +struct DiOutput { + DiOutput() = delete; + explicit DiOutput(::FlexFlow::Node const &src); + + bool operator==(DiOutput const &) const; + bool operator!=(DiOutput const &) const; + bool operator<(DiOutput const &) const; + bool operator>(DiOutput const &) const; + bool operator<=(DiOutput const &) const; + bool operator>=(DiOutput const &) const; + ::FlexFlow::Node src; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DiOutput> { + size_t operator()(::FlexFlow::DiOutput const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_OUTPUT_DTG_H diff --git a/lib/utils/include/utils/graph/digraph/di_output.struct.toml b/lib/utils/include/utils/graph/digraph/di_output.struct.toml new file mode 100644 index 0000000000..f678af132a --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/di_output.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DiOutput" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph/node.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h similarity index 51% rename from lib/utils/include/utils/graph/digraph.h rename to lib/utils/include/utils/graph/digraph/digraph.h index 7a385563ef..016b5ce513 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -1,44 +1,14 @@ #ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIGRAPH #define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIGRAPH -#include "cow_ptr_t.h" -#include "digraph_interfaces.h" -#include "node.h" -#include "utils/optional.h" -#include "utils/unique.h" -#include "utils/visitable.h" -#include +#include "utils/graph/cow_ptr_t.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/digraph/i_digraph.h" namespace FlexFlow { -struct DiGraphView : virtual public GraphView { -public: - using Edge = DirectedEdge; - using EdgeQuery = DirectedEdgeQuery; - - DiGraphView(DiGraphView const &) = default; - DiGraphView &operator=(DiGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if::value, - DiGraphView>::type - create(Args &&...args) { - return DiGraphView(make_cow_ptr(std::forward(args)...)); - } - -protected: - using GraphView::GraphView; - -private: - IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); - struct DiGraph : virtual DiGraphView { public: using Edge = DirectedEdge; diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h new file mode 100644 index 0000000000..367dcfa4a8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIGRAPH_VIEW_H + +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/node/graph_view.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" +#include "utils/graph/digraph/i_digraph_view.h" + +namespace FlexFlow { + +struct DiGraphView : virtual public GraphView { +public: + using Edge = DirectedEdge; + using EdgeQuery = DirectedEdgeQuery; + + DiGraphView(DiGraphView const &) = default; + DiGraphView &operator=(DiGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if::value, + DiGraphView>::type + create(Args &&...args) { + return DiGraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + using GraphView::GraphView; + +private: + IDiGraphView const &get_ptr() const; + + friend struct GraphInternal; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/directed_edge.dtg.h b/lib/utils/include/utils/graph/digraph/directed_edge.dtg.h new file mode 100644 index 0000000000..ad64c8184f --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/directed_edge.struct.toml +/* proj-data +{ + "generated_from": "406f818eb74797f6ea07231506a56f81" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DirectedEdge { + DirectedEdge() = delete; + explicit DirectedEdge(::FlexFlow::Node const &src, + ::FlexFlow::Node const &dst); + + bool operator==(DirectedEdge const &) const; + bool operator!=(DirectedEdge const &) const; + bool operator<(DirectedEdge const &) const; + bool operator>(DirectedEdge const &) const; + bool operator<=(DirectedEdge const &) const; + bool operator>=(DirectedEdge const &) const; + ::FlexFlow::Node src; + ::FlexFlow::Node dst; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DirectedEdge> { + size_t operator()(::FlexFlow::DirectedEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DirectedEdge const &); +std::ostream &operator<<(std::ostream &, DirectedEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml b/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml new file mode 100644 index 0000000000..9c17bb0325 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h new file mode 100644 index 0000000000..716a3c5fc6 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h @@ -0,0 +1,42 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml +/* proj-data +{ + "generated_from": "294ae0103df2a3c388a2ce140c271f4e" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include + +namespace FlexFlow { +struct DirectedEdgeQuery { + DirectedEdgeQuery() = delete; + explicit DirectedEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts); + + bool operator==(DirectedEdgeQuery const &) const; + bool operator!=(DirectedEdgeQuery const &) const; + bool operator<(DirectedEdgeQuery const &) const; + bool operator>(DirectedEdgeQuery const &) const; + bool operator<=(DirectedEdgeQuery const &) const; + bool operator>=(DirectedEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> srcs; + ::FlexFlow::query_set<::FlexFlow::Node> dsts; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(DirectedEdgeQuery const &); +std::ostream &operator<<(std::ostream &, DirectedEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DIRECTED_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.h b/lib/utils/include/utils/graph/digraph/directed_edge_query.h new file mode 100644 index 0000000000..f7e2aac86d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIRECTED_GRAPH_DIRECTED_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIRECTED_GRAPH_DIRECTED_EDGE_QUERY_H + +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" + +namespace FlexFlow { + +DirectedEdgeQuery directed_edge_query_all(); +bool matches_edge(DirectedEdgeQuery const &, DirectedEdge const &); +DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &, + DirectedEdgeQuery const &); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml b/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml new file mode 100644 index 0000000000..2ede557642 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "DirectedEdgeQuery" +features = [ + "eq", + "ord", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dsts" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/digraph/i_digraph.h b/lib/utils/include/utils/graph/digraph/i_digraph.h new file mode 100644 index 0000000000..3f4d5a44d4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/i_digraph.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_I_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_I_DIGRAPH_H + +#include "utils/graph/digraph/i_digraph_view.h" + +namespace FlexFlow { + +struct IDiGraph : virtual public IDiGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &) = 0; + virtual void remove_node_unsafe(Node const &) = 0; + virtual void add_edge(Edge const &) = 0; + virtual void remove_edge(Edge const &) = 0; + virtual IDiGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDiGraph); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/i_digraph_view.h b/lib/utils/include/utils/graph/digraph/i_digraph_view.h new file mode 100644 index 0000000000..8108d22cd0 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/i_digraph_view.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_I_DIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_I_DIGRAPH_VIEW_H + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/node/i_graph_view.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" + +namespace FlexFlow { + +struct IDiGraphView : virtual public IGraphView { +public: + using Edge = DirectedEdge; + using EdgeQuery = DirectedEdgeQuery; + + IDiGraphView() = default; + + IDiGraphView(IDiGraphView const &) = delete; + IDiGraphView &operator=(IDiGraphView const &) = delete; + + virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; + virtual ~IDiGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph_interfaces.h b/lib/utils/include/utils/graph/digraph_interfaces.h deleted file mode 100644 index 812caee902..0000000000 --- a/lib/utils/include/utils/graph/digraph_interfaces.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIGRAPH_INTERFACES -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_DIGRAPH_INTERFACES - -#include "diedge.h" -#include "node.h" -#include "utils/type_traits.h" - -namespace FlexFlow { - -struct IDiGraphView : virtual public IGraphView { -public: - using Edge = DirectedEdge; - using EdgeQuery = DirectedEdgeQuery; - - IDiGraphView() = default; - - IDiGraphView(IDiGraphView const &) = delete; - IDiGraphView &operator=(IDiGraphView const &) = delete; - - virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; - virtual ~IDiGraphView() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDiGraphView); - -struct IDiGraph : virtual public IDiGraphView { - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &) = 0; - virtual void remove_node_unsafe(Node const &) = 0; - virtual void add_edge(Edge const &) = 0; - virtual void remove_edge(Edge const &) = 0; - virtual IDiGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h new file mode 100644 index 0000000000..b41badcba0 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h @@ -0,0 +1,116 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "a48025d66b3bdc8eec931e33694b0a22" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct DownwardOpenMultiDiEdge { + DownwardOpenMultiDiEdge() = delete; + explicit DownwardOpenMultiDiEdge(::FlexFlow::OutputMultiDiEdge const &); + explicit DownwardOpenMultiDiEdge(::FlexFlow::MultiDiEdge const &); + template + static constexpr bool IsPartOfDownwardOpenMultiDiEdge_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OutputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type DownwardOpenMultiDiEdge", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OutputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type DownwardOpenMultiDiEdge", + this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfDownwardOpenMultiDiEdge_v, + "DownwardOpenMultiDiEdge::has() expected one of " + "[::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfDownwardOpenMultiDiEdge_v, + "DownwardOpenMultiDiEdge::get() expected one of " + "[::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfDownwardOpenMultiDiEdge_v, + "DownwardOpenMultiDiEdge::get() expected one of " + "[::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(DownwardOpenMultiDiEdge const &) const; + bool operator!=(DownwardOpenMultiDiEdge const &) const; + bool operator<(DownwardOpenMultiDiEdge const &) const; + bool operator>(DownwardOpenMultiDiEdge const &) const; + bool operator<=(DownwardOpenMultiDiEdge const &) const; + bool operator>=(DownwardOpenMultiDiEdge const &) const; + std::variant<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DownwardOpenMultiDiEdge> { + size_t operator()(::FlexFlow::DownwardOpenMultiDiEdge const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::DownwardOpenMultiDiEdge const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::DownwardOpenMultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml new file mode 100644 index 0000000000..ba6b304636 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DownwardOpenMultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h", + "utils/graph/multidigraph/multi_di_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputMultiDiEdge" +key = "output_edge" + +[[values]] +type = "::FlexFlow::MultiDiEdge" +key = "standard_edge" diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..ad3b468582 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "396fddca0f20f2459ee9938138d3fc40" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DownwardOpenMultiDiEdgeQuery { + DownwardOpenMultiDiEdgeQuery() = delete; + explicit DownwardOpenMultiDiEdgeQuery( + ::FlexFlow::OutputMultiDiEdgeQuery const &output_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query); + + bool operator==(DownwardOpenMultiDiEdgeQuery const &) const; + bool operator!=(DownwardOpenMultiDiEdgeQuery const &) const; + bool operator<(DownwardOpenMultiDiEdgeQuery const &) const; + bool operator>(DownwardOpenMultiDiEdgeQuery const &) const; + bool operator<=(DownwardOpenMultiDiEdgeQuery const &) const; + bool operator>=(DownwardOpenMultiDiEdgeQuery const &) const; + ::FlexFlow::OutputMultiDiEdgeQuery output_edge_query; + ::FlexFlow::MultiDiEdgeQuery standard_edge_query; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DownwardOpenMultiDiEdgeQuery> { + size_t operator()(::FlexFlow::DownwardOpenMultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DownwardOpenMultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, DownwardOpenMultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..5fc93066e8 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DownwardOpenMultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h", + "utils/graph/multidigraph/multi_di_edge_query.dtg.h", +] + +[[fields]] +name = "output_edge_query" +type = "::FlexFlow::OutputMultiDiEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::MultiDiEdgeQuery" diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h new file mode 100644 index 0000000000..72883af253 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" +#include "utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h" + +namespace FlexFlow { + +struct DownwardOpenMultiDiGraph : virtual DownwardOpenMultiDiGraphView { +public: + using Edge = DownwardOpenMultiDiEdge; + using EdgeQuery = DownwardOpenMultiDiEdgeQuery; + + DownwardOpenMultiDiGraph() = delete; + DownwardOpenMultiDiGraph(DownwardOpenMultiDiGraph const &) = default; + DownwardOpenMultiDiGraph & + operator=(DownwardOpenMultiDiGraph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + void add_edge(Edge const &); + void remove_edge(Edge const &); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if< + std::is_base_of::value, + DownwardOpenMultiDiGraph>::type + create() { + return DownwardOpenMultiDiGraph(make_cow_ptr()); + } + +private: + using DownwardOpenMultiDiGraphView::DownwardOpenMultiDiGraphView; + + IDownwardOpenMultiDiGraph &get_ptr(); + IDownwardOpenMultiDiGraph const &get_ptr() const; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraph); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h new file mode 100644 index 0000000000..2ee0ce0d5e --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h" +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h" +#include "utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h" +#include "utils/graph/multidigraph/multidigraph_view.h" + +namespace FlexFlow { + +struct DownwardOpenMultiDiGraphView : virtual MultiDiGraphView { +public: + using Edge = DownwardOpenMultiDiEdge; + using EdgeQuery = DownwardOpenMultiDiEdgeQuery; + using Interface = IDownwardOpenMultiDiGraphView; + + DownwardOpenMultiDiGraphView(DownwardOpenMultiDiGraphView const &) = default; + DownwardOpenMultiDiGraphView & + operator=(DownwardOpenMultiDiGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if< + std::is_base_of::value, + DownwardOpenMultiDiGraphView>::type + create(Args &&...args) { + return DownwardOpenMultiDiGraphView( + make_cow_ptr(std::forward(args)...)); + } + +private: + using MultiDiGraphView::MultiDiGraphView; + + Interface const &get_ptr() const; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraphView); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h b/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h new file mode 100644 index 0000000000..8bb6d0f569 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_I_DOWNWARD_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_I_DOWNWARD_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h" + +namespace FlexFlow { + +struct IDownwardOpenMultiDiGraph + : virtual public IDownwardOpenMultiDiGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &node) = 0; + virtual void remove_node_unsafe(Node const &node) = 0; + virtual void add_edge(DownwardOpenMultiDiEdge const &) = 0; + virtual void remove_edge(DownwardOpenMultiDiEdge const &) = 0; + virtual IDownwardOpenMultiDiGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDownwardOpenMultiDiGraph); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h b/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h new file mode 100644 index 0000000000..243851bac9 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_I_DOWNWARD_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_MULTIDIGRAPH_I_DOWNWARD_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/open_multidigraph/i_open_multidigraph_view.h" +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h" +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct IDownwardOpenMultiDiGraphView : virtual public IOpenMultiDiGraphView { + virtual std::unordered_set + query_edges(DownwardOpenMultiDiEdgeQuery const &) const = 0; + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const final; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDownwardOpenMultiDiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h deleted file mode 100644 index de4ab4fd82..0000000000 --- a/lib/utils/include/utils/graph/multidiedge.h +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_MULTIDIEDGE -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_MULTIDIEDGE - -#include "diedge.h" -#include "node.h" -#include "node_port.h" -#include "utils/fmt/pair.h" -#include "utils/strong_typedef.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct MultiDiInput : DiInput { - NodePort dst_idx; -}; -FF_VISITABLE_STRUCT(MultiDiInput, dst, dst_idx); -FF_VISIT_FMTABLE(MultiDiInput); - -struct MultiDiOutput : DiOutput { - NodePort src_idx; - - bool operator>(MultiDiOutput const &) const; - bool operator>=(MultiDiOutput const &) const; - bool operator<=(MultiDiOutput const &) const; -}; -FF_VISITABLE_STRUCT(MultiDiOutput, src, src_idx); -FF_VISIT_FMTABLE(MultiDiOutput); - -using edge_uid_t = std::pair; - -struct InputMultiDiEdge : MultiDiInput { - req uid; // necessary to differentiate multiple input edges from - // different sources resulting from a graph cut -}; -FF_VISITABLE_STRUCT(InputMultiDiEdge, dst, dst_idx, uid); -FF_VISIT_FMTABLE(InputMultiDiEdge); - -struct OutputMultiDiEdge : MultiDiOutput { - req uid; // necessary to differentiate multiple output edges from - // different sources resulting from a graph cut -}; -FF_VISITABLE_STRUCT(OutputMultiDiEdge, src, src_idx, uid); -FF_VISIT_FMTABLE(OutputMultiDiEdge); - -struct OutputMultiDiEdgeQuery { - query_set srcs; - query_set srcIdxs; - - OutputMultiDiEdgeQuery with_src_nodes(query_set const &) const; - - static OutputMultiDiEdgeQuery all(); - static OutputMultiDiEdgeQuery none(); -}; -FF_VISITABLE_STRUCT(OutputMultiDiEdgeQuery, srcs, srcIdxs); - -struct InputMultiDiEdgeQuery { - query_set dsts; - query_set dstIdxs; - - InputMultiDiEdgeQuery with_dst_nodes(query_set const &) const; - - static InputMultiDiEdgeQuery all(); - static InputMultiDiEdgeQuery none(); -}; -FF_VISITABLE_STRUCT(InputMultiDiEdgeQuery, dsts, dstIdxs); - -struct MultiDiEdge : MultiDiInput, MultiDiOutput { - edge_uid_t get_uid() const { - return std::make_pair(src_idx.value(), dst_idx.value()); - } -}; -FF_VISITABLE_STRUCT(MultiDiEdge, dst, dst_idx, src, src_idx); -FF_VISIT_FMTABLE(MultiDiEdge); - -struct MultiDiEdgeQuery { - query_set srcs; - query_set dsts; - query_set srcIdxs; - query_set dstIdxs; - - MultiDiEdgeQuery with_src_nodes(query_set const &) const; - MultiDiEdgeQuery with_dst_nodes(query_set const &) const; - MultiDiEdgeQuery with_src_idxs(query_set const &) const; - MultiDiEdgeQuery with_dst_idxs(query_set const &) const; - - static MultiDiEdgeQuery all(); - static MultiDiEdgeQuery none(); -}; -FF_VISITABLE_STRUCT(MultiDiEdgeQuery, srcs, dsts, srcIdxs, dstIdxs); - -MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &, - MultiDiEdgeQuery const &); -MultiDiEdgeQuery query_union(MultiDiEdgeQuery const &, - MultiDiEdgeQuery const &); - -InputMultiDiEdge to_inputmultidiedge(MultiDiEdge const &e); -OutputMultiDiEdge to_outputmultidiedge(MultiDiEdge const &e); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h deleted file mode 100644 index effbad8a1e..0000000000 --- a/lib/utils/include/utils/graph/multidigraph.h +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_MULTIDIGRAPH_H -#define _FLEXFLOW_UTILS_GRAPH_MULTIDIGRAPH_H - -#include "cow_ptr_t.h" -#include "digraph.h" -#include "multidiedge.h" -#include "multidigraph_interfaces.h" -#include "node.h" - -namespace FlexFlow { -struct MultiDiGraphView : virtual DiGraphView { -public: - using Edge = MultiDiEdge; - using EdgeQuery = MultiDiEdgeQuery; - - MultiDiGraphView(MultiDiGraphView const &) = default; - MultiDiGraphView &operator=(MultiDiGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if::value, - MultiDiGraphView>::type - create(Args &&...args) { - return MultiDiGraphView(make_cow_ptr(std::forward(args)...)); - } - -protected: - using DiGraphView::DiGraphView; - -private: - IMultiDiGraphView const &get_ptr() const; - - friend struct GraphInternal; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); - -struct MultiDiGraph : virtual MultiDiGraphView { -public: - using Edge = MultiDiEdge; - using EdgeQuery = MultiDiEdgeQuery; - - MultiDiGraph() = delete; - MultiDiGraph(MultiDiGraph const &) = default; - MultiDiGraph &operator=(MultiDiGraph const &) = default; - - Node add_node(); - NodePort add_node_port(); - void add_node_unsafe(Node const &); - void add_node_port_unsafe(NodePort const &); - void remove_node_unsafe(Node const &); - - void add_edge(Edge const &e); - void remove_edge(Edge const &e); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if::value, - MultiDiGraph>::type - create() { - return MultiDiGraph(make_cow_ptr()); - } - -private: - using MultiDiGraphView::MultiDiGraphView; - - IMultiDiGraph const &get_ptr() const; - IMultiDiGraph &get_ptr(); - - friend struct GraphInternal; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/adjacency_multidigraph.h b/lib/utils/include/utils/graph/multidigraph/adjacency_multidigraph.h similarity index 92% rename from lib/utils/include/utils/graph/adjacency_multidigraph.h rename to lib/utils/include/utils/graph/multidigraph/adjacency_multidigraph.h index f486016138..c750b894e6 100644 --- a/lib/utils/include/utils/graph/adjacency_multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/adjacency_multidigraph.h @@ -1,10 +1,7 @@ #ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_ADJACENCY_MULTIDIGRAPH #define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_ADJACENCY_MULTIDIGRAPH -#include "multidigraph.h" -#include "utils/type_traits.h" -#include -#include +#include "utils/graph/multidigraph/i_multidigraph.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/multidigraph/i_multidigraph.h b/lib/utils/include/utils/graph/multidigraph/i_multidigraph.h new file mode 100644 index 0000000000..6f958d27c2 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/i_multidigraph.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_I_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_I_MULTIDIGRAPH_H + +#include "utils/graph/multidigraph/i_multidigraph_view.h" + +namespace FlexFlow { + +struct IMultiDiGraph : virtual public IMultiDiGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &) = 0; + virtual void remove_node_unsafe(Node const &) = 0; + virtual void add_edge(Edge const &) = 0; + virtual void remove_edge(Edge const &) = 0; + + virtual std::unordered_set + query_nodes(NodeQuery const &query) const override { + return static_cast(this)->query_nodes(query); + } + + virtual IMultiDiGraph *clone() const override = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h b/lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h new file mode 100644 index 0000000000..78bf508be4 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_I_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_I_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/digraph/i_digraph_view.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct IMultiDiGraphView : virtual public IDiGraphView { + using Edge = MultiDiEdge; + using EdgeQuery = MultiDiEdgeQuery; + + virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; + std::unordered_set + query_edges(DirectedEdgeQuery const &) const override final; + virtual ~IMultiDiGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraphView); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h b/lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h new file mode 100644 index 0000000000..0c471c8c35 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "73b001bfb7a0b75c42cd5037bb8dc686" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct MultiDiEdge { + MultiDiEdge() = delete; + explicit MultiDiEdge(::FlexFlow::Node const &src, + ::FlexFlow::Node const &dst, + std::pair const &raw_edge_uid); + + bool operator==(MultiDiEdge const &) const; + bool operator!=(MultiDiEdge const &) const; + bool operator<(MultiDiEdge const &) const; + bool operator>(MultiDiEdge const &) const; + bool operator<=(MultiDiEdge const &) const; + bool operator>=(MultiDiEdge const &) const; + ::FlexFlow::Node src; + ::FlexFlow::Node dst; + std::pair raw_edge_uid; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::MultiDiEdge> { + size_t operator()(::FlexFlow::MultiDiEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(MultiDiEdge const &); +std::ostream &operator<<(std::ostream &, MultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml b/lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml new file mode 100644 index 0000000000..41b08deb18 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" + +[[fields]] +name = "raw_edge_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..47b30da97b --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "bede7a523428098275e26ba89bb30eb0" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct MultiDiEdgeQuery { + MultiDiEdgeQuery() = delete; + explicit MultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts); + + bool operator==(MultiDiEdgeQuery const &) const; + bool operator!=(MultiDiEdgeQuery const &) const; + bool operator<(MultiDiEdgeQuery const &) const; + bool operator>(MultiDiEdgeQuery const &) const; + bool operator<=(MultiDiEdgeQuery const &) const; + bool operator>=(MultiDiEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> srcs; + ::FlexFlow::query_set<::FlexFlow::Node> dsts; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::MultiDiEdgeQuery> { + size_t operator()(::FlexFlow::MultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(MultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, MultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h new file mode 100644 index 0000000000..af7b6d01b7 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_H + +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +MultiDiEdgeQuery multidiedge_query_all(); +MultiDiEdgeQuery multidiedge_query_none(); +MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &, + MultiDiEdgeQuery const &); +MultiDiEdgeQuery query_union(MultiDiEdgeQuery const &, + MultiDiEdgeQuery const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..1d555b2626 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "MultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dsts" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h new file mode 100644 index 0000000000..0fc498b8ac --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTIDIGRAPH_H + +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/multidigraph/i_multidigraph.h" + +namespace FlexFlow { + +struct MultiDiGraph : virtual MultiDiGraphView { +public: + using Edge = MultiDiEdge; + using EdgeQuery = MultiDiEdgeQuery; + + MultiDiGraph() = delete; + MultiDiGraph(MultiDiGraph const &) = default; + MultiDiGraph &operator=(MultiDiGraph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + void add_edge(Edge const &e); + void remove_edge(Edge const &e); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if::value, + MultiDiGraph>::type + create() { + return MultiDiGraph(make_cow_ptr()); + } + +private: + using MultiDiGraphView::MultiDiGraphView; + + IMultiDiGraph const &get_ptr() const; + IMultiDiGraph &get_ptr(); + + friend struct GraphInternal; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph_view.h b/lib/utils/include/utils/graph/multidigraph/multidigraph_view.h new file mode 100644 index 0000000000..ba73a2f6af --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph_view.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/multidigraph/i_multidigraph_view.h" + +namespace FlexFlow { + +struct MultiDiGraphView : virtual DiGraphView { +public: + using Edge = MultiDiEdge; + using EdgeQuery = MultiDiEdgeQuery; + + MultiDiGraphView(MultiDiGraphView const &) = default; + MultiDiGraphView &operator=(MultiDiGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if::value, + MultiDiGraphView>::type + create(Args &&...args) { + return MultiDiGraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + using DiGraphView::DiGraphView; + +private: + IMultiDiGraphView const &get_ptr() const; + + friend struct GraphInternal; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph_interfaces.h b/lib/utils/include/utils/graph/multidigraph_interfaces.h deleted file mode 100644 index e48fc2a1a9..0000000000 --- a/lib/utils/include/utils/graph/multidigraph_interfaces.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_INTERFACES -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_INTERFACES - -#include "digraph_interfaces.h" -#include "multidiedge.h" -#include "node.h" -#include "query_set.h" -#include "utils/optional.h" -#include "utils/strong_typedef.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct IMultiDiGraphView : virtual public IDiGraphView { - using Edge = MultiDiEdge; - using EdgeQuery = MultiDiEdgeQuery; - - virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override final; - virtual ~IMultiDiGraphView() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraphView); - -struct IMultiDiGraph : virtual public IMultiDiGraphView { - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &) = 0; - virtual void remove_node_unsafe(Node const &) = 0; - virtual NodePort add_node_port() = 0; - virtual void add_node_port_unsafe(NodePort const &) = 0; - virtual void add_edge(Edge const &) = 0; - virtual void remove_edge(Edge const &) = 0; - - virtual std::unordered_set - query_nodes(NodeQuery const &query) const override { - return static_cast(this)->query_nodes(query); - } - - virtual IMultiDiGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/node.h b/lib/utils/include/utils/graph/node.h deleted file mode 100644 index 2e35ba8131..0000000000 --- a/lib/utils/include/utils/graph/node.h +++ /dev/null @@ -1,110 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_NODE_H -#define _FLEXFLOW_UTILS_GRAPH_NODE_H - -#include "cow_ptr_t.h" -#include "query_set.h" -#include "utils/fmt.h" -#include "utils/optional.h" -#include "utils/strong_typedef.h" -#include "utils/type_traits.h" -#include "utils/unique.h" -#include "utils/visitable.h" -#include -#include -#include -#include -#include - -namespace FlexFlow { - -struct Node : public strong_typedef { - using strong_typedef::strong_typedef; -}; -FF_TYPEDEF_HASHABLE(Node); -FF_TYPEDEF_PRINTABLE(Node, "Node"); - -struct NodeQuery { - NodeQuery(query_set const &nodes) : nodes(nodes) {} - - query_set nodes; - - static NodeQuery all(); -}; -FF_VISITABLE_STRUCT(NodeQuery, nodes); - -NodeQuery query_intersection(NodeQuery const &, NodeQuery const &); -NodeQuery query_union(NodeQuery const &, NodeQuery const &); - -struct IGraphView { - IGraphView() = default; - IGraphView(IGraphView const &) = delete; - IGraphView &operator=(IGraphView const &) = delete; - - virtual IGraphView *clone() const = 0; - - virtual std::unordered_set query_nodes(NodeQuery const &) const = 0; - virtual ~IGraphView(){}; -}; - -struct GraphView { - std::unordered_set query_nodes(NodeQuery const &) const; - friend bool is_ptr_equal(GraphView const &, GraphView const &); - - template - static typename std::enable_if::value, - GraphView>::type - create(Args &&...args) { - return GraphView(make_cow_ptr(std::forward(args)...)); - } - -protected: - GraphView() : ptr(nullptr) {} - cow_ptr_t ptr; - GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraphView); - -struct IGraph : virtual IGraphView { - IGraph() = default; - IGraph(IGraph const &) = delete; - IGraph &operator=(IGraph const &) = delete; - - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &) = 0; - virtual void remove_node_unsafe(Node const &) = 0; - virtual IGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraph); - -struct Graph : virtual GraphView { -public: - Graph(Graph const &) = default; - - Graph &operator=(Graph const &) = default; - - Node add_node(); - void add_node_unsafe(Node const &); - void remove_node_unsafe(Node const &); - - std::unordered_set query_nodes(NodeQuery const &) const; - - template - static typename std::enable_if::value, Graph>::type - create() { - return Graph(make_cow_ptr()); - } - - using GraphView::GraphView; - -private: - IGraph const &get_ptr() const; - IGraph &get_ptr(); - - friend struct GraphInternal; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h new file mode 100644 index 0000000000..81c16a3147 --- /dev/null +++ b/lib/utils/include/utils/graph/node/graph.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_GRAPH_H + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/node/node_query.dtg.h" +#include "utils/graph/node/i_graph.h" +#include "utils/graph/node/graph_view.h" + +namespace FlexFlow { + +struct Graph : virtual GraphView { +public: + Graph(Graph const &) = default; + + Graph &operator=(Graph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + std::unordered_set query_nodes(NodeQuery const &) const; + + template + static typename std::enable_if::value, Graph>::type + create() { + return Graph(make_cow_ptr()); + } + + using GraphView::GraphView; + +private: + IGraph const &get_ptr() const; + IGraph &get_ptr(); + + friend struct GraphInternal; +}; + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h new file mode 100644 index 0000000000..ad8001b8e4 --- /dev/null +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_GRAPH_VIEW_H + +#include "utils/graph/node/node_query.dtg.h" +#include "utils/graph/node/i_graph_view.h" +#include "utils/graph/cow_ptr_t.h" + +namespace FlexFlow { + +struct GraphView { + std::unordered_set query_nodes(NodeQuery const &) const; + friend bool is_ptr_equal(GraphView const &, GraphView const &); + + template + static typename std::enable_if::value, + GraphView>::type + create(Args &&...args) { + return GraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + GraphView() : ptr(nullptr) {} + cow_ptr_t ptr; + GraphView(cow_ptr_t ptr); + + friend struct GraphInternal; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/i_graph.h b/lib/utils/include/utils/graph/node/i_graph.h new file mode 100644 index 0000000000..1b87fb4b9c --- /dev/null +++ b/lib/utils/include/utils/graph/node/i_graph.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_GRAPH_H + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/node/i_graph_view.h" + +namespace FlexFlow { + +struct IGraph : virtual IGraphView { + IGraph() = default; + IGraph(IGraph const &) = delete; + IGraph &operator=(IGraph const &) = delete; + + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &) = 0; + virtual void remove_node_unsafe(Node const &) = 0; + virtual IGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/i_graph_view.h b/lib/utils/include/utils/graph/node/i_graph_view.h new file mode 100644 index 0000000000..7d395bca2c --- /dev/null +++ b/lib/utils/include/utils/graph/node/i_graph_view.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_GRAPH_VIEW_H + +#include "utils/graph/node/node_query.dtg.h" +namespace FlexFlow { + +struct IGraphView { + IGraphView() = default; + IGraphView(IGraphView const &) = delete; + IGraphView &operator=(IGraphView const &) = delete; + + virtual IGraphView *clone() const = 0; + + virtual std::unordered_set query_nodes(NodeQuery const &) const = 0; + virtual ~IGraphView(){}; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/node.dtg.h b/lib/utils/include/utils/graph/node/node.dtg.h new file mode 100644 index 0000000000..d509e592e6 --- /dev/null +++ b/lib/utils/include/utils/graph/node/node.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/node/node.struct.toml +/* proj-data +{ + "generated_from": "cc4828f6a9dcc4c3435767bd6ccfc866" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_DTG_H + +#include "fmt/format.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Node { + Node() = delete; + explicit Node(size_t const &raw_uid); + + bool operator==(Node const &) const; + bool operator!=(Node const &) const; + bool operator<(Node const &) const; + bool operator>(Node const &) const; + bool operator<=(Node const &) const; + bool operator>=(Node const &) const; + size_t raw_uid; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::Node> { + size_t operator()(::FlexFlow::Node const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(Node const &); +std::ostream &operator<<(std::ostream &, Node const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_DTG_H diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml new file mode 100644 index 0000000000..0b6f348ddf --- /dev/null +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "Node" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "", +] + +[[fields]] +name = "raw_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/node/node_query.dtg.h b/lib/utils/include/utils/graph/node/node_query.dtg.h new file mode 100644 index 0000000000..3813e8217d --- /dev/null +++ b/lib/utils/include/utils/graph/node/node_query.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/node/node_query.struct.toml +/* proj-data +{ + "generated_from": "e3e4a13f0d1a7ca9f179ba09dd4c5735" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct NodeQuery { + NodeQuery() = delete; + explicit NodeQuery(::FlexFlow::query_set<::FlexFlow::Node> const &nodes); + + bool operator==(NodeQuery const &) const; + bool operator!=(NodeQuery const &) const; + bool operator<(NodeQuery const &) const; + bool operator>(NodeQuery const &) const; + bool operator<=(NodeQuery const &) const; + bool operator>=(NodeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> nodes; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::NodeQuery> { + size_t operator()(::FlexFlow::NodeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(NodeQuery const &); +std::ostream &operator<<(std::ostream &, NodeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/node/node_query.h b/lib/utils/include/utils/graph/node/node_query.h new file mode 100644 index 0000000000..60d1d2932c --- /dev/null +++ b/lib/utils/include/utils/graph/node/node_query.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_NODE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_NODE_QUERY_H + +#include "utils/graph/node/node_query.dtg.h" + +namespace FlexFlow { + +NodeQuery node_query_all(); +NodeQuery query_intersection(NodeQuery const &, NodeQuery const &); +NodeQuery query_union(NodeQuery const &, NodeQuery const &); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/node_query.struct.toml b/lib/utils/include/utils/graph/node/node_query.struct.toml new file mode 100644 index 0000000000..0519e01650 --- /dev/null +++ b/lib/utils/include/utils/graph/node/node_query.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "NodeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/query_set.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/node_port.h b/lib/utils/include/utils/graph/node_port.h deleted file mode 100644 index cb0c973a67..0000000000 --- a/lib/utils/include/utils/graph/node_port.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_NODE_PORT -#define UTILS_GRAPH_INCLUDE_NODE_PORT - -namespace FlexFlow { - -/** - * @class NodePort - * @brief An opaque object used to disambiguate multiple edges between the same - * nodes in a MultiDiGraph - * - * Name chosen to match the terminology used by ELK - * - */ -struct NodePort : public strong_typedef { - using strong_typedef::strong_typedef; -}; -FF_TYPEDEF_HASHABLE(NodePort); -FF_TYPEDEF_PRINTABLE(NodePort, "NodePort"); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/open_edge.h b/lib/utils/include/utils/graph/open_edge.h deleted file mode 100644 index 37e98a419d..0000000000 --- a/lib/utils/include/utils/graph/open_edge.h +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_OPEN_EDGE -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_OPEN_EDGE - -#include "multidiedge.h" - -namespace FlexFlow { - -using OpenMultiDiEdge = - std::variant; - -using DownwardOpenMultiDiEdge = std::variant; - -using UpwardOpenMultiDiEdge = std::variant; - -bool is_input_edge(OpenMultiDiEdge const &); -bool is_output_edge(OpenMultiDiEdge const &); -bool is_standard_edge(OpenMultiDiEdge const &); - -struct OpenMultiDiEdgeQuery { - OpenMultiDiEdgeQuery() = delete; - OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &input_edge_query, - MultiDiEdgeQuery const &standard_edge_query, - OutputMultiDiEdgeQuery const &output_edge_query); - - OpenMultiDiEdgeQuery(MultiDiEdgeQuery const &q); - OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &q); - OpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &q); - - static OpenMultiDiEdgeQuery all(); - - InputMultiDiEdgeQuery input_edge_query; - MultiDiEdgeQuery standard_edge_query; - OutputMultiDiEdgeQuery output_edge_query; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpenMultiDiEdgeQuery, - input_edge_query, - standard_edge_query, - output_edge_query); - -struct DownwardOpenMultiDiEdgeQuery { - DownwardOpenMultiDiEdgeQuery() = delete; - DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query, - MultiDiEdgeQuery const &standard_edge_query); - DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query); - DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query); - - operator OpenMultiDiEdgeQuery() const; - - OutputMultiDiEdgeQuery output_edge_query; - MultiDiEdgeQuery standard_edge_query; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(DownwardOpenMultiDiEdgeQuery, - output_edge_query, - standard_edge_query); - -struct UpwardOpenMultiDiEdgeQuery { - UpwardOpenMultiDiEdgeQuery() = delete; - UpwardOpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &, - MultiDiEdgeQuery const &); - UpwardOpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &); - UpwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &); - operator OpenMultiDiEdgeQuery() const; - - InputMultiDiEdgeQuery input_edge_query; - MultiDiEdgeQuery standard_edge_query; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(UpwardOpenMultiDiEdgeQuery, - input_edge_query, - standard_edge_query); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/open_graph_interfaces.h b/lib/utils/include/utils/graph/open_graph_interfaces.h deleted file mode 100644 index 3173ea9ac1..0000000000 --- a/lib/utils/include/utils/graph/open_graph_interfaces.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_OPEN_GRAPH_INTERFACES -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_OPEN_GRAPH_INTERFACES - -#include "multidigraph.h" -#include "open_edge.h" -#include "utils/exception.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/multidigraph_interfaces.h" -#include "utils/strong_typedef.h" -#include "utils/type_traits.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct IOpenMultiDiGraphView : virtual public IMultiDiGraphView { - virtual std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const = 0; - virtual std::unordered_set - query_edges(MultiDiEdgeQuery const &) const override final; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenMultiDiGraphView); - -struct IDownwardOpenMultiDiGraphView : virtual public IOpenMultiDiGraphView { - virtual std::unordered_set - query_edges(DownwardOpenMultiDiEdgeQuery const &) const = 0; - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &q) const final { - return widen( - this->query_edges(DownwardOpenMultiDiEdgeQuery{q.output_edge_query, - q.standard_edge_query})); - } -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDownwardOpenMultiDiGraphView); - -struct IUpwardOpenMultiDiGraphView : virtual public IOpenMultiDiGraphView { - virtual std::unordered_set - query_edges(UpwardOpenMultiDiEdgeQuery const &) const = 0; - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &q) const final { - return widen(this->query_edges( - UpwardOpenMultiDiEdgeQuery{q.input_edge_query, q.standard_edge_query})); - } -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUpwardOpenMultiDiGraphView); - -struct IOpenMultiDiGraph : virtual public IOpenMultiDiGraphView { - virtual Node add_node() = 0; - virtual NodePort add_node_port() = 0; - virtual void add_node_unsafe(Node const &node) = 0; - virtual void remove_node_unsafe(Node const &node) = 0; - virtual void add_edge(OpenMultiDiEdge const &) = 0; - virtual void remove_edge(OpenMultiDiEdge const &) = 0; - virtual IOpenMultiDiGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenMultiDiGraph); - -struct IUpwardOpenMultiDiGraph : virtual public IUpwardOpenMultiDiGraphView { - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &node) = 0; - virtual void remove_node_unsafe(Node const &node) = 0; - virtual void add_edge(UpwardOpenMultiDiEdge const &) = 0; - virtual void remove_edge(UpwardOpenMultiDiEdge const &) = 0; - virtual IUpwardOpenMultiDiGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUpwardOpenMultiDiGraph); - -struct IDownwardOpenMultiDiGraph - : virtual public IDownwardOpenMultiDiGraphView { - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &node) = 0; - virtual void remove_node_unsafe(Node const &node) = 0; - virtual void add_edge(DownwardOpenMultiDiEdge const &) = 0; - virtual void remove_edge(DownwardOpenMultiDiEdge const &) = 0; - virtual IDownwardOpenMultiDiGraph *clone() const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDownwardOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h deleted file mode 100644 index 0b0db44f93..0000000000 --- a/lib/utils/include/utils/graph/open_graphs.h +++ /dev/null @@ -1,206 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_OPEN_GRAPHS_H -#define _FLEXFLOW_UTILS_GRAPH_OPEN_GRAPHS_H - -#include "multidigraph.h" -#include "node.h" -#include "open_edge.h" -#include "open_graph_interfaces.h" -#include "utils/optional.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct OpenMultiDiGraphView : virtual MultiDiGraphView { -public: - using Edge = OpenMultiDiEdge; - using EdgeQuery = OpenMultiDiEdgeQuery; - - OpenMultiDiGraphView(OpenMultiDiGraphView const &) = default; - OpenMultiDiGraphView &operator=(OpenMultiDiGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static - typename std::enable_if::value, - OpenMultiDiGraphView>::type - create(Args &&...args) { - return OpenMultiDiGraphView(make_cow_ptr(std::forward(args)...)); - } - -protected: - using MultiDiGraphView::MultiDiGraphView; - -private: - IOpenMultiDiGraphView const &get_ptr() const; - - friend struct GraphInternal; -}; - -struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { -public: - using Edge = OpenMultiDiEdge; - using EdgeQuery = OpenMultiDiEdgeQuery; - - OpenMultiDiGraph() = delete; - OpenMultiDiGraph(OpenMultiDiGraph const &) = default; - - Node add_node(); - void add_node_unsafe(Node const &); - void remove_node_unsafe(Node const &); - NodePort add_node_port(); - - void add_edge(Edge const &); - void remove_edge(Edge const &); - - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if::value, - OpenMultiDiGraph>::type - create() { - return OpenMultiDiGraph(make_cow_ptr()); - } - -private: - using OpenMultiDiGraphView::OpenMultiDiGraphView; - - IOpenMultiDiGraph const &get_ptr() const; - IOpenMultiDiGraph &get_ptr(); - - friend struct GraphInternal; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OpenMultiDiGraph); - -struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { -public: - using Edge = UpwardOpenMultiDiEdge; - using EdgeQuery = UpwardOpenMultiDiEdgeQuery; - - UpwardOpenMultiDiGraphView(UpwardOpenMultiDiGraphView const &) = default; - UpwardOpenMultiDiGraphView & - operator=(UpwardOpenMultiDiGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &); - std::unordered_set query_edges(EdgeQuery const &); - - template - static typename std::enable_if< - std::is_base_of::value, - UpwardOpenMultiDiGraphView>::type - create(Args &&...args) { - return UpwardOpenMultiDiGraphView( - cow_ptr_t(std::forward(args)...)); - } - -private: - using MultiDiGraphView::MultiDiGraphView; - - IUpwardOpenMultiDiGraphView const &get_ptr() const; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraphView); - -struct UpwardOpenMultiDiGraph : virtual UpwardOpenMultiDiGraphView { -public: - using Edge = UpwardOpenMultiDiEdge; - using EdgeQuery = UpwardOpenMultiDiEdgeQuery; - - UpwardOpenMultiDiGraph() = delete; - UpwardOpenMultiDiGraph(UpwardOpenMultiDiGraph const &) = default; - UpwardOpenMultiDiGraph &operator=(UpwardOpenMultiDiGraph const &) = default; - - Node add_node(); - void add_node_unsafe(Node const &); - void remove_node_unsafe(Node const &); - - void add_edge(Edge const &); - void remove_edge(Edge const &); - - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if< - std::is_base_of::value, - UpwardOpenMultiDiGraph>::type - create() { - return UpwardOpenMultiDiGraph(make_cow_ptr()); - } - -private: - using UpwardOpenMultiDiGraphView::UpwardOpenMultiDiGraphView; - - IUpwardOpenMultiDiGraph const &get_ptr() const; - IUpwardOpenMultiDiGraph &get_ptr(); -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraph); - -struct DownwardOpenMultiDiGraphView : virtual MultiDiGraphView { -public: - using Edge = DownwardOpenMultiDiEdge; - using EdgeQuery = DownwardOpenMultiDiEdgeQuery; - using Interface = IDownwardOpenMultiDiGraphView; - - DownwardOpenMultiDiGraphView(DownwardOpenMultiDiGraphView const &) = default; - DownwardOpenMultiDiGraphView & - operator=(DownwardOpenMultiDiGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if< - std::is_base_of::value, - DownwardOpenMultiDiGraphView>::type - create(Args &&...args) { - return DownwardOpenMultiDiGraphView( - make_cow_ptr(std::forward(args)...)); - } - -private: - using MultiDiGraphView::MultiDiGraphView; - - Interface const &get_ptr() const; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraphView); - -struct DownwardOpenMultiDiGraph : virtual DownwardOpenMultiDiGraphView { -public: - using Edge = DownwardOpenMultiDiEdge; - using EdgeQuery = DownwardOpenMultiDiEdgeQuery; - - DownwardOpenMultiDiGraph() = delete; - DownwardOpenMultiDiGraph(DownwardOpenMultiDiGraph const &) = default; - DownwardOpenMultiDiGraph & - operator=(DownwardOpenMultiDiGraph const &) = default; - - Node add_node(); - void add_node_unsafe(Node const &); - void remove_node_unsafe(Node const &); - - void add_edge(Edge const &); - void remove_edge(Edge const &); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if< - std::is_base_of::value, - DownwardOpenMultiDiGraph>::type - create() { - return DownwardOpenMultiDiGraph(make_cow_ptr()); - } - -private: - using DownwardOpenMultiDiGraphView::DownwardOpenMultiDiGraphView; - - IDownwardOpenMultiDiGraph &get_ptr(); - IDownwardOpenMultiDiGraph const &get_ptr() const; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/adjacency_openmultidigraph.h b/lib/utils/include/utils/graph/open_multidigraph/adjacency_openmultidigraph.h similarity index 94% rename from lib/utils/include/utils/graph/adjacency_openmultidigraph.h rename to lib/utils/include/utils/graph/open_multidigraph/adjacency_openmultidigraph.h index ff331287cc..94671d395f 100644 --- a/lib/utils/include/utils/graph/adjacency_openmultidigraph.h +++ b/lib/utils/include/utils/graph/open_multidigraph/adjacency_openmultidigraph.h @@ -1,8 +1,8 @@ #ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_ADJACENCY_OPENMULTIDIGRAPH #define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_ADJACENCY_OPENMULTIDIGRAPH -#include "adjacency_multidigraph.h" -#include "open_graph_interfaces.h" +#include "utils/graph/multidigraph/adjacency_multidigraph.h" +#include "utils/graph/open_multidigraph/i_open_multidigraph.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h b/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h new file mode 100644 index 0000000000..f2a34ba7ff --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_I_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_I_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/open_multidigraph/i_open_multidigraph_view.h" + +namespace FlexFlow { + +struct IOpenMultiDiGraph : virtual public IOpenMultiDiGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &node) = 0; + virtual void remove_node_unsafe(Node const &node) = 0; + virtual void add_edge(OpenMultiDiEdge const &) = 0; + virtual void remove_edge(OpenMultiDiEdge const &) = 0; + virtual IOpenMultiDiGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenMultiDiGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h b/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h new file mode 100644 index 0000000000..b767791a77 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_I_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_I_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/multidigraph/i_multidigraph_view.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct IOpenMultiDiGraphView : virtual public IMultiDiGraphView { + virtual std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &) const = 0; + virtual std::unordered_set + query_edges(MultiDiEdgeQuery const &) const override final; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenMultiDiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h new file mode 100644 index 0000000000..94f892e8d2 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "d779a19c1f8f096dc1dfabf95633b115" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct InputMultiDiEdge { + InputMultiDiEdge() = delete; + explicit InputMultiDiEdge(::FlexFlow::Node const &dst, size_t const &raw_uid); + + bool operator==(InputMultiDiEdge const &) const; + bool operator!=(InputMultiDiEdge const &) const; + bool operator<(InputMultiDiEdge const &) const; + bool operator>(InputMultiDiEdge const &) const; + bool operator<=(InputMultiDiEdge const &) const; + bool operator>=(InputMultiDiEdge const &) const; + ::FlexFlow::Node dst; + size_t raw_uid; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::InputMultiDiEdge> { + size_t operator()(::FlexFlow::InputMultiDiEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(InputMultiDiEdge const &); +std::ostream &operator<<(std::ostream &, InputMultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h new file mode 100644 index 0000000000..9f396ed28a --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_H + +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h" + +namespace FlexFlow { + +InputMultiDiEdge input_multidiedge_from_multidiedge(MultiDiEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml new file mode 100644 index 0000000000..c9519dc886 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "InputMultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +[[fields]] +name = "dst" +type = "::FlexFlow::Node" + +[[fields]] +name = "raw_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..8e581edc88 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "c42e43b28fae9a63d94e54f244dd3ee0" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct InputMultiDiEdgeQuery { + InputMultiDiEdgeQuery() = delete; + explicit InputMultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts); + + bool operator==(InputMultiDiEdgeQuery const &) const; + bool operator!=(InputMultiDiEdgeQuery const &) const; + bool operator<(InputMultiDiEdgeQuery const &) const; + bool operator>(InputMultiDiEdgeQuery const &) const; + bool operator<=(InputMultiDiEdgeQuery const &) const; + bool operator>=(InputMultiDiEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> dsts; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::InputMultiDiEdgeQuery> { + size_t operator()(::FlexFlow::InputMultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(InputMultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, InputMultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h new file mode 100644 index 0000000000..9971bce463 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_INPUT_MULTI_DI_EDGE_QUERY_H + +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +InputMultiDiEdgeQuery input_multidiedge_query_all(); +InputMultiDiEdgeQuery input_multidiedge_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..76f420f696 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "InputMultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "dsts" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h new file mode 100644 index 0000000000..0e30970aee --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h @@ -0,0 +1,126 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "f7a6881be7d51ba916f3740828c23d91" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OpenMultiDiEdge { + OpenMultiDiEdge() = delete; + explicit OpenMultiDiEdge(::FlexFlow::InputMultiDiEdge const &); + explicit OpenMultiDiEdge(::FlexFlow::OutputMultiDiEdge const &); + explicit OpenMultiDiEdge(::FlexFlow::MultiDiEdge const &); + template + static constexpr bool IsPartOfOpenMultiDiEdge_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::InputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::OutputMultiDiEdge>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenMultiDiEdge", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::InputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::OutputMultiDiEdge>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenMultiDiEdge", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfOpenMultiDiEdge_v, + "OpenMultiDiEdge::has() expected one of [::FlexFlow::InputMultiDiEdge, " + "::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfOpenMultiDiEdge_v, + "OpenMultiDiEdge::get() expected one of [::FlexFlow::InputMultiDiEdge, " + "::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfOpenMultiDiEdge_v, + "OpenMultiDiEdge::get() expected one of [::FlexFlow::InputMultiDiEdge, " + "::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OpenMultiDiEdge const &) const; + bool operator!=(OpenMultiDiEdge const &) const; + bool operator<(OpenMultiDiEdge const &) const; + bool operator>(OpenMultiDiEdge const &) const; + bool operator<=(OpenMultiDiEdge const &) const; + bool operator>=(OpenMultiDiEdge const &) const; + std::variant<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::MultiDiEdge> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OpenMultiDiEdge> { + size_t operator()(::FlexFlow::OpenMultiDiEdge const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenMultiDiEdge const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::OpenMultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h new file mode 100644 index 0000000000..8fd52417d3 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_H + +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" + +namespace FlexFlow { + +bool is_input_edge(OpenMultiDiEdge const &); +bool is_output_edge(OpenMultiDiEdge const &); +bool is_standard_edge(OpenMultiDiEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml new file mode 100644 index 0000000000..e99ad43173 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OpenMultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h", + "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h", + "utils/graph/multidigraph/multi_di_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::InputMultiDiEdge" +key = "input_edge" + +[[values]] +type = "::FlexFlow::OutputMultiDiEdge" +key = "output_edge" + +[[values]] +type = "::FlexFlow::MultiDiEdge" +key = "standard_edge" diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..7dd4945419 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "86fc384b53b6b27982dfe6ab8fff2d04" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OpenMultiDiEdgeQuery { + OpenMultiDiEdgeQuery() = delete; + explicit OpenMultiDiEdgeQuery( + ::FlexFlow::InputMultiDiEdgeQuery const &input_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query, + ::FlexFlow::OutputMultiDiEdgeQuery const &output_edge_query); + + bool operator==(OpenMultiDiEdgeQuery const &) const; + bool operator!=(OpenMultiDiEdgeQuery const &) const; + bool operator<(OpenMultiDiEdgeQuery const &) const; + bool operator>(OpenMultiDiEdgeQuery const &) const; + bool operator<=(OpenMultiDiEdgeQuery const &) const; + bool operator>=(OpenMultiDiEdgeQuery const &) const; + ::FlexFlow::InputMultiDiEdgeQuery input_edge_query; + ::FlexFlow::MultiDiEdgeQuery standard_edge_query; + ::FlexFlow::OutputMultiDiEdgeQuery output_edge_query; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::OpenMultiDiEdgeQuery> { + size_t operator()(::FlexFlow::OpenMultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OpenMultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, OpenMultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h new file mode 100644 index 0000000000..e0a194b991 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTI_DI_EDGE_QUERY_H + +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +OpenMultiDiEdgeQuery open_multidiedge_query_all(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..ddc9e062e0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OpenMultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h", + "utils/graph/multidigraph/multi_di_edge_query.dtg.h", + "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h", +] + +[[fields]] +name = "input_edge_query" +type = "::FlexFlow::InputMultiDiEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::MultiDiEdgeQuery" + +[[fields]] +name = "output_edge_query" +type = "::FlexFlow::OutputMultiDiEdgeQuery" diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h b/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h new file mode 100644 index 0000000000..d482a7149b --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/open_multidigraph/open_multidigraph_view.h" +#include "utils/graph/open_multidigraph/i_open_multidigraph.h" + +namespace FlexFlow { + +struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { +public: + using Edge = OpenMultiDiEdge; + using EdgeQuery = OpenMultiDiEdgeQuery; + + OpenMultiDiGraph() = delete; + OpenMultiDiGraph(OpenMultiDiGraph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + void add_edge(Edge const &); + void remove_edge(Edge const &); + + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if::value, + OpenMultiDiGraph>::type + create() { + return OpenMultiDiGraph(make_cow_ptr()); + } + +private: + using OpenMultiDiGraphView::OpenMultiDiGraphView; + + IOpenMultiDiGraph const &get_ptr() const; + IOpenMultiDiGraph &get_ptr(); + + friend struct GraphInternal; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OpenMultiDiGraph); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h b/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h new file mode 100644 index 0000000000..b1d587b644 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/open_multidigraph/i_open_multidigraph_view.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct OpenMultiDiGraphView : virtual MultiDiGraphView { +public: + using Edge = OpenMultiDiEdge; + using EdgeQuery = OpenMultiDiEdgeQuery; + + OpenMultiDiGraphView(OpenMultiDiGraphView const &) = default; + OpenMultiDiGraphView &operator=(OpenMultiDiGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static + typename std::enable_if::value, + OpenMultiDiGraphView>::type + create(Args &&...args) { + return OpenMultiDiGraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + using MultiDiGraphView::MultiDiGraphView; + +private: + IOpenMultiDiGraphView const &get_ptr() const; + + friend struct GraphInternal; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h new file mode 100644 index 0000000000..4c12126e67 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "2ec351b641e8ecfd79fd7df2ec13dbd4" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputMultiDiEdge { + OutputMultiDiEdge() = delete; + explicit OutputMultiDiEdge(::FlexFlow::Node const &src, + size_t const &raw_uid); + + bool operator==(OutputMultiDiEdge const &) const; + bool operator!=(OutputMultiDiEdge const &) const; + bool operator<(OutputMultiDiEdge const &) const; + bool operator>(OutputMultiDiEdge const &) const; + bool operator<=(OutputMultiDiEdge const &) const; + bool operator>=(OutputMultiDiEdge const &) const; + ::FlexFlow::Node src; + size_t raw_uid; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::OutputMultiDiEdge> { + size_t operator()(::FlexFlow::OutputMultiDiEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputMultiDiEdge const &); +std::ostream &operator<<(std::ostream &, OutputMultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h new file mode 100644 index 0000000000..c6100cfb28 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_H + +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h" + +namespace FlexFlow { + +OutputMultiDiEdge output_multidiedge_from_multidiedge(MultiDiEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml new file mode 100644 index 0000000000..4671b016ba --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputMultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +[[fields]] +name = "src" +type = "::FlexFlow::Node" + +[[fields]] +name = "raw_uid" +type = "size_t" diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..6e2acd81f7 --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "4833874bcc5268ec7a7f8fe92186ba17" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct OutputMultiDiEdgeQuery { + OutputMultiDiEdgeQuery() = delete; + explicit OutputMultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs); + + bool operator==(OutputMultiDiEdgeQuery const &) const; + bool operator!=(OutputMultiDiEdgeQuery const &) const; + bool operator<(OutputMultiDiEdgeQuery const &) const; + bool operator>(OutputMultiDiEdgeQuery const &) const; + bool operator<=(OutputMultiDiEdgeQuery const &) const; + bool operator>=(OutputMultiDiEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> srcs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::OutputMultiDiEdgeQuery> { + size_t operator()(::FlexFlow::OutputMultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputMultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, OutputMultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h new file mode 100644 index 0000000000..2262f679ba --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_MULTIDIGRAPH_OUTPUT_MULTI_DI_EDGE_QUERY_H + +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +OutputMultiDiEdgeQuery output_multidiedge_query_all(); +OutputMultiDiEdgeQuery output_multidiedge_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..c6f4c7160b --- /dev/null +++ b/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "OutputMultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index dda06e997f..ff65533d2a 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -46,6 +46,10 @@ struct query_set { return {std::nullopt}; } + static query_set match_none() { + return {std::unordered_set{}}; + } + private: std::optional> query; }; diff --git a/lib/utils/include/utils/graph/traversal.h b/lib/utils/include/utils/graph/traversal.h index 3c3992cd53..44ddc39eb8 100644 --- a/lib/utils/include/utils/graph/traversal.h +++ b/lib/utils/include/utils/graph/traversal.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_TRAVERSAL_H #define _FLEXFLOW_UTILS_GRAPH_TRAVERSAL_H -#include "digraph.h" -#include "node.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h deleted file mode 100644 index d604016c31..0000000000 --- a/lib/utils/include/utils/graph/undirected.h +++ /dev/null @@ -1,113 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_UNDIRECTED_H -#define _FLEXFLOW_UTILS_GRAPH_UNDIRECTED_H - -#include "cow_ptr_t.h" -#include "node.h" -#include "undirected_edge.h" -#include "utils/exception.h" -#include "utils/optional.h" -#include "utils/type_traits.h" -#include "utils/unique.h" -#include - -namespace FlexFlow { - -struct IUndirectedGraphView : public IGraphView { - using Edge = UndirectedEdge; - using EdgeQuery = UndirectedEdgeQuery; - - IUndirectedGraphView(IUndirectedGraphView const &) = delete; - IUndirectedGraphView &operator=(IUndirectedGraphView const &) = delete; - - virtual std::unordered_set - query_edges(UndirectedEdgeQuery const &) const = 0; - virtual ~IUndirectedGraphView() = default; - - IUndirectedGraphView *clone() const override = 0; - -protected: - IUndirectedGraphView() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUndirectedGraphView); - -struct UndirectedGraphView : virtual GraphView { -public: - using Edge = UndirectedEdge; - using EdgeQuery = UndirectedEdgeQuery; - - UndirectedGraphView() = delete; - UndirectedGraphView(UndirectedGraphView const &) = default; - UndirectedGraphView &operator=(UndirectedGraphView const &) = default; - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &query) const; - - template - static - typename std::enable_if::value, - UndirectedGraphView>::type - create(Args &&...args) { - return UndirectedGraphView(make_cow_ptr(std::forward(args)...)); - } - - using GraphView::GraphView; - - friend struct GraphInternal; - -private: - IUndirectedGraphView const &get_ptr() const; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); - -struct IUndirectedGraph : public IUndirectedGraphView { - virtual Node add_node() = 0; - virtual void add_node_unsafe(Node const &) = 0; - virtual void remove_node_unsafe(Node const &) = 0; - virtual void add_edge(UndirectedEdge const &) = 0; - virtual void remove_edge(UndirectedEdge const &) = 0; - - virtual std::unordered_set - query_nodes(NodeQuery const &query) const = 0; - - virtual IUndirectedGraph *clone() const override = 0; -}; - -struct UndirectedGraph : virtual UndirectedGraphView { -public: - using Edge = UndirectedEdge; - using EdgeQuery = UndirectedEdgeQuery; - - UndirectedGraph() = delete; - UndirectedGraph(UndirectedGraph const &) = default; - UndirectedGraph &operator=(UndirectedGraph const &) = default; - - Node add_node(); - void add_node_unsafe(Node const &); - void remove_node_unsafe(Node const &); - - void add_edge(Edge const &); - void remove_edge(Edge const &); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_edges(EdgeQuery const &) const; - - template - static typename std::enable_if::value, - UndirectedGraph>::type - create() { - return UndirectedGraph(make_cow_ptr()); - } - - using UndirectedGraphView::UndirectedGraphView; - - friend struct GraphInternal; - -private: - IUndirectedGraph const &get_ptr() const; - IUndirectedGraph &get_ptr(); -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/undirected/hashmap_undirected_graph.h similarity index 96% rename from lib/utils/include/utils/graph/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/undirected/hashmap_undirected_graph.h index 5d2653bcae..8630277fe8 100644 --- a/lib/utils/include/utils/graph/hashmap_undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/hashmap_undirected_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_HASHMAP_UNDIRECTED_GRAPH_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_HASHMAP_UNDIRECTED_GRAPH_H -#include "utils/graph/undirected.h" +#include "utils/graph/undirected/i_undirected_graph.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h new file mode 100644 index 0000000000..1662ec6d8c --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_UNDIRECTED_GRAPH_H + +#include "utils/graph/undirected/i_undirected_graph_view.h" + +namespace FlexFlow { + +struct IUndirectedGraph : public IUndirectedGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &) = 0; + virtual void remove_node_unsafe(Node const &) = 0; + virtual void add_edge(UndirectedEdge const &) = 0; + virtual void remove_edge(UndirectedEdge const &) = 0; + + virtual std::unordered_set + query_nodes(NodeQuery const &query) const = 0; + + virtual IUndirectedGraph *clone() const override = 0; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph_view.h new file mode 100644 index 0000000000..2ffe061dbe --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph_view.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_UNDIRECTED_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_UNDIRECTED_GRAPH_VIEW_H + +#include "utils/graph/node/i_graph_view.h" +#include "utils/graph/undirected/undirected_edge.h" +#include "utils/graph/undirected/undirected_edge_query.dtg.h" + +namespace FlexFlow { + +struct IUndirectedGraphView : public IGraphView { + using Edge = UndirectedEdge; + using EdgeQuery = UndirectedEdgeQuery; + + IUndirectedGraphView(IUndirectedGraphView const &) = delete; + IUndirectedGraphView &operator=(IUndirectedGraphView const &) = delete; + + virtual std::unordered_set + query_edges(UndirectedEdgeQuery const &) const = 0; + virtual ~IUndirectedGraphView() = default; + + IUndirectedGraphView *clone() const override = 0; + +protected: + IUndirectedGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUndirectedGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h new file mode 100644 index 0000000000..71c760daf3 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H + +#include "utils/graph/node/node.dtg.h" +namespace FlexFlow { + +struct UndirectedEdge { +public: + UndirectedEdge() = delete; + UndirectedEdge(Node const &src, Node const &dst); + + bool operator==(UndirectedEdge const &) const; + bool operator!=(UndirectedEdge const &) const; + bool operator<(UndirectedEdge const &) const; +public: + Node smaller; + Node bigger; +}; + +bool is_connected_to(UndirectedEdge const &, Node const &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::UndirectedEdge> { + size_t operator()(::FlexFlow::UndirectedEdge const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.h new file mode 100644 index 0000000000..32fbbb4c5c --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml +/* proj-data +{ + "generated_from": "10df85f620b0fb6e70496d6585be6b43" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct UndirectedEdgeQuery { + UndirectedEdgeQuery() = delete; + explicit UndirectedEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &nodes); + + bool operator==(UndirectedEdgeQuery const &) const; + bool operator!=(UndirectedEdgeQuery const &) const; + bool operator<(UndirectedEdgeQuery const &) const; + bool operator>(UndirectedEdgeQuery const &) const; + bool operator<=(UndirectedEdgeQuery const &) const; + bool operator>=(UndirectedEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::Node> nodes; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::UndirectedEdgeQuery> { + size_t operator()(::FlexFlow::UndirectedEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(UndirectedEdgeQuery const &); +std::ostream &operator<<(std::ostream &, UndirectedEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h new file mode 100644 index 0000000000..9aa0f189ec --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H + +#include "utils/graph/undirected/undirected_edge_query.dtg.h" + +namespace FlexFlow { + +UndirectedEdgeQuery undirected_edge_query_all(); + +UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, + UndirectedEdgeQuery const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml new file mode 100644 index 0000000000..239194a275 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "UndirectedEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h new file mode 100644 index 0000000000..d2830fceff --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_GRAPH_H + +#include "utils/graph/undirected/undirected_graph_view.h" +#include "utils/graph/undirected/i_undirected_graph.h" + +namespace FlexFlow { + +struct UndirectedGraph : virtual UndirectedGraphView { +public: + using Edge = UndirectedEdge; + using EdgeQuery = UndirectedEdgeQuery; + + UndirectedGraph() = delete; + UndirectedGraph(UndirectedGraph const &) = default; + UndirectedGraph &operator=(UndirectedGraph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + void add_edge(Edge const &); + void remove_edge(Edge const &); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if::value, + UndirectedGraph>::type + create() { + return UndirectedGraph(make_cow_ptr()); + } + + using UndirectedGraphView::UndirectedGraphView; + + friend struct GraphInternal; + +private: + IUndirectedGraph const &get_ptr() const; + IUndirectedGraph &get_ptr(); +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h new file mode 100644 index 0000000000..7f9bcfb953 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_GRAPH_VIEW_H + +#include "utils/graph/node/graph_view.h" +#include "utils/graph/undirected/undirected_edge.h" +#include "utils/graph/undirected/i_undirected_graph_view.h" + +namespace FlexFlow { + +struct UndirectedGraphView : virtual GraphView { +public: + using Edge = UndirectedEdge; + using EdgeQuery = UndirectedEdgeQuery; + + UndirectedGraphView() = delete; + UndirectedGraphView(UndirectedGraphView const &) = default; + UndirectedGraphView &operator=(UndirectedGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &query) const; + + template + static + typename std::enable_if::value, + UndirectedGraphView>::type + create(Args &&...args) { + return UndirectedGraphView(make_cow_ptr(std::forward(args)...)); + } + + using GraphView::GraphView; + + friend struct GraphInternal; + +private: + IUndirectedGraphView const &get_ptr() const; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected_edge.h b/lib/utils/include/utils/graph/undirected_edge.h deleted file mode 100644 index 98252c315a..0000000000 --- a/lib/utils/include/utils/graph/undirected_edge.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_UNDIRECTED_EDGE -#define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_UNDIRECTED_EDGE - -#include "node.h" - -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); - -public: - Node smaller; - Node bigger; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(UndirectedEdge, smaller, bigger); -FF_VISIT_FMTABLE(UndirectedEdge); - -bool is_connected_to(UndirectedEdge const &, Node const &); - -struct UndirectedEdgeQuery { - query_set nodes; - - static UndirectedEdgeQuery all(); -}; -FF_VISITABLE_STRUCT(UndirectedEdgeQuery, nodes); -FF_VISIT_FMTABLE(UndirectedEdgeQuery); - -UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, - UndirectedEdgeQuery const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h b/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h new file mode 100644 index 0000000000..de0ffae9d3 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_I_UPWARD_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_I_UPWARD_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h" + +namespace FlexFlow { + +struct IUpwardOpenMultiDiGraph : virtual public IUpwardOpenMultiDiGraphView { + virtual Node add_node() = 0; + virtual void add_node_unsafe(Node const &node) = 0; + virtual void remove_node_unsafe(Node const &node) = 0; + virtual void add_edge(UpwardOpenMultiDiEdge const &) = 0; + virtual void remove_edge(UpwardOpenMultiDiEdge const &) = 0; + virtual IUpwardOpenMultiDiGraph *clone() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUpwardOpenMultiDiGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h b/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h new file mode 100644 index 0000000000..d3daaf4c31 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_I_UPWARD_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_I_UPWARD_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/open_multidigraph/i_open_multidigraph_view.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct IUpwardOpenMultiDiGraphView : virtual public IOpenMultiDiGraphView { + virtual std::unordered_set + query_edges(UpwardOpenMultiDiEdgeQuery const &) const = 0; + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const final; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUpwardOpenMultiDiGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h new file mode 100644 index 0000000000..4bbc752eb3 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h @@ -0,0 +1,114 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "fbb0c70b77edf2b92ceb84523c67c2ad" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct UpwardOpenMultiDiEdge { + UpwardOpenMultiDiEdge() = delete; + explicit UpwardOpenMultiDiEdge(::FlexFlow::InputMultiDiEdge const &); + explicit UpwardOpenMultiDiEdge(::FlexFlow::MultiDiEdge const &); + template + static constexpr bool IsPartOfUpwardOpenMultiDiEdge_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::InputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type UpwardOpenMultiDiEdge", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::InputMultiDiEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::MultiDiEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type UpwardOpenMultiDiEdge", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfUpwardOpenMultiDiEdge_v, + "UpwardOpenMultiDiEdge::has() expected one of " + "[::FlexFlow::InputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfUpwardOpenMultiDiEdge_v, + "UpwardOpenMultiDiEdge::get() expected one of " + "[::FlexFlow::InputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfUpwardOpenMultiDiEdge_v, + "UpwardOpenMultiDiEdge::get() expected one of " + "[::FlexFlow::InputMultiDiEdge, ::FlexFlow::MultiDiEdge], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(UpwardOpenMultiDiEdge const &) const; + bool operator!=(UpwardOpenMultiDiEdge const &) const; + bool operator<(UpwardOpenMultiDiEdge const &) const; + bool operator>(UpwardOpenMultiDiEdge const &) const; + bool operator<=(UpwardOpenMultiDiEdge const &) const; + bool operator>=(UpwardOpenMultiDiEdge const &) const; + std::variant<::FlexFlow::InputMultiDiEdge, ::FlexFlow::MultiDiEdge> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::UpwardOpenMultiDiEdge> { + size_t operator()(::FlexFlow::UpwardOpenMultiDiEdge const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::UpwardOpenMultiDiEdge const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::UpwardOpenMultiDiEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h new file mode 100644 index 0000000000..2c2b6c5583 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_H + +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h" + +namespace FlexFlow { + +OpenMultiDiEdge open_multidiedge_from_upward_open(UpwardOpenMultiDiEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml new file mode 100644 index 0000000000..20ca3ad196 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "UpwardOpenMultiDiEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h", + "utils/graph/multidigraph/multi_di_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::InputMultiDiEdge" +key = "input_edge" + +[[values]] +type = "::FlexFlow::MultiDiEdge" +key = "standard_edge" diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h new file mode 100644 index 0000000000..7159399e91 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "45db44200f5b0ff7d80004f783ce1464" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct UpwardOpenMultiDiEdgeQuery { + UpwardOpenMultiDiEdgeQuery() = delete; + explicit UpwardOpenMultiDiEdgeQuery( + ::FlexFlow::InputMultiDiEdgeQuery const &input_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query); + + bool operator==(UpwardOpenMultiDiEdgeQuery const &) const; + bool operator!=(UpwardOpenMultiDiEdgeQuery const &) const; + bool operator<(UpwardOpenMultiDiEdgeQuery const &) const; + bool operator>(UpwardOpenMultiDiEdgeQuery const &) const; + bool operator<=(UpwardOpenMultiDiEdgeQuery const &) const; + bool operator>=(UpwardOpenMultiDiEdgeQuery const &) const; + ::FlexFlow::InputMultiDiEdgeQuery input_edge_query; + ::FlexFlow::MultiDiEdgeQuery standard_edge_query; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::UpwardOpenMultiDiEdgeQuery> { + size_t operator()(::FlexFlow::UpwardOpenMultiDiEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(UpwardOpenMultiDiEdgeQuery const &); +std::ostream &operator<<(std::ostream &, UpwardOpenMultiDiEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTI_DI_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml new file mode 100644 index 0000000000..ea1cfcb617 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "UpwardOpenMultiDiEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h", + "utils/graph/multidigraph/multi_di_edge_query.dtg.h", +] + +[[fields]] +name = "input_edge_query" +type = "::FlexFlow::InputMultiDiEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::MultiDiEdgeQuery" diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h new file mode 100644 index 0000000000..33e0abc028 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTIDIGRAPH_H + +#include "utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" + +namespace FlexFlow { + +struct UpwardOpenMultiDiGraph : virtual UpwardOpenMultiDiGraphView { +public: + using Edge = UpwardOpenMultiDiEdge; + using EdgeQuery = UpwardOpenMultiDiEdgeQuery; + + UpwardOpenMultiDiGraph() = delete; + UpwardOpenMultiDiGraph(UpwardOpenMultiDiGraph const &) = default; + UpwardOpenMultiDiGraph &operator=(UpwardOpenMultiDiGraph const &) = default; + + Node add_node(); + void add_node_unsafe(Node const &); + void remove_node_unsafe(Node const &); + + void add_edge(Edge const &); + void remove_edge(Edge const &); + + std::unordered_set query_edges(EdgeQuery const &) const; + + template + static typename std::enable_if< + std::is_base_of::value, + UpwardOpenMultiDiGraph>::type + create() { + return UpwardOpenMultiDiGraph(make_cow_ptr()); + } + +private: + using UpwardOpenMultiDiGraphView::UpwardOpenMultiDiGraphView; + + IUpwardOpenMultiDiGraph const &get_ptr() const; + IUpwardOpenMultiDiGraph &get_ptr(); +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraph); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h new file mode 100644 index 0000000000..1fc9907de2 --- /dev/null +++ b/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTIDIGRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UPWARD_OPEN_MULTIDIGRAPH_UPWARD_OPEN_MULTIDIGRAPH_VIEW_H + +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h" + +namespace FlexFlow { + +struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { +public: + using Edge = UpwardOpenMultiDiEdge; + using EdgeQuery = UpwardOpenMultiDiEdgeQuery; + + UpwardOpenMultiDiGraphView(UpwardOpenMultiDiGraphView const &) = default; + UpwardOpenMultiDiGraphView & + operator=(UpwardOpenMultiDiGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &); + std::unordered_set query_edges(EdgeQuery const &); + + template + static typename std::enable_if< + std::is_base_of::value, + UpwardOpenMultiDiGraphView>::type + create(Args &&...args) { + return UpwardOpenMultiDiGraphView( + cow_ptr_t(std::forward(args)...)); + } + +private: + using MultiDiGraphView::MultiDiGraphView; + + IUpwardOpenMultiDiGraphView const &get_ptr() const; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraphView); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/views/join_node_key.dtg.h b/lib/utils/include/utils/graph/views/join_node_key.dtg.h new file mode 100644 index 0000000000..b522fc29e9 --- /dev/null +++ b/lib/utils/include/utils/graph/views/join_node_key.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/views/join_node_key.struct.toml +/* proj-data +{ + "generated_from": "d18ad1216e748a6af1a1a132f18a2284" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_JOIN_NODE_KEY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_JOIN_NODE_KEY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/views/lr_direction.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct JoinNodeKey { + JoinNodeKey() = delete; + explicit JoinNodeKey(::FlexFlow::Node const &node, + ::FlexFlow::LRDirection const &direction); + + bool operator==(JoinNodeKey const &) const; + bool operator!=(JoinNodeKey const &) const; + bool operator<(JoinNodeKey const &) const; + bool operator>(JoinNodeKey const &) const; + bool operator<=(JoinNodeKey const &) const; + bool operator>=(JoinNodeKey const &) const; + ::FlexFlow::Node node; + ::FlexFlow::LRDirection direction; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::JoinNodeKey> { + size_t operator()(::FlexFlow::JoinNodeKey const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(JoinNodeKey const &); +std::ostream &operator<<(std::ostream &, JoinNodeKey const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_JOIN_NODE_KEY_DTG_H diff --git a/lib/utils/include/utils/graph/views/join_node_key.struct.toml b/lib/utils/include/utils/graph/views/join_node_key.struct.toml new file mode 100644 index 0000000000..9dce99f0a0 --- /dev/null +++ b/lib/utils/include/utils/graph/views/join_node_key.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "JoinNodeKey" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/views/lr_direction.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "direction" +type = "::FlexFlow::LRDirection" diff --git a/lib/utils/include/utils/graph/views/lr_direction.dtg.h b/lib/utils/include/utils/graph/views/lr_direction.dtg.h new file mode 100644 index 0000000000..452e646788 --- /dev/null +++ b/lib/utils/include/utils/graph/views/lr_direction.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/views/lr_direction.enum.toml +/* proj-data +{ + "generated_from": "0fef027ec69f92967f3171795ae9ddd2" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_LR_DIRECTION_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_LR_DIRECTION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LRDirection { LEFT, RIGHT }; +std::string format_as(LRDirection); +std::ostream &operator<<(std::ostream &, LRDirection); +void to_json(::nlohmann::json &, LRDirection); +void from_json(::nlohmann::json const &, LRDirection &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LRDirection) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_LR_DIRECTION_DTG_H diff --git a/lib/utils/include/utils/graph/views/lr_direction.enum.toml b/lib/utils/include/utils/graph/views/lr_direction.enum.toml new file mode 100644 index 0000000000..878a937b0b --- /dev/null +++ b/lib/utils/include/utils/graph/views/lr_direction.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "LRDirection" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT" + +[[values]] +name = "RIGHT" diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views/views.h similarity index 95% rename from lib/utils/include/utils/graph/views.h rename to lib/utils/include/utils/graph/views/views.h index a0ef837796..8330ef51bf 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -1,17 +1,12 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_VIEWS_H -#define _FLEXFLOW_UTILS_GRAPH_VIEWS_H - -#include "adjacency_digraph.h" -#include "digraph.h" -#include "labelled_graphs.h" -#include "multidigraph.h" -#include "open_graphs.h" -#include "undirected.h" -#include "utils/bidict.h" -#include "utils/graph/digraph_interfaces.h" -#include "utils/visitable.h" -#include -#include +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_VIEWS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_VIEWS_H + +#include "utils/graph/digraph/adjacency_digraph.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/open_multidigraph/open_multidigraph_view.h" +#include "utils/graph/views/join_node_key.dtg.h" namespace FlexFlow { @@ -90,14 +85,6 @@ struct NodeSource { std::size_t next_node_idx = 0; }; -enum class LRDirection { LEFT, RIGHT }; - -struct JoinNodeKey { - Node node; - req direction; -}; -FF_VISITABLE_STRUCT(JoinNodeKey, node, direction); - struct JoinedNodeView { public: JoinedNodeView() = delete; diff --git a/lib/utils/src/graph/multidiedge.cc b/lib/utils/src/graph/multidiedge.cc deleted file mode 100644 index 47d9a47023..0000000000 --- a/lib/utils/src/graph/multidiedge.cc +++ /dev/null @@ -1,155 +0,0 @@ -#include "utils/graph/multidiedge.h" - -namespace FlexFlow { - -OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::all() { - return {matchall(), matchall()}; -} - -OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::none() { - return {query_set({}), query_set({})}; -} - -InputMultiDiEdgeQuery InputMultiDiEdgeQuery::all() { - return {matchall(), matchall()}; -} - -InputMultiDiEdgeQuery InputMultiDiEdgeQuery::none() { - return {query_set({}), query_set({})}; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_src_nodes(query_set const &nodes) const { - MultiDiEdgeQuery e = *this; - // if (!is_matchall(e.srcs)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.srcs = nodes; - e.srcs = query_intersection(nodes, e.srcs); - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { - MultiDiEdgeQuery e = *this; - // if (!is_matchall(e.dsts)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.dsts = nodes; - e.dsts = query_intersection(nodes, e.dsts); - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_src_idxs(query_set const &idxs) const { - MultiDiEdgeQuery e{*this}; - // if (!is_matchall(e.srcIdxs)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.srcIdxs = idxs; - e.srcIdxs = query_intersection(idxs, e.srcIdxs); - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_dst_idxs(query_set const &idxs) const { - MultiDiEdgeQuery e = *this; - // if (!is_matchall(e.dstIdxs)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.dstIdxs = idxs; - e.dstIdxs = query_intersection(idxs, e.dstIdxs); - return e; -} - -MultiDiEdgeQuery MultiDiEdgeQuery::all() { - return {matchall(), - matchall(), - matchall(), - matchall()}; -} - -MultiDiEdgeQuery MultiDiEdgeQuery::none() { - return {query_set({}), - query_set({}), - query_set({}), - query_set({})}; -} - -MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs, - MultiDiEdgeQuery const &rhs) { - std::unordered_set srcs_t1; - if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(rhs.srcs); - } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(lhs.srcs); - } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); - } - - std::unordered_set dsts_t1; - if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(rhs.dsts); - } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(lhs.dsts); - } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); - } - - std::unordered_set srcIdxs_t1; - if (is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(rhs.srcIdxs); - } else if (!is_matchall(lhs.srcIdxs) && is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(lhs.srcIdxs); - } else if (!is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(query_intersection(lhs.srcIdxs, rhs.srcIdxs)); - } - - std::unordered_set dstIdxs_t1; - if (is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(rhs.dstIdxs); - } else if (!is_matchall(lhs.dstIdxs) && is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(lhs.dstIdxs); - } else if (!is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(query_intersection(lhs.dstIdxs, rhs.dstIdxs)); - } - - MultiDiEdgeQuery e = MultiDiEdgeQuery::all(); - e.srcs = srcs_t1; - e.dsts = dsts_t1; - e.srcIdxs = srcIdxs_t1; - e.dstIdxs = dstIdxs_t1; - return e; -} - -OutputMultiDiEdgeQuery - OutputMultiDiEdgeQuery::with_src_nodes(query_set const &nodes) const { - OutputMultiDiEdgeQuery e = *this; - // if (!is_matchall(e.srcs)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.srcs = nodes; - e.srcs = query_intersection(nodes, e.srcs); - return e; -} - -InputMultiDiEdgeQuery - InputMultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { - InputMultiDiEdgeQuery e = *this; - // if (!is_matchall(e.dsts)) { - // throw mk_runtime_error("Expected matchall previous value"); - // } - // e.dsts = nodes; - e.dsts = query_intersection(nodes, e.dsts); - return e; -} - -InputMultiDiEdge to_inputmultidiedge(MultiDiEdge const &e) { - return InputMultiDiEdge{e.dst, e.dst_idx, e.get_uid()}; -} - -OutputMultiDiEdge to_outputmultidiedge(MultiDiEdge const &e) { - return OutputMultiDiEdge{e.src, e.src_idx, e.get_uid()}; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc deleted file mode 100644 index 72caa3136e..0000000000 --- a/lib/utils/src/graph/node.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include "utils/graph/node.h" -#include "utils/graph/cow_ptr_t.h" -#include - -namespace FlexFlow { - -NodeQuery NodeQuery::all() { - return {matchall()}; -} - -NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { - - std::unordered_set nodes; - - if (is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { - nodes = allowed_values(rhs.nodes); - } else if (!is_matchall(lhs.nodes) && is_matchall(rhs.nodes)) { - nodes = allowed_values(lhs.nodes); - } else if (!is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { - nodes = allowed_values(query_intersection(lhs.nodes, rhs.nodes)); - } - - NodeQuery intersection_result = NodeQuery::all(); - intersection_result.nodes = nodes; - - return intersection_result; -} - -std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { - return this->ptr->query_nodes(g); -} - -bool is_ptr_equal(GraphView const &lhs, GraphView const &rhs) { - return lhs.ptr == rhs.ptr; -} - -GraphView::GraphView(cow_ptr_t ptr) : ptr(ptr) {} - -Node Graph::add_node() { - return get_ptr().add_node(); -} - -void Graph::add_node_unsafe(Node const &node) { - get_ptr().add_node_unsafe(node); -} - -void Graph::remove_node_unsafe(Node const &node) { - get_ptr().remove_node_unsafe(node); -} - -std::unordered_set Graph::query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); -} - -IGraph const &Graph::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); -} - -IGraph &Graph::get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/graph/open_edge.cc b/lib/utils/src/graph/open_edge.cc deleted file mode 100644 index 1b571d5c6c..0000000000 --- a/lib/utils/src/graph/open_edge.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "utils/graph/open_edge.h" - -namespace FlexFlow { - -bool is_input_edge(OpenMultiDiEdge const &e) { - return std::holds_alternative(e); -} - -bool is_output_edge(OpenMultiDiEdge const &e) { - return std::holds_alternative(e); -} - -bool is_standard_edge(OpenMultiDiEdge const &e) { - return std::holds_alternative(e); -} - -OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery const &input_edge_query, - MultiDiEdgeQuery const &standard_edge_query, - OutputMultiDiEdgeQuery const &output_edge_query) - : input_edge_query(input_edge_query), - standard_edge_query(standard_edge_query), - output_edge_query(output_edge_query) {} - -OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(MultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), q, OutputMultiDiEdgeQuery::none()) {} -OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - q, MultiDiEdgeQuery::none(), OutputMultiDiEdgeQuery::none()) {} -OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), MultiDiEdgeQuery::none(), q) {} - -OpenMultiDiEdgeQuery OpenMultiDiEdgeQuery::all() { - return {InputMultiDiEdgeQuery::all(), - MultiDiEdgeQuery::all(), - OutputMultiDiEdgeQuery::all()}; -} - -DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery( - OutputMultiDiEdgeQuery const &output_edge_query, - MultiDiEdgeQuery const &standard_edge_query) - : output_edge_query(output_edge_query), - standard_edge_query(standard_edge_query) {} -DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery( - OutputMultiDiEdgeQuery const &output_edge_query) - : DownwardOpenMultiDiEdgeQuery(output_edge_query, - MultiDiEdgeQuery::none()) {} -DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery( - MultiDiEdgeQuery const &standard_edge_query) - : DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery::all(), - standard_edge_query){}; - -DownwardOpenMultiDiEdgeQuery::operator OpenMultiDiEdgeQuery() const { - return OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), standard_edge_query, output_edge_query); -} - -UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery const &input_edge_query, - MultiDiEdgeQuery const &standard_edge_query) - : input_edge_query(input_edge_query), - standard_edge_query(standard_edge_query) {} - -UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery const &input_edge_query) - : input_edge_query(input_edge_query), - standard_edge_query(MultiDiEdgeQuery::none()) {} -UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery( - MultiDiEdgeQuery const &standard_edge_query) - : input_edge_query(InputMultiDiEdgeQuery::none()), - standard_edge_query(standard_edge_query) {} - -} // namespace FlexFlow diff --git a/lib/utils/src/graph/undirected_edge.cc b/lib/utils/src/graph/undirected_edge.cc deleted file mode 100644 index 4bae7e3d25..0000000000 --- a/lib/utils/src/graph/undirected_edge.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "utils/graph/undirected_edge.h" - -namespace FlexFlow { - -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; -} - -UndirectedEdgeQuery UndirectedEdgeQuery::all() { - return {matchall()}; -} - -UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, - UndirectedEdgeQuery const &rhs) { - return { - query_intersection(lhs.nodes, rhs.nodes), - }; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc similarity index 100% rename from lib/utils/src/graph/algorithms.cc rename to lib/utils/src/utils/graph/algorithms.cc diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc new file mode 100644 index 0000000000..794519ffed --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_edge.struct.toml +/* proj-data +{ + "generated_from": "4728f139efc6884057f39e38f44a791b" +} +*/ + +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" + +#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include + +namespace FlexFlow { +DataflowEdge::DataflowEdge(::FlexFlow::DataflowOutput const &src, + ::FlexFlow::DataflowInput const &dst) + : src(src), dst(dst) {} +bool DataflowEdge::operator==(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) == std::tie(other.src, other.dst); +} +bool DataflowEdge::operator!=(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) != std::tie(other.src, other.dst); +} +bool DataflowEdge::operator<(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) < std::tie(other.src, other.dst); +} +bool DataflowEdge::operator>(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) > std::tie(other.src, other.dst); +} +bool DataflowEdge::operator<=(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) <= std::tie(other.src, other.dst); +} +bool DataflowEdge::operator>=(DataflowEdge const &other) const { + return std::tie(this->src, this->dst) >= std::tie(other.src, other.dst); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataflowOutput>{}(x.src) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataflowInput>{}(x.dst) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc new file mode 100644 index 0000000000..65ac12003c --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc @@ -0,0 +1,100 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml +/* proj-data +{ + "generated_from": "684726a7add4aa912e194335fcfe91ab" +} +*/ + +#include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" + +#include "utils/graph/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +DataflowEdgeQuery::DataflowEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &src_nodes, + ::FlexFlow::query_set const &src_idxs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dst_nodes, + ::FlexFlow::query_set const &dst_idxs) + : src_nodes(src_nodes), src_idxs(src_idxs), dst_nodes(dst_nodes), + dst_idxs(dst_idxs) {} +bool DataflowEdgeQuery::operator==(DataflowEdgeQuery const &other) const { + return std::tie(this->src_nodes, + this->src_idxs, + this->dst_nodes, + this->dst_idxs) == + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +bool DataflowEdgeQuery::operator!=(DataflowEdgeQuery const &other) const { + return std::tie(this->src_nodes, + this->src_idxs, + this->dst_nodes, + this->dst_idxs) != + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +bool DataflowEdgeQuery::operator<(DataflowEdgeQuery const &other) const { + return std::tie( + this->src_nodes, this->src_idxs, this->dst_nodes, this->dst_idxs) < + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +bool DataflowEdgeQuery::operator>(DataflowEdgeQuery const &other) const { + return std::tie( + this->src_nodes, this->src_idxs, this->dst_nodes, this->dst_idxs) > + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +bool DataflowEdgeQuery::operator<=(DataflowEdgeQuery const &other) const { + return std::tie(this->src_nodes, + this->src_idxs, + this->dst_nodes, + this->dst_idxs) <= + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +bool DataflowEdgeQuery::operator>=(DataflowEdgeQuery const &other) const { + return std::tie(this->src_nodes, + this->src_idxs, + this->dst_nodes, + this->dst_idxs) >= + std::tie( + other.src_nodes, other.src_idxs, other.dst_nodes, other.dst_idxs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.src_nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set>{}(x.src_idxs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dst_nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set>{}(x.dst_idxs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc new file mode 100644 index 0000000000..f3afb4a9b1 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc @@ -0,0 +1,10 @@ +#include "utils/graph/dataflow_graph/dataflow_graph.h" + +namespace FlexFlow { + +NodeAddedResult DataflowGraph::add_node(std::vector const &inputs, + int num_outputs) { + return this->get_interface().add_node(inputs, num_outputs); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph_view.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph_view.cc new file mode 100644 index 0000000000..e088e44441 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph_view.cc @@ -0,0 +1,21 @@ +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set DataflowGraphView::query_nodes(NodeQuery const &q) const { + return this->get_interface().query_nodes(q); +} + +std::unordered_set DataflowGraphView::query_edges(DataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); +} + +std::unordered_set DataflowGraphView::query_outputs(DataflowOutputQuery const &q) const { + return this->get_interface().query_outputs(q); +} + +IDataflowGraphView const &DataflowGraphView::get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc new file mode 100644 index 0000000000..32d61f4f0b --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc @@ -0,0 +1,61 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml +/* proj-data +{ + "generated_from": "9fc7657f7fcc71fdad9e6a5040771ad7" +} +*/ + +#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" + +#include + +namespace FlexFlow { +DataflowInput::DataflowInput(::FlexFlow::Node const &node, int const &idx) + : node(node), idx(idx) {} +bool DataflowInput::operator==(DataflowInput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool DataflowInput::operator!=(DataflowInput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool DataflowInput::operator<(DataflowInput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool DataflowInput::operator>(DataflowInput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool DataflowInput::operator<=(DataflowInput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool DataflowInput::operator>=(DataflowInput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowInput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc new file mode 100644 index 0000000000..8c8cf6b73a --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml +/* proj-data +{ + "generated_from": "b704f2549a69ee6bfc1c5e28df421f9c" +} +*/ + +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" + +#include "utils/graph/node.dtg.h" +#include + +namespace FlexFlow { +DataflowOutput::DataflowOutput(::FlexFlow::Node const &node, int const &idx) + : node(node), idx(idx) {} +bool DataflowOutput::operator==(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool DataflowOutput::operator!=(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool DataflowOutput::operator<(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool DataflowOutput::operator>(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool DataflowOutput::operator<=(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool DataflowOutput::operator>=(DataflowOutput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowOutput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc new file mode 100644 index 0000000000..7bc200e887 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml +/* proj-data +{ + "generated_from": "6f662c3c4d285a4fd3c60713e6fc67fa" +} +*/ + +#include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" + +#include "utils/graph/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +DataflowOutputQuery::DataflowOutputQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &nodes, + ::FlexFlow::query_set const &output_idxs) + : nodes(nodes), output_idxs(output_idxs) {} +bool DataflowOutputQuery::operator==(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) == + std::tie(other.nodes, other.output_idxs); +} +bool DataflowOutputQuery::operator!=(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) != + std::tie(other.nodes, other.output_idxs); +} +bool DataflowOutputQuery::operator<(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) < + std::tie(other.nodes, other.output_idxs); +} +bool DataflowOutputQuery::operator>(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) > + std::tie(other.nodes, other.output_idxs); +} +bool DataflowOutputQuery::operator<=(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) <= + std::tie(other.nodes, other.output_idxs); +} +bool DataflowOutputQuery::operator>=(DataflowOutputQuery const &other) const { + return std::tie(this->nodes, this->output_idxs) >= + std::tie(other.nodes, other.output_idxs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowOutputQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set>{}(x.output_idxs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutputQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowOutputQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc new file mode 100644 index 0000000000..f29054cc2d --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc @@ -0,0 +1,20 @@ +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::unordered_set IDataflowGraphView::query_edges(MultiDiEdgeQuery const &q) const { + DataflowEdgeQuery dataflow_query = DataflowEdgeQuery{ + q.srcs, + matchall(), + q.dsts, + matchall(), + }; + std::unordered_set dataflow_edges = this->query_edges(dataflow_query); + + return transform(dataflow_edges, [](DataflowEdge const &e) { + return MultiDiEdge{e.src.node, e.dst.node, std::make_pair(e.src.idx, e.dst.idx)}; + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc new file mode 100644 index 0000000000..dcbe3578f2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml +/* proj-data +{ + "generated_from": "4536bb54376e2e221e0ff29347e81662" +} +*/ + +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" + +#include "utils/fmt/vector.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" +#include "utils/graph/node.dtg.h" +#include +#include + +namespace FlexFlow { +NodeAddedResult::NodeAddedResult( + ::FlexFlow::Node const &node, + std::vector<::FlexFlow::DataflowOutput> const &outputs) + : node(node), outputs(outputs) {} +bool NodeAddedResult::operator==(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) == + std::tie(other.node, other.outputs); +} +bool NodeAddedResult::operator!=(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) != + std::tie(other.node, other.outputs); +} +bool NodeAddedResult::operator<(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) < + std::tie(other.node, other.outputs); +} +bool NodeAddedResult::operator>(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) > + std::tie(other.node, other.outputs); +} +bool NodeAddedResult::operator<=(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) <= + std::tie(other.node, other.outputs); +} +bool NodeAddedResult::operator>=(NodeAddedResult const &other) const { + return std::tie(this->node, this->outputs) >= + std::tie(other.node, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(NodeAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NodeAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/adjacency_digraph.cc b/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc similarity index 92% rename from lib/utils/src/graph/adjacency_digraph.cc rename to lib/utils/src/utils/graph/digraph/adjacency_digraph.cc index 1438edc78b..705d8f6158 100644 --- a/lib/utils/src/graph/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/adjacency_digraph.h" +#include "utils/graph/digraph/adjacency_digraph.h" #include namespace FlexFlow { @@ -12,7 +12,7 @@ Node AdjacencyDiGraph::add_node() { void AdjacencyDiGraph::add_node_unsafe(Node const &node) { adjacency[node]; - this->next_node_idx = std::max(this->next_node_idx, node.value() + 1); + this->next_node_idx = std::max(this->next_node_idx, node.raw_uid + 1); } void AdjacencyDiGraph::remove_node_unsafe(Node const &n) { diff --git a/lib/utils/src/utils/graph/digraph/di_input.dtg.cc b/lib/utils/src/utils/graph/digraph/di_input.dtg.cc new file mode 100644 index 0000000000..7b44d41e97 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/di_input.dtg.cc @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/di_input.struct.toml +/* proj-data +{ + "generated_from": "19ab2e465577ae9e7add8b73c63e671f" +} +*/ + +#include "utils/graph/digraph/di_input.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include + +namespace FlexFlow { +DiInput::DiInput(::FlexFlow::Node const &dst) : dst(dst) {} +bool DiInput::operator==(DiInput const &other) const { + return std::tie(this->dst) == std::tie(other.dst); +} +bool DiInput::operator!=(DiInput const &other) const { + return std::tie(this->dst) != std::tie(other.dst); +} +bool DiInput::operator<(DiInput const &other) const { + return std::tie(this->dst) < std::tie(other.dst); +} +bool DiInput::operator>(DiInput const &other) const { + return std::tie(this->dst) > std::tie(other.dst); +} +bool DiInput::operator<=(DiInput const &other) const { + return std::tie(this->dst) <= std::tie(other.dst); +} +bool DiInput::operator>=(DiInput const &other) const { + return std::tie(this->dst) >= std::tie(other.dst); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(::FlexFlow::DiInput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.dst) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DiInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DiInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/di_output.dtg.cc b/lib/utils/src/utils/graph/digraph/di_output.dtg.cc new file mode 100644 index 0000000000..9723a1cd84 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/di_output.dtg.cc @@ -0,0 +1,44 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/di_output.struct.toml +/* proj-data +{ + "generated_from": "a8f3fc2ad9e00f3c29a6dcd4658199ba" +} +*/ + +#include "utils/graph/digraph/di_output.dtg.h" + +#include "utils/graph/node.dtg.h" + +namespace FlexFlow { +DiOutput::DiOutput(::FlexFlow::Node const &src) : src(src) {} +bool DiOutput::operator==(DiOutput const &other) const { + return std::tie(this->src) == std::tie(other.src); +} +bool DiOutput::operator!=(DiOutput const &other) const { + return std::tie(this->src) != std::tie(other.src); +} +bool DiOutput::operator<(DiOutput const &other) const { + return std::tie(this->src) < std::tie(other.src); +} +bool DiOutput::operator>(DiOutput const &other) const { + return std::tie(this->src) > std::tie(other.src); +} +bool DiOutput::operator<=(DiOutput const &other) const { + return std::tie(this->src) <= std::tie(other.src); +} +bool DiOutput::operator>=(DiOutput const &other) const { + return std::tie(this->src) >= std::tie(other.src); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(::FlexFlow::DiOutput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.src) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/utils/graph/digraph/digraph.cc similarity index 66% rename from lib/utils/src/graph/digraph.cc rename to lib/utils/src/utils/graph/digraph/digraph.cc index bdfe5ff599..24015dc1f3 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/utils/graph/digraph/digraph.cc @@ -1,22 +1,7 @@ -#include "utils/graph/digraph.h" -#include "utils/containers.h" -#include "utils/graph/digraph_interfaces.h" +#include "utils/graph/digraph/digraph.h" namespace FlexFlow { -std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { - return this->get_ptr().query_nodes(q); -} - -std::unordered_set - DiGraphView::query_edges(EdgeQuery const &query) const { - return get_ptr().query_edges(query); -} - -IDiGraphView const &DiGraphView::get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); -} - Node DiGraph::add_node() { return this->get_ptr().add_node(); } @@ -54,4 +39,5 @@ IDiGraph const &DiGraph::get_ptr() const { return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/digraph_view.cc b/lib/utils/src/utils/graph/digraph/digraph_view.cc new file mode 100644 index 0000000000..fb6de481d6 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/digraph_view.cc @@ -0,0 +1,18 @@ +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { + return this->get_ptr().query_nodes(q); +} + +std::unordered_set + DiGraphView::query_edges(EdgeQuery const &query) const { + return get_ptr().query_edges(query); +} + +IDiGraphView const &DiGraphView::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc b/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc new file mode 100644 index 0000000000..79f910de69 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/directed_edge.struct.toml +/* proj-data +{ + "generated_from": "406f818eb74797f6ea07231506a56f81" +} +*/ + +#include "utils/graph/digraph/directed_edge.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include + +namespace FlexFlow { +DirectedEdge::DirectedEdge(::FlexFlow::Node const &src, + ::FlexFlow::Node const &dst) + : src(src), dst(dst) {} +bool DirectedEdge::operator==(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) == std::tie(other.src, other.dst); +} +bool DirectedEdge::operator!=(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) != std::tie(other.src, other.dst); +} +bool DirectedEdge::operator<(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) < std::tie(other.src, other.dst); +} +bool DirectedEdge::operator>(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) > std::tie(other.src, other.dst); +} +bool DirectedEdge::operator<=(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) <= std::tie(other.src, other.dst); +} +bool DirectedEdge::operator>=(DirectedEdge const &other) const { + return std::tie(this->src, this->dst) >= std::tie(other.src, other.dst); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DirectedEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.src) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::Node>{}(x.dst) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DirectedEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DirectedEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc b/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc new file mode 100644 index 0000000000..3804bd3399 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml +/* proj-data +{ + "generated_from": "294ae0103df2a3c388a2ce140c271f4e" +} +*/ + +#include "utils/graph/digraph/directed_edge_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +DirectedEdgeQuery::DirectedEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts) + : srcs(srcs), dsts(dsts) {} +bool DirectedEdgeQuery::operator==(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) == std::tie(other.srcs, other.dsts); +} +bool DirectedEdgeQuery::operator!=(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) != std::tie(other.srcs, other.dsts); +} +bool DirectedEdgeQuery::operator<(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) < std::tie(other.srcs, other.dsts); +} +bool DirectedEdgeQuery::operator>(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) > std::tie(other.srcs, other.dsts); +} +bool DirectedEdgeQuery::operator<=(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) <= std::tie(other.srcs, other.dsts); +} +bool DirectedEdgeQuery::operator>=(DirectedEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) >= std::tie(other.srcs, other.dsts); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(DirectedEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DirectedEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/diedge.cc b/lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc similarity index 83% rename from lib/utils/src/graph/diedge.cc rename to lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc index 7a21518311..2522e6aaa1 100644 --- a/lib/utils/src/graph/diedge.cc +++ b/lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc @@ -1,9 +1,9 @@ -#include "utils/graph/diedge.h" +#include "utils/graph/directed_graph/directed_edge_query.h" namespace FlexFlow { -DirectedEdgeQuery DirectedEdgeQuery::all() { - return {matchall(), matchall()}; +DirectedEdgeQuery directed_edge_query_all() { + return DirectedEdgeQuery{matchall(), matchall()}; } bool matches_edge(DirectedEdgeQuery const &q, DirectedEdge const &e) { @@ -30,7 +30,7 @@ DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, result_dsts = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); } - return {result_srcs, result_dsts}; + return DirectedEdgeQuery{result_srcs, result_dsts}; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc new file mode 100644 index 0000000000..392a57e78d --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "a48025d66b3bdc8eec931e33694b0a22" +} +*/ + +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h" + +#include + +namespace FlexFlow { +DownwardOpenMultiDiEdge::DownwardOpenMultiDiEdge( + ::FlexFlow::OutputMultiDiEdge const &v) + : raw_variant(v) {} +DownwardOpenMultiDiEdge::DownwardOpenMultiDiEdge( + ::FlexFlow::MultiDiEdge const &v) + : raw_variant(v) {} +bool DownwardOpenMultiDiEdge::operator==( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant == other.raw_variant; +} +bool DownwardOpenMultiDiEdge::operator!=( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant != other.raw_variant; +} +bool DownwardOpenMultiDiEdge::operator<( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant < other.raw_variant; +} +bool DownwardOpenMultiDiEdge::operator>( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant > other.raw_variant; +} +bool DownwardOpenMultiDiEdge::operator<=( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool DownwardOpenMultiDiEdge::operator>=( + DownwardOpenMultiDiEdge const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::DownwardOpenMultiDiEdge>::operator()( + ::FlexFlow::DownwardOpenMultiDiEdge const &x) const { + return std::hash< + std::variant<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::MultiDiEdge>>{}( + x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::DownwardOpenMultiDiEdge const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DownwardOpenMultiDiEdge", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::DownwardOpenMultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..a707609f39 --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "396fddca0f20f2459ee9938138d3fc40" +} +*/ + +#include "utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h" + +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" +#include + +namespace FlexFlow { +DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery( + ::FlexFlow::OutputMultiDiEdgeQuery const &output_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query) + : output_edge_query(output_edge_query), + standard_edge_query(standard_edge_query) {} +bool DownwardOpenMultiDiEdgeQuery::operator==( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) == + std::tie(other.output_edge_query, other.standard_edge_query); +} +bool DownwardOpenMultiDiEdgeQuery::operator!=( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) != + std::tie(other.output_edge_query, other.standard_edge_query); +} +bool DownwardOpenMultiDiEdgeQuery::operator<( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) < + std::tie(other.output_edge_query, other.standard_edge_query); +} +bool DownwardOpenMultiDiEdgeQuery::operator>( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) > + std::tie(other.output_edge_query, other.standard_edge_query); +} +bool DownwardOpenMultiDiEdgeQuery::operator<=( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) <= + std::tie(other.output_edge_query, other.standard_edge_query); +} +bool DownwardOpenMultiDiEdgeQuery::operator>=( + DownwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->output_edge_query, this->standard_edge_query) >= + std::tie(other.output_edge_query, other.standard_edge_query); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DownwardOpenMultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::OutputMultiDiEdgeQuery>{}(x.output_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::MultiDiEdgeQuery>{}(x.standard_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DownwardOpenMultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + DownwardOpenMultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc new file mode 100644 index 0000000000..01d205309d --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc @@ -0,0 +1 @@ +#include "utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h" diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc new file mode 100644 index 0000000000..d43d6af164 --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc @@ -0,0 +1,13 @@ +#include "utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h" + +namespace FlexFlow { + +std::unordered_set + IDownwardOpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { + NOT_IMPLEMENTED(); + //return widen( + // this->query_edges(DownwardOpenMultiDiEdgeQuery{q.output_edge_query, + // q.standard_edge_query})); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/graph/labelled_graphs.cc b/lib/utils/src/utils/graph/labelled_graphs.cc similarity index 100% rename from lib/utils/src/graph/labelled_graphs.cc rename to lib/utils/src/utils/graph/labelled_graphs.cc diff --git a/lib/utils/src/graph/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc similarity index 97% rename from lib/utils/src/graph/adjacency_multidigraph.cc rename to lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc index 0d5d3a70fd..671c327e51 100644 --- a/lib/utils/src/graph/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/adjacency_multidigraph.h" #include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc new file mode 100644 index 0000000000..1b02668db9 --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc @@ -0,0 +1 @@ +#include "utils/graph/multidigraph/i_multidigraph.h" diff --git a/lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc b/lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc new file mode 100644 index 0000000000..e6e2fed587 --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc @@ -0,0 +1,15 @@ +#include "utils/graph/multidigraph/i_multidigraph_view.h" + +namespace FlexFlow { + +std::unordered_set + IMultiDiGraphView::query_edges(DirectedEdgeQuery const &q) const { + return transform( + query_edges(MultiDiEdgeQuery{ + q.srcs, q.dsts}), + [](MultiDiEdge const &e) { + return DirectedEdge{e.src, e.dst}; + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc new file mode 100644 index 0000000000..ae9070c9dd --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "73b001bfb7a0b75c42cd5037bb8dc686" +} +*/ + +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include + +namespace FlexFlow { +MultiDiEdge::MultiDiEdge(::FlexFlow::Node const &src, + ::FlexFlow::Node const &dst, + std::pair const &raw_edge_uid) + : src(src), dst(dst), raw_edge_uid(raw_edge_uid) {} +bool MultiDiEdge::operator==(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) == + std::tie(other.src, other.dst, other.raw_edge_uid); +} +bool MultiDiEdge::operator!=(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) != + std::tie(other.src, other.dst, other.raw_edge_uid); +} +bool MultiDiEdge::operator<(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) < + std::tie(other.src, other.dst, other.raw_edge_uid); +} +bool MultiDiEdge::operator>(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) > + std::tie(other.src, other.dst, other.raw_edge_uid); +} +bool MultiDiEdge::operator<=(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) <= + std::tie(other.src, other.dst, other.raw_edge_uid); +} +bool MultiDiEdge::operator>=(MultiDiEdge const &other) const { + return std::tie(this->src, this->dst, this->raw_edge_uid) >= + std::tie(other.src, other.dst, other.raw_edge_uid); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::MultiDiEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.src) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::Node>{}(x.dst) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.raw_edge_uid) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(MultiDiEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc new file mode 100644 index 0000000000..dd4b7a3cb6 --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc @@ -0,0 +1,46 @@ +#include "utils/graph/multidigraph/multi_di_edge_query.h" + +namespace FlexFlow { + +MultiDiEdgeQuery multidiedge_query_all() { + return MultiDiEdgeQuery{matchall(), + matchall()}; +} + +MultiDiEdgeQuery multidiedge_query_none() { + return MultiDiEdgeQuery{query_set({}), + query_set({})}; +} + +MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs, + MultiDiEdgeQuery const &rhs) { + std::unordered_set srcs_t1; + if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(rhs.srcs); + } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(lhs.srcs); + } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); + } + + std::unordered_set dsts_t1; + if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(rhs.dsts); + } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(lhs.dsts); + } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); + } + + MultiDiEdgeQuery e = multidiedge_query_all(); + e.srcs = srcs_t1; + e.dsts = dsts_t1; + return e; +} + +MultiDiEdgeQuery query_union(MultiDiEdgeQuery const &, + MultiDiEdgeQuery const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..686a6f362a --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "bede7a523428098275e26ba89bb30eb0" +} +*/ + +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +MultiDiEdgeQuery::MultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts) + : srcs(srcs), dsts(dsts) {} +bool MultiDiEdgeQuery::operator==(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) == std::tie(other.srcs, other.dsts); +} +bool MultiDiEdgeQuery::operator!=(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) != std::tie(other.srcs, other.dsts); +} +bool MultiDiEdgeQuery::operator<(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) < std::tie(other.srcs, other.dsts); +} +bool MultiDiEdgeQuery::operator>(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) > std::tie(other.srcs, other.dsts); +} +bool MultiDiEdgeQuery::operator<=(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) <= std::tie(other.srcs, other.dsts); +} +bool MultiDiEdgeQuery::operator>=(MultiDiEdgeQuery const &other) const { + return std::tie(this->srcs, this->dsts) >= std::tie(other.srcs, other.dsts); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::MultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.srcs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dsts) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(MultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/multidigraph.cc similarity index 70% rename from lib/utils/src/graph/multidigraph.cc rename to lib/utils/src/utils/graph/multidigraph/multidigraph.cc index 771e01e573..9e4cf84f72 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/utils/graph/multidigraph/multidigraph.cc @@ -1,18 +1,7 @@ -#include "utils/graph/multidigraph.h" -#include "utils/graph/multidigraph_interfaces.h" +#include "utils/graph/multidigraph/multidigraph.h" namespace FlexFlow { -std::unordered_set - IMultiDiGraphView::query_edges(DirectedEdgeQuery const &q) const { - return transform( - query_edges(MultiDiEdgeQuery{ - q.srcs, q.dsts, matchall(), matchall()}), - [](MultiDiEdge const &e) { - return DirectedEdge{e.src, e.dst}; - }); -} - std::unordered_set MultiDiGraphView::query_nodes(NodeQuery const &q) const { return this->get_ptr().query_nodes(q); @@ -32,14 +21,6 @@ Node MultiDiGraph::add_node() { return this->get_ptr().add_node(); } -NodePort MultiDiGraph::add_node_port() { - return this->get_ptr().add_node_port(); -} - -void MultiDiGraph::add_node_port_unsafe(NodePort const &np) { - return this->get_ptr().add_node_port_unsafe(np); -} - void MultiDiGraph::add_node_unsafe(Node const &n) { return this->get_ptr().add_node_unsafe(n); } diff --git a/lib/utils/src/utils/graph/node/graph.cc b/lib/utils/src/utils/graph/node/graph.cc new file mode 100644 index 0000000000..69a66f169d --- /dev/null +++ b/lib/utils/src/utils/graph/node/graph.cc @@ -0,0 +1,30 @@ +#include "utils/graph/undirected/graph.h" + +namespace FlexFlow { + +Node Graph::add_node() { + return get_ptr().add_node(); +} + +void Graph::add_node_unsafe(Node const &node) { + get_ptr().add_node_unsafe(node); +} + +void Graph::remove_node_unsafe(Node const &node) { + get_ptr().remove_node_unsafe(node); +} + +std::unordered_set Graph::query_nodes(NodeQuery const &q) const { + return get_ptr().query_nodes(q); +} + +IGraph const &Graph::get_ptr() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); +} + +IGraph &Graph::get_ptr() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/graph_view.cc b/lib/utils/src/utils/graph/node/graph_view.cc new file mode 100644 index 0000000000..5ea0fe7b63 --- /dev/null +++ b/lib/utils/src/utils/graph/node/graph_view.cc @@ -0,0 +1,15 @@ +#include "utils/graph/undirected/graph_view.h" + +namespace FlexFlow { + +GraphView::GraphView(cow_ptr_t ptr) : ptr(ptr) {} + +std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { + return this->ptr->query_nodes(g); +} + +bool is_ptr_equal(GraphView const &lhs, GraphView const &rhs) { + return lhs.ptr == rhs.ptr; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/i_graph_view.cc b/lib/utils/src/utils/graph/node/i_graph_view.cc new file mode 100644 index 0000000000..63c5b829cb --- /dev/null +++ b/lib/utils/src/utils/graph/node/i_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/undirected/i_graph_view.h" diff --git a/lib/utils/src/utils/graph/node/node.dtg.cc b/lib/utils/src/utils/graph/node/node.dtg.cc new file mode 100644 index 0000000000..6a314f64dd --- /dev/null +++ b/lib/utils/src/utils/graph/node/node.dtg.cc @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/node/node.struct.toml +/* proj-data +{ + "generated_from": "cc4828f6a9dcc4c3435767bd6ccfc866" +} +*/ + +#include "utils/graph/node/node.dtg.h" + +#include +#include + +namespace FlexFlow { +Node::Node(size_t const &raw_uid) : raw_uid(raw_uid) {} +bool Node::operator==(Node const &other) const { + return std::tie(this->raw_uid) == std::tie(other.raw_uid); +} +bool Node::operator!=(Node const &other) const { + return std::tie(this->raw_uid) != std::tie(other.raw_uid); +} +bool Node::operator<(Node const &other) const { + return std::tie(this->raw_uid) < std::tie(other.raw_uid); +} +bool Node::operator>(Node const &other) const { + return std::tie(this->raw_uid) > std::tie(other.raw_uid); +} +bool Node::operator<=(Node const &other) const { + return std::tie(this->raw_uid) <= std::tie(other.raw_uid); +} +bool Node::operator>=(Node const &other) const { + return std::tie(this->raw_uid) >= std::tie(other.raw_uid); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(::FlexFlow::Node const &x) const { + size_t result = 0; + result ^= std::hash{}(x.raw_uid) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(Node const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Node const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/node_query.cc b/lib/utils/src/utils/graph/node/node_query.cc new file mode 100644 index 0000000000..c74457465c --- /dev/null +++ b/lib/utils/src/utils/graph/node/node_query.cc @@ -0,0 +1,31 @@ +#include "utils/graph/node/node_query.h" + +namespace FlexFlow { + +NodeQuery node_query_all() { + return NodeQuery{matchall()}; +} + +NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { + + std::unordered_set nodes; + + if (is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { + nodes = allowed_values(rhs.nodes); + } else if (!is_matchall(lhs.nodes) && is_matchall(rhs.nodes)) { + nodes = allowed_values(lhs.nodes); + } else if (!is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { + nodes = allowed_values(query_intersection(lhs.nodes, rhs.nodes)); + } + + NodeQuery intersection_result = node_query_all(); + intersection_result.nodes = nodes; + + return intersection_result; +} + +NodeQuery query_union(NodeQuery const &lhs, NodeQuery const &rhs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/node_query.dtg.cc b/lib/utils/src/utils/graph/node/node_query.dtg.cc new file mode 100644 index 0000000000..516d9f9d88 --- /dev/null +++ b/lib/utils/src/utils/graph/node/node_query.dtg.cc @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/node/node_query.struct.toml +/* proj-data +{ + "generated_from": "e3e4a13f0d1a7ca9f179ba09dd4c5735" +} +*/ + +#include "utils/graph/node/node_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +NodeQuery::NodeQuery(::FlexFlow::query_set<::FlexFlow::Node> const &nodes) + : nodes(nodes) {} +bool NodeQuery::operator==(NodeQuery const &other) const { + return std::tie(this->nodes) == std::tie(other.nodes); +} +bool NodeQuery::operator!=(NodeQuery const &other) const { + return std::tie(this->nodes) != std::tie(other.nodes); +} +bool NodeQuery::operator<(NodeQuery const &other) const { + return std::tie(this->nodes) < std::tie(other.nodes); +} +bool NodeQuery::operator>(NodeQuery const &other) const { + return std::tie(this->nodes) > std::tie(other.nodes); +} +bool NodeQuery::operator<=(NodeQuery const &other) const { + return std::tie(this->nodes) <= std::tie(other.nodes); +} +bool NodeQuery::operator>=(NodeQuery const &other) const { + return std::tie(this->nodes) >= std::tie(other.nodes); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::NodeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(NodeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NodeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc new file mode 100644 index 0000000000..673bf3b010 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc @@ -0,0 +1,9 @@ +#include "utils/graph/open_multidigraph/input_multi_di_edge.h" + +namespace FlexFlow { + +InputMultiDiEdge input_multidiedge_from_multidiedge(MultiDiEdge const &e) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc new file mode 100644 index 0000000000..922da92415 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "d779a19c1f8f096dc1dfabf95633b115" +} +*/ + +#include "utils/graph/open_multidigraph/input_multi_di_edge.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include +#include + +namespace FlexFlow { +InputMultiDiEdge::InputMultiDiEdge(::FlexFlow::Node const &dst, + size_t const &raw_uid) + : dst(dst), raw_uid(raw_uid) {} +bool InputMultiDiEdge::operator==(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) == + std::tie(other.dst, other.raw_uid); +} +bool InputMultiDiEdge::operator!=(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) != + std::tie(other.dst, other.raw_uid); +} +bool InputMultiDiEdge::operator<(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) < + std::tie(other.dst, other.raw_uid); +} +bool InputMultiDiEdge::operator>(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) > + std::tie(other.dst, other.raw_uid); +} +bool InputMultiDiEdge::operator<=(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) <= + std::tie(other.dst, other.raw_uid); +} +bool InputMultiDiEdge::operator>=(InputMultiDiEdge const &other) const { + return std::tie(this->dst, this->raw_uid) >= + std::tie(other.dst, other.raw_uid); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::InputMultiDiEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.dst) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.raw_uid) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(InputMultiDiEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InputMultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc new file mode 100644 index 0000000000..6f56af308d --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc @@ -0,0 +1,13 @@ +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.h" + +namespace FlexFlow { + +InputMultiDiEdgeQuery input_multidiedge_query_all() { + return InputMultiDiEdgeQuery{query_set::matchall()}; +} + +InputMultiDiEdgeQuery input_multidiedge_query_none() { + return InputMultiDiEdgeQuery{query_set::match_none()}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..d3f0568f0a --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "c42e43b28fae9a63d94e54f244dd3ee0" +} +*/ + +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +InputMultiDiEdgeQuery::InputMultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts) + : dsts(dsts) {} +bool InputMultiDiEdgeQuery::operator==( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) == std::tie(other.dsts); +} +bool InputMultiDiEdgeQuery::operator!=( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) != std::tie(other.dsts); +} +bool InputMultiDiEdgeQuery::operator<( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) < std::tie(other.dsts); +} +bool InputMultiDiEdgeQuery::operator>( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) > std::tie(other.dsts); +} +bool InputMultiDiEdgeQuery::operator<=( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) <= std::tie(other.dsts); +} +bool InputMultiDiEdgeQuery::operator>=( + InputMultiDiEdgeQuery const &other) const { + return std::tie(this->dsts) >= std::tie(other.dsts); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::InputMultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dsts) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(InputMultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InputMultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc new file mode 100644 index 0000000000..9110e6d6b1 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc @@ -0,0 +1,18 @@ +#include "utils/graph/open_multidigraph/open_multi_di_edge.h" + +namespace FlexFlow { + +bool is_input_edge(OpenMultiDiEdge const &e) { + return std::holds_alternative(e); +} + +bool is_output_edge(OpenMultiDiEdge const &e) { + return std::holds_alternative(e); +} + +bool is_standard_edge(OpenMultiDiEdge const &e) { + return std::holds_alternative(e); +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc new file mode 100644 index 0000000000..f3c19f2786 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "f7a6881be7d51ba916f3740828c23d91" +} +*/ + +#include "utils/graph/open_multidigraph/open_multi_di_edge.dtg.h" + +#include + +namespace FlexFlow { +OpenMultiDiEdge::OpenMultiDiEdge(::FlexFlow::InputMultiDiEdge const &v) + : raw_variant(v) {} +OpenMultiDiEdge::OpenMultiDiEdge(::FlexFlow::OutputMultiDiEdge const &v) + : raw_variant(v) {} +OpenMultiDiEdge::OpenMultiDiEdge(::FlexFlow::MultiDiEdge const &v) + : raw_variant(v) {} +bool OpenMultiDiEdge::operator==(OpenMultiDiEdge const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OpenMultiDiEdge::operator!=(OpenMultiDiEdge const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OpenMultiDiEdge::operator<(OpenMultiDiEdge const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OpenMultiDiEdge::operator>(OpenMultiDiEdge const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OpenMultiDiEdge::operator<=(OpenMultiDiEdge const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OpenMultiDiEdge::operator>=(OpenMultiDiEdge const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OpenMultiDiEdge>::operator()( + ::FlexFlow::OpenMultiDiEdge const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenMultiDiEdge const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OpenMultiDiEdge", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OpenMultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc new file mode 100644 index 0000000000..448ef9f884 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc @@ -0,0 +1,16 @@ +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.h" +#include "utils/graph/multidigraph/multi_di_edge_query.h" + +namespace FlexFlow { + +OpenMultiDiEdgeQuery open_multidiedge_query_all() { + return OpenMultiDiEdgeQuery{ + input_multidiedge_query_all(), + multidiedge_query_all(), + output_multidiedge_query_all(), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..d0708ebd4c --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "86fc384b53b6b27982dfe6ab8fff2d04" +} +*/ + +#include "utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h" + +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" +#include + +namespace FlexFlow { +OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery( + ::FlexFlow::InputMultiDiEdgeQuery const &input_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query, + ::FlexFlow::OutputMultiDiEdgeQuery const &output_edge_query) + : input_edge_query(input_edge_query), + standard_edge_query(standard_edge_query), + output_edge_query(output_edge_query) {} +bool OpenMultiDiEdgeQuery::operator==(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) == + std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +bool OpenMultiDiEdgeQuery::operator!=(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) != + std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +bool OpenMultiDiEdgeQuery::operator<(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) < std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +bool OpenMultiDiEdgeQuery::operator>(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) > std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +bool OpenMultiDiEdgeQuery::operator<=(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) <= + std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +bool OpenMultiDiEdgeQuery::operator>=(OpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, + this->standard_edge_query, + this->output_edge_query) >= + std::tie(other.input_edge_query, + other.standard_edge_query, + other.output_edge_query); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::OpenMultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::InputMultiDiEdgeQuery>{}(x.input_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::MultiDiEdgeQuery>{}(x.standard_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash<::FlexFlow::OutputMultiDiEdgeQuery>{}(x.output_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OpenMultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OpenMultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc new file mode 100644 index 0000000000..05654ca72b --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc @@ -0,0 +1,9 @@ +#include "utils/graph/open_multidigraph/output_multi_di_edge.h" + +namespace FlexFlow { + +OutputMultiDiEdge output_multidiedge_from_multidiedge(MultiDiEdge const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc new file mode 100644 index 0000000000..d93ef8fa74 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "2ec351b641e8ecfd79fd7df2ec13dbd4" +} +*/ + +#include "utils/graph/open_multidigraph/output_multi_di_edge.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include +#include + +namespace FlexFlow { +OutputMultiDiEdge::OutputMultiDiEdge(::FlexFlow::Node const &src, + size_t const &raw_uid) + : src(src), raw_uid(raw_uid) {} +bool OutputMultiDiEdge::operator==(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) == + std::tie(other.src, other.raw_uid); +} +bool OutputMultiDiEdge::operator!=(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) != + std::tie(other.src, other.raw_uid); +} +bool OutputMultiDiEdge::operator<(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) < + std::tie(other.src, other.raw_uid); +} +bool OutputMultiDiEdge::operator>(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) > + std::tie(other.src, other.raw_uid); +} +bool OutputMultiDiEdge::operator<=(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) <= + std::tie(other.src, other.raw_uid); +} +bool OutputMultiDiEdge::operator>=(OutputMultiDiEdge const &other) const { + return std::tie(this->src, this->raw_uid) >= + std::tie(other.src, other.raw_uid); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::OutputMultiDiEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.src) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.raw_uid) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputMultiDiEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OutputMultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc new file mode 100644 index 0000000000..803cadf653 --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc @@ -0,0 +1,13 @@ +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.h" + +namespace FlexFlow { + +OutputMultiDiEdgeQuery output_multidiedge_query_all() { + return OutputMultiDiEdgeQuery{query_set::matchall()}; +} + +OutputMultiDiEdgeQuery output_multidiedge_query_none() { + return OutputMultiDiEdgeQuery{query_set::match_none()}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..f9a017d60a --- /dev/null +++ b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "4833874bcc5268ec7a7f8fe92186ba17" +} +*/ + +#include "utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +OutputMultiDiEdgeQuery::OutputMultiDiEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &srcs) + : srcs(srcs) {} +bool OutputMultiDiEdgeQuery::operator==( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) == std::tie(other.srcs); +} +bool OutputMultiDiEdgeQuery::operator!=( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) != std::tie(other.srcs); +} +bool OutputMultiDiEdgeQuery::operator<( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) < std::tie(other.srcs); +} +bool OutputMultiDiEdgeQuery::operator>( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) > std::tie(other.srcs); +} +bool OutputMultiDiEdgeQuery::operator<=( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) <= std::tie(other.srcs); +} +bool OutputMultiDiEdgeQuery::operator>=( + OutputMultiDiEdgeQuery const &other) const { + return std::tie(this->srcs) >= std::tie(other.srcs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::OutputMultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.srcs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputMultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OutputMultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/utils/graph/serialparallel.cc similarity index 100% rename from lib/utils/src/graph/serialparallel.cc rename to lib/utils/src/utils/graph/serialparallel.cc diff --git a/lib/utils/src/graph/serialparallel_internal.h b/lib/utils/src/utils/graph/serialparallel_internal.h similarity index 100% rename from lib/utils/src/graph/serialparallel_internal.h rename to lib/utils/src/utils/graph/serialparallel_internal.h diff --git a/lib/utils/src/graph/traversal.cc b/lib/utils/src/utils/graph/traversal.cc similarity index 100% rename from lib/utils/src/graph/traversal.cc rename to lib/utils/src/utils/graph/traversal.cc diff --git a/lib/utils/src/graph/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc similarity index 93% rename from lib/utils/src/graph/hashmap_undirected_graph.cc rename to lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc index 2d80c31f92..78788a6454 100644 --- a/lib/utils/src/graph/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/hashmap_undirected_graph.h" +#include "utils/graph/undirected/hashmap_undirected_graph.h" #include "utils/containers.h" #include "utils/exception.h" @@ -13,7 +13,7 @@ Node HashmapUndirectedGraph::add_node() { void HashmapUndirectedGraph::add_node_unsafe(Node const &node) { adjacency[node]; - this->next_node_idx = std::max(this->next_node_idx, node.value() + 1); + this->next_node_idx = std::max(this->next_node_idx, node.raw_uid + 1); } void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc new file mode 100644 index 0000000000..7af1dc8fcc --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -0,0 +1,40 @@ +#include "utils/graph/undirected/undirected_edge.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) + : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} + +static std::tuple tie(UndirectedEdge const &e) { + return std::tie(e.smaller, e.bigger); +} + +bool UndirectedEdge::operator==(UndirectedEdge const &other) const { + return tie(*this) == tie(other); +} + +bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { + return tie(*this) != tie(other); +} + +bool UndirectedEdge::operator<(UndirectedEdge const &other) const { + return tie(*this) < tie(other); +} + +bool is_connected_to(UndirectedEdge const &e, Node const &n) { + return e.bigger == n || e.smaller == n; +} + +} // namespace FlexFlow + +namespace std { + +using namespace FlexFlow; + +size_t hash::operator()(UndirectedEdge const &e) const { + std::tuple members = ::FlexFlow::tie(e); + return std::hash{}(members); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc new file mode 100644 index 0000000000..5c41eef7da --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -0,0 +1,16 @@ +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +UndirectedEdgeQuery undirected_edge_query_all() { + return UndirectedEdgeQuery{matchall()}; +} + +UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, + UndirectedEdgeQuery const &rhs) { + return UndirectedEdgeQuery{ + query_intersection(lhs.nodes, rhs.nodes), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc new file mode 100644 index 0000000000..f67e39519c --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc @@ -0,0 +1,61 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/undirected/undirected_edge_query.struct.toml +/* proj-data +{ + "generated_from": "10df85f620b0fb6e70496d6585be6b43" +} +*/ + +#include "utils/graph/undirected/undirected_edge_query.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/query_set.h" +#include + +namespace FlexFlow { +UndirectedEdgeQuery::UndirectedEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::Node> const &nodes) + : nodes(nodes) {} +bool UndirectedEdgeQuery::operator==(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) == std::tie(other.nodes); +} +bool UndirectedEdgeQuery::operator!=(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) != std::tie(other.nodes); +} +bool UndirectedEdgeQuery::operator<(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) < std::tie(other.nodes); +} +bool UndirectedEdgeQuery::operator>(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) > std::tie(other.nodes); +} +bool UndirectedEdgeQuery::operator<=(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) <= std::tie(other.nodes); +} +bool UndirectedEdgeQuery::operator>=(UndirectedEdgeQuery const &other) const { + return std::tie(this->nodes) >= std::tie(other.nodes); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::UndirectedEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(UndirectedEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UndirectedEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/utils/graph/undirected/undirected_graph.cc similarity index 66% rename from lib/utils/src/graph/undirected.cc rename to lib/utils/src/utils/graph/undirected/undirected_graph.cc index b1e8be7f14..32c9468ec3 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_graph.cc @@ -1,7 +1,4 @@ -#include "utils/graph/undirected.h" -#include "utils/containers.h" -#include "utils/graph/node.h" -#include +#include "utils/graph/undirected/undirected_graph.h" namespace FlexFlow { @@ -45,19 +42,4 @@ std::unordered_set return this->get_ptr().query_nodes(q); } -std::unordered_set - UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { - return this->get_ptr().query_edges(q); -} - -std::unordered_set - UndirectedGraphView::query_nodes(NodeQuery const &q) const { - return this->get_ptr().query_nodes(q); -} - -IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::dynamic_pointer_cast( - GraphView::ptr.get()); -} - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_graph_view.cc b/lib/utils/src/utils/graph/undirected/undirected_graph_view.cc new file mode 100644 index 0000000000..b74ccfa322 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/undirected_graph_view.cc @@ -0,0 +1,20 @@ +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { + return this->get_ptr().query_edges(q); +} + +std::unordered_set + UndirectedGraphView::query_nodes(NodeQuery const &q) const { + return this->get_ptr().query_nodes(q); +} + +IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { + return *std::dynamic_pointer_cast( + GraphView::ptr.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc new file mode 100644 index 0000000000..0b1e46e1f4 --- /dev/null +++ b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc @@ -0,0 +1,11 @@ +#include "utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h" + +namespace FlexFlow { + +std::unordered_set + IUpwardOpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { + return widen(this->query_edges( + UpwardOpenMultiDiEdgeQuery{q.input_edge_query, q.standard_edge_query})); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc new file mode 100644 index 0000000000..d3533450ab --- /dev/null +++ b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc @@ -0,0 +1,16 @@ +#include "utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h" +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::unordered_set + IUpwardOpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { + + std::unordered_set queried = this->query_edges( + UpwardOpenMultiDiEdgeQuery{q.input_edge_query, q.standard_edge_query}); + + return transform(queried, [](UpwardOpenMultiDiEdge const &upward_e) { return open_multidiedge_from_upward_open(upward_e); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc new file mode 100644 index 0000000000..0ab741229f --- /dev/null +++ b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenMultiDiEdge open_multidiedge_from_upward_open(UpwardOpenMultiDiEdge const &upward_e) { + return upward_e.visit(overload { + [](MultiDiEdge const &e) { return OpenMultiDiEdge{e}; }, + [](OpenMultiDiEdge const &e) { return OpenMultiDiEdge{e}; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc new file mode 100644 index 0000000000..02d9864acd --- /dev/null +++ b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml +/* proj-data +{ + "generated_from": "fbb0c70b77edf2b92ceb84523c67c2ad" +} +*/ + +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h" + +#include + +namespace FlexFlow { +UpwardOpenMultiDiEdge::UpwardOpenMultiDiEdge( + ::FlexFlow::InputMultiDiEdge const &v) + : raw_variant(v) {} +UpwardOpenMultiDiEdge::UpwardOpenMultiDiEdge(::FlexFlow::MultiDiEdge const &v) + : raw_variant(v) {} +bool UpwardOpenMultiDiEdge::operator==( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant == other.raw_variant; +} +bool UpwardOpenMultiDiEdge::operator!=( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant != other.raw_variant; +} +bool UpwardOpenMultiDiEdge::operator<( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant < other.raw_variant; +} +bool UpwardOpenMultiDiEdge::operator>( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant > other.raw_variant; +} +bool UpwardOpenMultiDiEdge::operator<=( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool UpwardOpenMultiDiEdge::operator>=( + UpwardOpenMultiDiEdge const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::UpwardOpenMultiDiEdge>::operator()( + ::FlexFlow::UpwardOpenMultiDiEdge const &x) const { + return std::hash< + std::variant<::FlexFlow::InputMultiDiEdge, ::FlexFlow::MultiDiEdge>>{}( + x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::UpwardOpenMultiDiEdge const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type UpwardOpenMultiDiEdge", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::UpwardOpenMultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc new file mode 100644 index 0000000000..f8aac56490 --- /dev/null +++ b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml +/* proj-data +{ + "generated_from": "45db44200f5b0ff7d80004f783ce1464" +} +*/ + +#include "utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h" + +#include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h" +#include + +namespace FlexFlow { +UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery( + ::FlexFlow::InputMultiDiEdgeQuery const &input_edge_query, + ::FlexFlow::MultiDiEdgeQuery const &standard_edge_query) + : input_edge_query(input_edge_query), + standard_edge_query(standard_edge_query) {} +bool UpwardOpenMultiDiEdgeQuery::operator==( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) == + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool UpwardOpenMultiDiEdgeQuery::operator!=( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) != + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool UpwardOpenMultiDiEdgeQuery::operator<( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) < + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool UpwardOpenMultiDiEdgeQuery::operator>( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) > + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool UpwardOpenMultiDiEdgeQuery::operator<=( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) <= + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool UpwardOpenMultiDiEdgeQuery::operator>=( + UpwardOpenMultiDiEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) >= + std::tie(other.input_edge_query, other.standard_edge_query); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::UpwardOpenMultiDiEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::InputMultiDiEdgeQuery>{}(x.input_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::MultiDiEdgeQuery>{}(x.standard_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(UpwardOpenMultiDiEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UpwardOpenMultiDiEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/utils/graph/views.cc similarity index 80% rename from lib/utils/src/graph/views.cc rename to lib/utils/src/utils/graph/views.cc index af15b0d6aa..567914a249 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/utils/graph/views.cc @@ -1,13 +1,15 @@ -#include "utils/graph/views.h" +#include "utils/graph/views/views.h" #include "utils/containers.h" #include "utils/disjoint_set.h" #include "utils/exception.h" #include "utils/graph/algorithms.h" -#include "utils/graph/digraph.h" -#include "utils/graph/digraph_interfaces.h" -#include "utils/graph/query_set.h" -#include "utils/graph/undirected.h" -#include +#include "utils/graph/open_multidigraph/input_multi_di_edge.h" +#include "utils/graph/open_multidigraph/input_multi_di_edge_query.h" +#include "utils/graph/open_multidigraph/output_multi_di_edge.h" +#include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/multidigraph/multi_di_edge_query.h" namespace FlexFlow { @@ -16,7 +18,7 @@ FlippedView::FlippedView(DiGraphView const &g) : g(g) {} std::unordered_set FlippedView::query_edges(DirectedEdgeQuery const &query) const { std::unordered_set result = - this->g.query_edges({query.dsts, query.srcs}); + this->g.query_edges(DirectedEdgeQuery{query.dsts, query.srcs}); return transform(result, [](DirectedEdge const &e) { return flipped(e); }); } @@ -44,7 +46,7 @@ ContractNodeView *ContractNodeView::clone() const { } DirectedEdge flipped(DirectedEdge const &e) { - return {e.src, e.dst}; + return DirectedEdge{e.src, e.dst}; } UndirectedSubgraphView::UndirectedSubgraphView( @@ -58,13 +60,13 @@ UndirectedSubgraphView *UndirectedSubgraphView::clone() const { std::unordered_set UndirectedSubgraphView::query_edges( UndirectedEdgeQuery const &query) const { - UndirectedEdgeQuery subgraph_query = {this->subgraph_nodes}; + UndirectedEdgeQuery subgraph_query = UndirectedEdgeQuery{this->subgraph_nodes}; return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set UndirectedSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes(query_intersection(query, {this->subgraph_nodes})); + return this->g.query_nodes(query_intersection(query, NodeQuery{this->subgraph_nodes})); } DiSubgraphView::DiSubgraphView(DiGraphView const &g, @@ -73,14 +75,14 @@ DiSubgraphView::DiSubgraphView(DiGraphView const &g, std::unordered_set DiSubgraphView::query_edges(DirectedEdgeQuery const &query) const { - DirectedEdgeQuery subgraph_query = {this->subgraph_nodes, + DirectedEdgeQuery subgraph_query = DirectedEdgeQuery{this->subgraph_nodes, this->subgraph_nodes}; return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set DiSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes(query_intersection(query, {this->subgraph_nodes})); + return this->g.query_nodes(query_intersection(query, NodeQuery{this->subgraph_nodes})); } DiSubgraphView *DiSubgraphView::clone() const { @@ -93,15 +95,13 @@ MultiDiSubgraphView::MultiDiSubgraphView( std::unordered_set MultiDiSubgraphView::query_edges(MultiDiEdgeQuery const &query) const { - MultiDiEdgeQuery subgraph_query = MultiDiEdgeQuery::all() - .with_src_nodes(this->subgraph_nodes) - .with_dst_nodes(this->subgraph_nodes); + MultiDiEdgeQuery subgraph_query = MultiDiEdgeQuery{this->subgraph_nodes, this->subgraph_nodes}; return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set MultiDiSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes(query_intersection(query, {this->subgraph_nodes})); + return this->g.query_nodes(query_intersection(query, NodeQuery{this->subgraph_nodes})); } UndirectedGraphView @@ -128,11 +128,11 @@ Node NodeSource::fresh_node() { JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { for (Node const &n : get_nodes(lhs)) { - this->mapping.equate({n, LRDirection::LEFT}, + this->mapping.equate(JoinNodeKey{n, LRDirection::LEFT}, this->node_source.fresh_node()); } for (Node const &n : get_nodes(rhs)) { - this->mapping.equate({n, LRDirection::RIGHT}, + this->mapping.equate(JoinNodeKey{n, LRDirection::RIGHT}, this->node_source.fresh_node()); } } @@ -180,7 +180,7 @@ std::unordered_set std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdgeQuery const &query) const { - std::unordered_set nodes = this->query_nodes({query.nodes}); + std::unordered_set nodes = this->query_nodes(NodeQuery{query.nodes}); std::unordered_set left_nodes, right_nodes; for (Node const &n : nodes) { JoinNodeKey k = this->joined_nodes.at_node(n); @@ -193,10 +193,10 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( } std::unordered_set result; - for (UndirectedEdge const &e : this->lhs.query_edges({left_nodes})) { + for (UndirectedEdge const &e : this->lhs.query_edges(UndirectedEdgeQuery{left_nodes})) { result.insert(this->fix_lhs_edge(e)); } - for (UndirectedEdge const &e : this->rhs.query_edges({right_nodes})) { + for (UndirectedEdge const &e : this->rhs.query_edges(UndirectedEdgeQuery{right_nodes})) { result.insert(this->fix_rhs_edge(e)); } @@ -205,14 +205,14 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdge JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key({e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key({e.bigger, LRDirection::LEFT})}; + return {this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; } UndirectedEdge JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key({e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key({e.bigger, LRDirection::RIGHT})}; + return {this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::RIGHT})}; } JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, @@ -231,12 +231,12 @@ std::unordered_set std::unordered_set JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { - std::unordered_set srcs = this->query_nodes(query.srcs); - std::unordered_set dsts = this->query_nodes(query.dsts); + std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); + std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); auto traced_srcs = this->joined_nodes.trace_nodes(srcs); auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - DirectedEdgeQuery left_query = {traced_srcs.first, traced_dsts.first}; - DirectedEdgeQuery right_query = {traced_srcs.second, traced_dsts.second}; + DirectedEdgeQuery left_query = DirectedEdgeQuery{traced_srcs.first, traced_dsts.first}; + DirectedEdgeQuery right_query = DirectedEdgeQuery{traced_srcs.second, traced_dsts.second}; std::unordered_set result; for (DirectedEdge const &e : this->lhs.query_edges(left_query)) { @@ -250,13 +250,13 @@ std::unordered_set } DirectedEdge JoinedDigraphView::fix_lhs_edge(DirectedEdge const &e) const { - return {this->joined_nodes.at_join_key({e.src, LRDirection::LEFT}), - this->joined_nodes.at_join_key({e.dst, LRDirection::LEFT})}; + return DirectedEdge{this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT})}; } DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { - return {this->joined_nodes.at_join_key({e.src, LRDirection::RIGHT}), - this->joined_nodes.at_join_key({e.dst, LRDirection::RIGHT})}; + return DirectedEdge{this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; } JoinedMultiDigraphView::JoinedMultiDigraphView(MultiDiGraphView const &lhs, @@ -270,15 +270,15 @@ std::unordered_set std::unordered_set JoinedMultiDigraphView::query_edges(MultiDiEdgeQuery const &query) const { - std::unordered_set srcs = this->query_nodes(query.srcs); - std::unordered_set dsts = this->query_nodes(query.dsts); + std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); + std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); auto traced_srcs = this->joined_nodes.trace_nodes(srcs); auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - MultiDiEdgeQuery left_query = { - traced_srcs.first, traced_dsts.first, query.srcIdxs, query.dstIdxs}; - MultiDiEdgeQuery right_query = { - traced_srcs.second, traced_dsts.second, query.srcIdxs, query.dstIdxs}; + MultiDiEdgeQuery left_query = MultiDiEdgeQuery{ + traced_srcs.first, traced_dsts.first}; + MultiDiEdgeQuery right_query = MultiDiEdgeQuery{ + traced_srcs.second, traced_dsts.second}; return set_union( transform(this->lhs.query_edges(left_query), @@ -292,17 +292,17 @@ JoinedMultiDigraphView *JoinedMultiDigraphView::clone() const { } MultiDiEdge JoinedMultiDigraphView::fix_lhs_edge(MultiDiEdge const &e) const { - return {this->joined_nodes.at_join_key({e.dst, LRDirection::LEFT}), - e.dst_idx, - this->joined_nodes.at_join_key({e.src, LRDirection::LEFT}), - e.src_idx}; + return MultiDiEdge{ + this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}) + }; } MultiDiEdge JoinedMultiDigraphView::fix_rhs_edge(MultiDiEdge const &e) const { - return {this->joined_nodes.at_join_key({e.dst, LRDirection::RIGHT}), - e.dst_idx, - this->joined_nodes.at_join_key({e.src, LRDirection::RIGHT}), - e.src_idx}; + return MultiDiEdge{ + this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}) + }; } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { @@ -325,7 +325,10 @@ std::unordered_set } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return {{e.smaller, e.bigger}, {e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.smaller, e.bigger}, + DirectedEdge{e.bigger, e.smaller} + }; } std::unordered_set to_directed_edges( @@ -334,7 +337,7 @@ std::unordered_set to_directed_edges( } DirectedEdge to_directed_edge(MultiDiEdge const &e) { - return {e.src, e.dst}; + return DirectedEdge{e.src, e.dst}; } std::unordered_set @@ -373,7 +376,7 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges({q.srcs}), g.query_edges({q.dsts})); + intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); @@ -397,7 +400,7 @@ std::unordered_set ViewDiGraphAsMultiDiGraph::query_edges( this->g.query_edges(directed_query); return transform(directed_edges, [](DirectedEdge const &e) { - return MultiDiEdge{e.dst, NodePort(0), e.src, NodePort(0)}; + return MultiDiEdge{e.dst, e.src}; }); } @@ -446,8 +449,8 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); - this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); + this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); } std::unordered_set @@ -471,7 +474,7 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); } UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { @@ -497,19 +500,20 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) : g(g), nodes(nodes) { - this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); } std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - OpenMultiDiEdgeQuery subgraph_query( - InputMultiDiEdgeQuery::none(), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - q.output_edge_query.with_src_nodes(nodes)); + OpenMultiDiEdgeQuery subgraph_query{ + input_multidiedge_query_none(), + MultiDiEdgeQuery{nodes, nodes}, + OutputMultiDiEdgeQuery{nodes}, + }; std::unordered_set result = g.query_edges(subgraph_query); extend(result, - query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); + query_edge(outputs, OutputMultiDiEdgeQuery{nodes})); return result; } diff --git a/lib/utils/src/utils/graph/views/join_node_key.dtg.cc b/lib/utils/src/utils/graph/views/join_node_key.dtg.cc new file mode 100644 index 0000000000..0139d3974f --- /dev/null +++ b/lib/utils/src/utils/graph/views/join_node_key.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/views/join_node_key.struct.toml +/* proj-data +{ + "generated_from": "d18ad1216e748a6af1a1a132f18a2284" +} +*/ + +#include "utils/graph/views/join_node_key.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/views/lr_direction.dtg.h" +#include + +namespace FlexFlow { +JoinNodeKey::JoinNodeKey(::FlexFlow::Node const &node, + ::FlexFlow::LRDirection const &direction) + : node(node), direction(direction) {} +bool JoinNodeKey::operator==(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) == + std::tie(other.node, other.direction); +} +bool JoinNodeKey::operator!=(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) != + std::tie(other.node, other.direction); +} +bool JoinNodeKey::operator<(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) < + std::tie(other.node, other.direction); +} +bool JoinNodeKey::operator>(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) > + std::tie(other.node, other.direction); +} +bool JoinNodeKey::operator<=(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) <= + std::tie(other.node, other.direction); +} +bool JoinNodeKey::operator>=(JoinNodeKey const &other) const { + return std::tie(this->node, this->direction) >= + std::tie(other.node, other.direction); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::JoinNodeKey const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::LRDirection>{}(x.direction) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(JoinNodeKey const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, JoinNodeKey const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/views/lr_direction.dtg.cc b/lib/utils/src/utils/graph/views/lr_direction.dtg.cc new file mode 100644 index 0000000000..4a2a7576f8 --- /dev/null +++ b/lib/utils/src/utils/graph/views/lr_direction.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/views/lr_direction.enum.toml +/* proj-data +{ + "generated_from": "0fef027ec69f92967f3171795ae9ddd2" +} +*/ + +#include "utils/graph/views/lr_direction.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::LRDirection x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LRDirection x) { + switch (x) { + case LRDirection::LEFT: + return "LEFT"; + case LRDirection::RIGHT: + return "RIGHT"; + default: + std::ostringstream oss; + oss << "Unknown LRDirection value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LRDirection x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LRDirection x) { + switch (x) { + case LRDirection::LEFT: + j = "LEFT"; + break; + case LRDirection::RIGHT: + j = "RIGHT"; + break; + default: + std::ostringstream oss; + oss << "Unknown LRDirection value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LRDirection &x) { + std::string as_str = j.get(); + if (as_str == "LEFT") { + x = LRDirection::LEFT; + } else if (as_str == "RIGHT") { + x = LRDirection::RIGHT; + } else { + std::ostringstream oss; + oss << "Unknown LRDirection value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::LRDirection::LEFT, + FlexFlow::LRDirection::RIGHT); +} +} // namespace rc From 3be816ffc802fdf271531bf87da55fcd568d84b0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 20 Jun 2024 23:14:02 -0700 Subject: [PATCH 09/71] Removing visitable from sp code --- flake.lock | 15 +- flake.nix | 3 +- .../multihead_attention_inputs.dtg.cc | 2 - ...multihead_attention_parallel_inputs.dtg.cc | 5 - .../src/op-attrs/ops/broadcast.dtg.cc | 1 - .../src/op-attrs/ops/cast_attrs.dtg.cc | 1 - .../src/op-attrs/ops/combine_attrs.dtg.cc | 2 - .../src/op-attrs/ops/concat_attrs.dtg.cc | 2 - .../ops/conv_2d/conv_2d_input_shape.dtg.cc | 2 - .../conv_2d_parallel_input_shape.dtg.cc | 2 - .../src/op-attrs/ops/conv_2d_attrs.dtg.cc | 3 - .../op-attrs/ops/element_binary_attrs.dtg.cc | 2 - .../op-attrs/ops/element_unary_attrs.dtg.cc | 2 - .../src/op-attrs/ops/embedding_attrs.dtg.cc | 3 - .../src/op-attrs/ops/gather_attrs.dtg.cc | 2 - .../src/op-attrs/ops/layer_norm_attrs.dtg.cc | 3 - .../src/op-attrs/ops/linear_attrs.dtg.cc | 4 - .../ops/parallel_attention_inputs.dtg.cc | 1 - .../src/op-attrs/ops/pool_2d_attrs.dtg.cc | 2 - .../src/op-attrs/ops/reduce_attrs.dtg.cc | 4 - .../src/op-attrs/ops/repartition_attrs.dtg.cc | 2 - .../src/op-attrs/ops/reshape_attrs.dtg.cc | 1 - .../src/op-attrs/ops/reverse_attrs.dtg.cc | 2 - .../src/op-attrs/ops/softmax_attrs.dtg.cc | 2 - .../src/op-attrs/ops/split_attrs.dtg.cc | 3 - .../src/op-attrs/ops/transpose_attrs.dtg.cc | 3 - .../src/op-attrs/parallel_tensor_dims.dtg.cc | 6 - .../src/op-attrs/parallel_tensor_shape.dtg.cc | 2 - .../src/op-attrs/replica_parallel_dim.dtg.cc | 1 - .../op-attrs/replica_parallel_dim_set.dtg.cc | 2 - lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc | 1 - lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc | 2 - lib/pcg/src/pcg/computation_graph.dtg.cc | 4 - .../layer_added_result.dtg.cc | 3 - .../operator_added_result.dtg.cc | 3 - .../v1/graphs/v1_multidigraph.dtg.cc | 5 - .../v1/graphs/v1_operator_graph.dtg.cc | 5 - .../constant_initializer_attrs.dtg.cc | 2 - lib/pcg/src/pcg/layer_attrs.dtg.cc | 4 - lib/pcg/src/pcg/layer_guid_t.dtg.cc | 1 - lib/pcg/src/pcg/machine_view.dtg.cc | 2 - .../operator_graph_input.dtg.cc | 1 - .../operator_graph_output.dtg.cc | 1 - .../parallel_computation_graph.dtg.cc | 4 - .../parallel_layer_added_result.dtg.cc | 4 - .../parallel_layer_attrs.dtg.cc | 3 - .../parallel_layer_guid_t.dtg.cc | 1 - .../parallel_tensor_attrs.dtg.cc | 5 - .../parallel_tensor_guid_t.dtg.cc | 1 - lib/pcg/src/pcg/strided_rectangle.dtg.cc | 2 - lib/pcg/src/pcg/strided_rectangle_side.dtg.cc | 1 - lib/pcg/src/pcg/tensor_attrs.dtg.cc | 4 - lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 1 - .../operator_attribute_constraint.dtg.cc | 3 - .../operator_attribute_list_access.dtg.cc | 1 - .../operator_attribute_list_size.dtg.cc | 1 - .../operator_attribute_pattern.dtg.cc | 3 - .../output_graph/attr_constant.dtg.cc | 1 - .../output_graph/output_graph_expr.dtg.cc | 3 - .../output_operator_attr_access.dtg.cc | 2 - .../output_operator_attrs_assignment.dtg.cc | 3 - .../src/substitutions/pcg_pattern.dtg.cc | 4 - .../sub_parallel_computation_graph.dtg.cc | 4 - .../src/substitutions/substitution.dtg.cc | 3 - .../tensor_attribute_constraint.dtg.cc | 3 - .../tensor_attribute_list_access.dtg.cc | 1 - .../tensor_attribute_list_size.dtg.cc | 1 - .../tensor_attribute_pattern.dtg.cc | 3 - .../unlabelled/closed_pattern_edge.dtg.cc | 2 - .../downward_open_pattern_edge.dtg.cc | 2 - .../unlabelled/edge_splits.dtg.cc | 4 - .../unlabelled/input_pattern_edge.dtg.cc | 2 - .../match_additional_criterion.dtg.cc | 5 - .../unlabelled/match_split.dtg.cc | 2 - .../multidigraph_pattern_match.dtg.cc | 5 - .../unlabelled/output_pattern_edge.dtg.cc | 2 - .../unlabelled/pattern_edge.dtg.cc | 2 - .../unlabelled/pattern_node.dtg.cc | 2 - .../unlabelled/pattern_split.dtg.cc | 3 - .../unlabelled_graph_pattern.dtg.cc | 2 - .../upward_open_pattern_edge.dtg.cc | 2 - lib/utils/include/utils/fmt/variant.h | 35 ++ lib/utils/include/utils/graph/algorithms.h | 173 +++--- .../dataflow_graph/dataflow_edge_query.dtg.h | 5 +- .../dataflow_graph/dataflow_edge_query.h | 13 + .../dataflow_edge_query.struct.toml | 3 +- .../graph/dataflow_graph/dataflow_graph.h | 19 +- .../dataflow_graph/dataflow_graph_view.h | 4 +- .../graph/dataflow_graph/dataflow_input.dtg.h | 3 +- .../dataflow_graph/dataflow_input.struct.toml | 4 + .../dataflow_graph/dataflow_output.dtg.h | 4 +- .../dataflow_output.struct.toml | 2 +- .../dataflow_output_query.dtg.h | 5 +- .../dataflow_output_query.struct.toml | 3 +- .../dataflow_graph/i_dataflow_graph_view.h | 6 +- .../dataflow_graph/node_added_result.dtg.h | 5 +- .../node_added_result.struct.toml | 3 +- .../include/utils/graph/digraph/algorithms.h | 17 + .../utils/graph/digraph/di_output.dtg.h | 4 +- .../utils/graph/digraph/di_output.struct.toml | 2 +- .../graph/digraph/directed_edge_query.dtg.h | 10 +- .../digraph/directed_edge_query.struct.toml | 1 + .../downward_open_multi_di_edge.dtg.h.old} | 0 ...nward_open_multi_di_edge.variant.toml.old} | 0 ...wnward_open_multi_di_edge_query.dtg.h.old} | 0 ..._open_multi_di_edge_query.struct.toml.old} | 0 .../downward_open_multidigraph.h.old} | 0 .../downward_open_multidigraph_view.h.old} | 0 .../i_downward_open_multidigraph.h.old} | 0 .../i_downward_open_multidigraph_view.h.old} | 0 .../{algorithms.h => algorithms.h.old} | 0 ...el_interfaces.h => label_interfaces.h.old} | 0 .../{node_labelled.h => node_labelled.h.old} | 0 ...faces.h => node_labelled_interfaces.h.old} | 0 ...belled_open.h => node_labelled_open.h.old} | 0 ...pen_algorithms.h => open_algorithms.h.old} | 0 .../{open_views.h => open_views.h.old} | 0 .../utils/graph/labelled/output_labelled.h | 142 ----- .../labelled/output_labelled_interfaces.h | 39 -- ...lled_open.h => output_labelled_open.h.old} | 0 ... => output_labelled_open_interfaces.h.old} | 0 ...ard_labelled.h => standard_labelled.h.old} | 0 ...s.h => standard_labelled_interfaces.h.old} | 0 ...nordered_label.h => unordered_label.h.old} | 0 ...aphs.h => unordered_labelled_graphs.h.old} | 0 .../graph/labelled/{views.h => views.h.old} | 0 ...{labelled_graphs.h => labelled_graphs.old} | 0 .../adjacency_multidigraph.h.old} | 0 .../i_multidigraph.h.old} | 0 .../i_multidigraph_view.h.old} | 0 .../multi_di_edge.dtg.h.old} | 10 +- .../multi_di_edge.struct.toml.old} | 8 - .../multi_di_edge_query.dtg.h.old} | 7 +- .../multi_di_edge_query.h.old} | 0 .../multi_di_edge_query.struct.toml.old} | 5 + .../multidigraph.old/multi_di_output.h.old | 14 + .../multidigraph.h.old} | 2 +- .../multidigraph_view.h.old} | 0 .../include/utils/graph/node/node_source.h | 19 + .../adjacency_openmultidigraph.h.old} | 0 .../i_open_multidigraph.h.old} | 0 .../i_open_multidigraph_view.h.old} | 0 .../input_multi_di_edge.dtg.h.old} | 0 .../input_multi_di_edge.h.old} | 0 .../input_multi_di_edge.struct.toml.old} | 0 .../input_multi_di_edge_query.dtg.h.old} | 0 .../input_multi_di_edge_query.h.old} | 0 ...input_multi_di_edge_query.struct.toml.old} | 0 .../open_multi_di_edge.dtg.h.old} | 0 .../open_multi_di_edge.h.old} | 0 .../open_multi_di_edge.variant.toml.old} | 0 .../open_multi_di_edge_query.dtg.h.old} | 0 .../open_multi_di_edge_query.h.old} | 0 .../open_multi_di_edge_query.struct.toml.old} | 0 .../open_multidigraph.h.old} | 0 .../open_multidigraph_view.h.old} | 0 .../output_multi_di_edge.dtg.h.old} | 0 .../output_multi_di_edge.h.old} | 0 .../output_multi_di_edge.struct.toml.old} | 0 .../output_multi_di_edge_query.dtg.h.old} | 0 .../output_multi_di_edge_query.h.old} | 0 ...utput_multi_di_edge_query.struct.toml.old} | 0 lib/utils/include/utils/graph/query_set.h | 37 +- .../graph/serial_parallel/parallel.dtg.h | 52 ++ .../graph/serial_parallel/parallel.fwd.h | 10 + .../serial_parallel/parallel.struct.toml | 29 + .../utils/graph/serial_parallel/serial.dtg.h | 52 ++ .../utils/graph/serial_parallel/serial.fwd.h | 10 + .../graph/serial_parallel/serial.struct.toml | 29 + .../serial_parallel_decomposition.dtg.h | 127 ++++ ...serial_parallel_decomposition.variant.toml | 23 + .../graph/serial_parallel/serialparallel.h | 40 ++ .../graph/serial_parallel/split_type.dtg.h | 40 ++ .../serial_parallel/split_type.enum.toml | 14 + .../include/utils/graph/serialparallel.h | 53 -- .../i_upward_open_multidigraph.h.old} | 0 .../i_upward_open_multidigraph_view.h.old} | 0 .../upward_open_multi_di_edge.dtg.h.old} | 0 .../upward_open_multi_di_edge.h.old} | 0 ...pward_open_multi_di_edge.variant.toml.old} | 0 ...upward_open_multi_di_edge_query.dtg.h.old} | 0 ..._open_multi_di_edge_query.struct.toml.old} | 0 .../upward_open_multidigraph.h.old} | 0 .../upward_open_multidigraph_view.h.old} | 0 lib/utils/include/utils/graph/views/views.h | 292 ++++----- lib/utils/include/utils/hash-utils-core.h | 69 +-- lib/utils/include/utils/hash-utils.h | 32 +- lib/utils/include/utils/hash/pair.h | 23 + lib/utils/include/utils/hash/tuple.h | 43 ++ lib/utils/include/utils/hash/vector.h | 20 + ...h.cc => adjacency_openmultidigraph.cc.old} | 0 .../{open_graphs.cc => open_graphs.cc.old} | 0 lib/utils/src/utils/fmt/variant.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 584 +++++++++--------- .../graph/dataflow_graph/dataflow_edge.dtg.cc | 2 - .../dataflow_graph/dataflow_edge_query.cc | 23 + .../dataflow_graph/dataflow_edge_query.dtg.cc | 4 +- .../graph/dataflow_graph/dataflow_graph.cc | 20 + .../dataflow_graph/dataflow_input.dtg.cc | 2 +- .../dataflow_graph/dataflow_output.dtg.cc | 3 +- .../dataflow_output_query.dtg.cc | 4 +- .../dataflow_graph/i_dataflow_graph_view.cc | 4 +- .../dataflow_graph/node_added_result.dtg.cc | 7 +- .../utils/graph/digraph/adjacency_digraph.cc | 2 +- .../src/utils/graph/digraph/di_input.dtg.cc | 1 - .../src/utils/graph/digraph/di_output.dtg.cc | 4 +- .../utils/graph/digraph/directed_edge.dtg.cc | 1 - .../directed_edge_query.cc | 2 +- .../graph/digraph/directed_edge_query.dtg.cc | 16 +- ...=> downward_open_multi_di_edge.dtg.cc.old} | 0 ...nward_open_multi_di_edge_query.dtg.cc.old} | 0 ...cc => i_downward_open_multidigraph.cc.old} | 0 ... i_downward_open_multidigraph_view.cc.old} | 0 lib/utils/src/utils/graph/labelled_graphs.cc | 3 - lib/utils/src/utils/graph/multidiedge.cc | 17 - ...graph.cc => adjacency_multidigraph.cc.old} | 0 ..._multidigraph.cc => i_multidigraph.cc.old} | 0 ...aph_view.cc => i_multidigraph_view.cc.old} | 0 .../graph/multidigraph/multi_di_edge.dtg.cc | 73 --- .../multidigraph/multi_di_edge.dtg.cc.old | 59 ++ ...ge_query.cc => multi_di_edge_query.cc.old} | 0 ....dtg.cc => multi_di_edge_query.dtg.cc.old} | 29 +- .../{multidigraph.cc => multidigraph.cc.old} | 0 lib/utils/src/utils/graph/node/graph.cc | 2 +- lib/utils/src/utils/graph/node/graph_view.cc | 2 +- .../src/utils/graph/node/i_graph_view.cc | 2 +- lib/utils/src/utils/graph/node/node.dtg.cc | 1 - .../src/utils/graph/node/node_query.dtg.cc | 2 - lib/utils/src/utils/graph/node/node_source.cc | 15 + ..._di_edge.cc => input_multi_di_edge.cc.old} | 0 ....dtg.cc => input_multi_di_edge.dtg.cc.old} | 0 ...ry.cc => input_multi_di_edge_query.cc.old} | 0 ...c => input_multi_di_edge_query.dtg.cc.old} | 0 ...i_di_edge.cc => open_multi_di_edge.cc.old} | 0 ...e.dtg.cc => open_multi_di_edge.dtg.cc.old} | 0 ...ery.cc => open_multi_di_edge_query.cc.old} | 0 ...cc => open_multi_di_edge_query.dtg.cc.old} | 0 ...di_edge.cc => output_multi_di_edge.cc.old} | 0 ...dtg.cc => output_multi_di_edge.dtg.cc.old} | 0 ...y.cc => output_multi_di_edge_query.cc.old} | 0 ... => output_multi_di_edge_query.dtg.cc.old} | 0 .../intermediate_sp_decomposition_tree.dtg.cc | 65 ++ .../intermediate_sp_decomposition_tree.dtg.h | 48 ++ ...rmediate_sp_decomposition_tree.struct.toml | 26 + .../graph/serial_parallel/parallel.dtg.cc | 66 ++ .../utils/graph/serial_parallel/serial.dtg.cc | 65 ++ .../serial_parallel_decomposition.dtg.cc | 88 +++ .../graph/serial_parallel/serialparallel.cc | 194 ++++++ .../serialparallel_internal.cc | 123 ++++ .../serial_parallel/serialparallel_internal.h | 29 + .../serial_parallel/sink_settings.dtg.cc | 72 +++ .../graph/serial_parallel/sink_settings.dtg.h | 40 ++ .../serial_parallel/sink_settings.enum.toml | 14 + .../serial_parallel/source_settings.dtg.cc | 72 +++ .../serial_parallel/source_settings.dtg.h | 40 ++ .../serial_parallel/source_settings.enum.toml | 14 + .../graph/serial_parallel/split_type.dtg.cc | 70 +++ lib/utils/src/utils/graph/serialparallel.cc | 325 ---------- .../src/utils/graph/serialparallel_internal.h | 42 -- .../undirected/undirected_edge_query.dtg.cc | 2 - ...h.cc => i_upward_open_multidigraph.cc.old} | 0 ...=> i_upward_open_multidigraph_view.cc.old} | 0 ...ge.cc => upward_open_multi_di_edge.cc.old} | 0 ...c => upward_open_multi_di_edge.dtg.cc.old} | 0 ...pward_open_multi_di_edge_query.dtg.cc.old} | 0 .../utils/graph/views/join_node_key.dtg.cc | 2 - .../src/utils/graph/{ => views}/views.cc | 482 +++++++-------- 267 files changed, 2714 insertions(+), 1828 deletions(-) create mode 100644 lib/utils/include/utils/fmt/variant.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms.h rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multi_di_edge.dtg.h => downward_open_multidigraph.old/downward_open_multi_di_edge.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multi_di_edge.variant.toml => downward_open_multidigraph.old/downward_open_multi_di_edge.variant.toml.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h => downward_open_multidigraph.old/downward_open_multi_di_edge_query.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml => downward_open_multidigraph.old/downward_open_multi_di_edge_query.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multidigraph.h => downward_open_multidigraph.old/downward_open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/downward_open_multidigraph_view.h => downward_open_multidigraph.old/downward_open_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/i_downward_open_multidigraph.h => downward_open_multidigraph.old/i_downward_open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{downward_open_multidigraph/i_downward_open_multidigraph_view.h => downward_open_multidigraph.old/i_downward_open_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{algorithms.h => algorithms.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{label_interfaces.h => label_interfaces.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{node_labelled.h => node_labelled.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{node_labelled_interfaces.h => node_labelled_interfaces.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{node_labelled_open.h => node_labelled_open.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{open_algorithms.h => open_algorithms.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{open_views.h => open_views.h.old} (100%) delete mode 100644 lib/utils/include/utils/graph/labelled/output_labelled.h delete mode 100644 lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h rename lib/utils/include/utils/graph/labelled/{output_labelled_open.h => output_labelled_open.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{output_labelled_open_interfaces.h => output_labelled_open_interfaces.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{standard_labelled.h => standard_labelled.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{standard_labelled_interfaces.h => standard_labelled_interfaces.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{unordered_label.h => unordered_label.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{unordered_labelled_graphs.h => unordered_labelled_graphs.h.old} (100%) rename lib/utils/include/utils/graph/labelled/{views.h => views.h.old} (100%) rename lib/utils/include/utils/graph/{labelled_graphs.h => labelled_graphs.old} (100%) rename lib/utils/include/utils/graph/{multidigraph/adjacency_multidigraph.h => multidigraph.old/adjacency_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{multidigraph/i_multidigraph.h => multidigraph.old/i_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{multidigraph/i_multidigraph_view.h => multidigraph.old/i_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/{multidigraph/multi_di_edge.dtg.h => multidigraph.old/multi_di_edge.dtg.h.old} (80%) rename lib/utils/include/utils/graph/{multidigraph/multi_di_edge.struct.toml => multidigraph.old/multi_di_edge.struct.toml.old} (65%) rename lib/utils/include/utils/graph/{multidigraph/multi_di_edge_query.dtg.h => multidigraph.old/multi_di_edge_query.dtg.h.old} (84%) rename lib/utils/include/utils/graph/{multidigraph/multi_di_edge_query.h => multidigraph.old/multi_di_edge_query.h.old} (100%) rename lib/utils/include/utils/graph/{multidigraph/multi_di_edge_query.struct.toml => multidigraph.old/multi_di_edge_query.struct.toml.old} (72%) create mode 100644 lib/utils/include/utils/graph/multidigraph.old/multi_di_output.h.old rename lib/utils/include/utils/graph/{multidigraph/multidigraph.h => multidigraph.old/multidigraph.h.old} (96%) rename lib/utils/include/utils/graph/{multidigraph/multidigraph_view.h => multidigraph.old/multidigraph_view.h.old} (100%) create mode 100644 lib/utils/include/utils/graph/node/node_source.h rename lib/utils/include/utils/graph/{open_multidigraph/adjacency_openmultidigraph.h => open_multidigraph.old/adjacency_openmultidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/i_open_multidigraph.h => open_multidigraph.old/i_open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/i_open_multidigraph_view.h => open_multidigraph.old/i_open_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge.dtg.h => open_multidigraph.old/input_multi_di_edge.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge.h => open_multidigraph.old/input_multi_di_edge.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge.struct.toml => open_multidigraph.old/input_multi_di_edge.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge_query.dtg.h => open_multidigraph.old/input_multi_di_edge_query.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge_query.h => open_multidigraph.old/input_multi_di_edge_query.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/input_multi_di_edge_query.struct.toml => open_multidigraph.old/input_multi_di_edge_query.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge.dtg.h => open_multidigraph.old/open_multi_di_edge.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge.h => open_multidigraph.old/open_multi_di_edge.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge.variant.toml => open_multidigraph.old/open_multi_di_edge.variant.toml.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge_query.dtg.h => open_multidigraph.old/open_multi_di_edge_query.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge_query.h => open_multidigraph.old/open_multi_di_edge_query.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multi_di_edge_query.struct.toml => open_multidigraph.old/open_multi_di_edge_query.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multidigraph.h => open_multidigraph.old/open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/open_multidigraph_view.h => open_multidigraph.old/open_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge.dtg.h => open_multidigraph.old/output_multi_di_edge.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge.h => open_multidigraph.old/output_multi_di_edge.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge.struct.toml => open_multidigraph.old/output_multi_di_edge.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge_query.dtg.h => open_multidigraph.old/output_multi_di_edge_query.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge_query.h => open_multidigraph.old/output_multi_di_edge_query.h.old} (100%) rename lib/utils/include/utils/graph/{open_multidigraph/output_multi_di_edge_query.struct.toml => open_multidigraph.old/output_multi_di_edge_query.struct.toml.old} (100%) create mode 100644 lib/utils/include/utils/graph/serial_parallel/parallel.dtg.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial.dtg.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial.fwd.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial.struct.toml create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml create mode 100644 lib/utils/include/utils/graph/serial_parallel/serialparallel.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/split_type.dtg.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml delete mode 100644 lib/utils/include/utils/graph/serialparallel.h rename lib/utils/include/utils/graph/{upward_open_multidigraph/i_upward_open_multidigraph.h => upward_open_multidigraph.old/i_upward_open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/i_upward_open_multidigraph_view.h => upward_open_multidigraph.old/i_upward_open_multidigraph_view.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multi_di_edge.dtg.h => upward_open_multidigraph.old/upward_open_multi_di_edge.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multi_di_edge.h => upward_open_multidigraph.old/upward_open_multi_di_edge.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multi_di_edge.variant.toml => upward_open_multidigraph.old/upward_open_multi_di_edge.variant.toml.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h => upward_open_multidigraph.old/upward_open_multi_di_edge_query.dtg.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml => upward_open_multidigraph.old/upward_open_multi_di_edge_query.struct.toml.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multidigraph.h => upward_open_multidigraph.old/upward_open_multidigraph.h.old} (100%) rename lib/utils/include/utils/graph/{upward_open_multidigraph/upward_open_multidigraph_view.h => upward_open_multidigraph.old/upward_open_multidigraph_view.h.old} (100%) create mode 100644 lib/utils/include/utils/hash/pair.h create mode 100644 lib/utils/include/utils/hash/tuple.h create mode 100644 lib/utils/include/utils/hash/vector.h rename lib/utils/src/graph/{adjacency_openmultidigraph.cc => adjacency_openmultidigraph.cc.old} (100%) rename lib/utils/src/graph/{open_graphs.cc => open_graphs.cc.old} (100%) create mode 100644 lib/utils/src/utils/fmt/variant.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc rename lib/utils/src/utils/graph/{directed_graph => digraph}/directed_edge_query.cc (95%) rename lib/utils/src/utils/graph/downward_open_multidigraph/{downward_open_multi_di_edge.dtg.cc => downward_open_multi_di_edge.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/downward_open_multidigraph/{downward_open_multi_di_edge_query.dtg.cc => downward_open_multi_di_edge_query.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/downward_open_multidigraph/{i_downward_open_multidigraph.cc => i_downward_open_multidigraph.cc.old} (100%) rename lib/utils/src/utils/graph/downward_open_multidigraph/{i_downward_open_multidigraph_view.cc => i_downward_open_multidigraph_view.cc.old} (100%) delete mode 100644 lib/utils/src/utils/graph/labelled_graphs.cc delete mode 100644 lib/utils/src/utils/graph/multidiedge.cc rename lib/utils/src/utils/graph/multidigraph/{adjacency_multidigraph.cc => adjacency_multidigraph.cc.old} (100%) rename lib/utils/src/utils/graph/multidigraph/{i_multidigraph.cc => i_multidigraph.cc.old} (100%) rename lib/utils/src/utils/graph/multidigraph/{i_multidigraph_view.cc => i_multidigraph_view.cc.old} (100%) delete mode 100644 lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc.old rename lib/utils/src/utils/graph/multidigraph/{multi_di_edge_query.cc => multi_di_edge_query.cc.old} (100%) rename lib/utils/src/utils/graph/multidigraph/{multi_di_edge_query.dtg.cc => multi_di_edge_query.dtg.cc.old} (61%) rename lib/utils/src/utils/graph/multidigraph/{multidigraph.cc => multidigraph.cc.old} (100%) create mode 100644 lib/utils/src/utils/graph/node/node_source.cc rename lib/utils/src/utils/graph/open_multidigraph/{input_multi_di_edge.cc => input_multi_di_edge.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{input_multi_di_edge.dtg.cc => input_multi_di_edge.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{input_multi_di_edge_query.cc => input_multi_di_edge_query.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{input_multi_di_edge_query.dtg.cc => input_multi_di_edge_query.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{open_multi_di_edge.cc => open_multi_di_edge.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{open_multi_di_edge.dtg.cc => open_multi_di_edge.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{open_multi_di_edge_query.cc => open_multi_di_edge_query.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{open_multi_di_edge_query.dtg.cc => open_multi_di_edge_query.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{output_multi_di_edge.cc => output_multi_di_edge.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{output_multi_di_edge.dtg.cc => output_multi_di_edge.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{output_multi_di_edge_query.cc => output_multi_di_edge_query.cc.old} (100%) rename lib/utils/src/utils/graph/open_multidigraph/{output_multi_di_edge_query.dtg.cc => output_multi_di_edge_query.dtg.cc.old} (100%) create mode 100644 lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h create mode 100644 lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/src/utils/graph/serial_parallel/parallel.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serial.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serialparallel.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h create mode 100644 lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.h create mode 100644 lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml create mode 100644 lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.h create mode 100644 lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml create mode 100644 lib/utils/src/utils/graph/serial_parallel/split_type.dtg.cc delete mode 100644 lib/utils/src/utils/graph/serialparallel.cc delete mode 100644 lib/utils/src/utils/graph/serialparallel_internal.h rename lib/utils/src/utils/graph/upward_open_multidigraph/{i_upward_open_multidigraph.cc => i_upward_open_multidigraph.cc.old} (100%) rename lib/utils/src/utils/graph/upward_open_multidigraph/{i_upward_open_multidigraph_view.cc => i_upward_open_multidigraph_view.cc.old} (100%) rename lib/utils/src/utils/graph/upward_open_multidigraph/{upward_open_multi_di_edge.cc => upward_open_multi_di_edge.cc.old} (100%) rename lib/utils/src/utils/graph/upward_open_multidigraph/{upward_open_multi_di_edge.dtg.cc => upward_open_multi_di_edge.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/upward_open_multidigraph/{upward_open_multi_di_edge_query.dtg.cc => upward_open_multi_di_edge_query.dtg.cc.old} (100%) rename lib/utils/src/utils/graph/{ => views}/views.cc (52%) diff --git a/flake.lock b/flake.lock index 3a9fffbdd1..1c5d00bbae 100644 --- a/flake.lock +++ b/flake.lock @@ -43,17 +43,16 @@ ] }, "locked": { + "dirtyRev": "5dc9d970f0fe67e65146b2ba1d7aa44d11324d48-dirty", + "dirtyShortRev": "5dc9d97-dirty", "lastModified": 1718643207, - "narHash": "sha256-VhPjZi4Zl4XgaagzqI0Z2bgFoJhF2SblwUq4eZR08DU=", - "owner": "lockshaw", - "repo": "proj", - "rev": "5dc9d970f0fe67e65146b2ba1d7aa44d11324d48", - "type": "github" + "narHash": "sha256-tZwDcHotcUrvQlyBxavhQRpMCFAiu100V/q9YeHJhdM=", + "type": "git", + "url": "file:///home/lockshaw/x/proj/proj" }, "original": { - "owner": "lockshaw", - "repo": "proj", - "type": "github" + "type": "git", + "url": "file:///home/lockshaw/x/proj/proj" } }, "root": { diff --git a/flake.nix b/flake.nix index 2dc005b113..e6cbb2d5d8 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,8 @@ flake-utils.url = "github:numtide/flake-utils"; proj-repo = { - url = "github:lockshaw/proj"; + # url = "github:lockshaw/proj"; + url = "git+file:///home/lockshaw/x/proj/proj"; inputs.nixpkgs.follows = "nixpkgs"; inputs.flake-utils.follows = "flake-utils"; }; diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc index a5a66b1a77..3ed7b70f63 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc index be4507677b..39353ed13e 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc @@ -9,11 +9,6 @@ #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" -#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" -#include "op-attrs/shard_parallel_dim.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc index 85fff2518c..cd80d2b99a 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/ops/broadcast.dtg.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc index 423fc2e046..661aca32a9 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/ops/cast_attrs.dtg.h" -#include "op-attrs/datatype.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc index 198da728bf..d29ef9885c 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/combine_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc index 2bbd9ba50e..85eaca3c34 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/concat_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc index 90df5ae1a3..e3841b5093 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc index efb73dba1b..def7ccd81e 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include "op-attrs/shard_parallel_dim.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc index 696fe08a6f..3a0df0dfb4 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "op-attrs/ops/conv_2d_attrs.dtg.h" -#include "op-attrs/activation.dtg.h" -#include "utils/json.h" -#include #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc index 568371c4fe..d567b56f13 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/element_binary_attrs.dtg.h" -#include "op-attrs/datatype.h" -#include "op-attrs/operator_type.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc index 4c246906eb..150edabb3d 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/element_unary_attrs.dtg.h" -#include "op-attrs/operator_type.h" -#include "utils/json.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc index 8f5778d794..f14e2966ea 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "op-attrs/ops/embedding_attrs.dtg.h" -#include "op-attrs/aggregate_op.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc index a056d812ca..9386bd0cc3 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/gather_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc index 66db8e278a..8da5cd5e98 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "op-attrs/ops/layer_norm_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc index 3099a6c7e4..1cd41522e9 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -9,10 +9,6 @@ #include "op-attrs/ops/linear_attrs.dtg.h" -#include "op-attrs/activation.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include "op-attrs/regularizer_attrs.dtg.h" -#include "utils/json.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc index 67a46ef5fb..f81f792cb5 100644 --- a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/ops/parallel_attention_inputs.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc index 057b030a96..4565cba760 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/pool_2d_attrs.dtg.h" -#include "op-attrs/activation.dtg.h" -#include "op-attrs/pool_op.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc index c365819440..650d36ba36 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -9,10 +9,6 @@ #include "op-attrs/ops/reduce_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/operator_type.dtg.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc index 110e16c36a..ded4b6b050 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/repartition_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc index de18a192ff..948062fe76 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/ops/reshape_attrs.dtg.h" -#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc index 9e8079d666..be615f5cc2 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/reverse_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc index 1d4d396ef3..a1649de5ec 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/ops/softmax_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc index bdae47681e..32f1fe16e2 100644 --- a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "op-attrs/ops/split_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc index 23e78beb7a..60f26d2c30 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "op-attrs/ops/transpose_attrs.dtg.h" -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index 3cad12b4fa..3e92b3a457 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -9,13 +9,7 @@ #include "op-attrs/parallel_tensor_dims.dtg.h" -#include "op-attrs/dim_ordered.h" -#include "op-attrs/replica_parallel_dim_set.dtg.h" -#include "op-attrs/shard_parallel_dim.dtg.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/unordered_map.h" #include -#include namespace FlexFlow { ParallelTensorDims::ParallelTensorDims( diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index 3a509de7f0..88089c828f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/datatype.h" -#include "op-attrs/parallel_tensor_dims.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc index ed45115c77..945672b44b 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/replica_parallel_dim.dtg.h" -#include "op-attrs/replica_type.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc index 1d11006523..851d82a42e 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/replica_parallel_dim_set.dtg.h" -#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" -#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc index ab78d44805..cbde658f0b 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -9,7 +9,6 @@ #include "op-attrs/tensor_dims.dtg.h" -#include "op-attrs/dim_ordered.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc index 0c725dc443..7f448e3b53 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -9,8 +9,6 @@ #include "op-attrs/tensor_shape.dtg.h" -#include "op-attrs/datatype.dtg.h" -#include "op-attrs/tensor_dims.dtg.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc index 799cf55908..327cdc964a 100644 --- a/lib/pcg/src/pcg/computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -9,10 +9,6 @@ #include "pcg/computation_graph.dtg.h" -#include "pcg/dataflow_graph/dataflow_graph.h" -#include "pcg/layer_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" - namespace FlexFlow { ComputationGraph::ComputationGraph( ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, diff --git a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc index 1d00b4f32e..ac7e16538e 100644 --- a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc @@ -9,9 +9,6 @@ #include "pcg/computation_graph/layer_added_result.dtg.h" -#include "pcg/layer_guid_t.dtg.h" -#include "pcg/tensor_guid_t.dtg.h" -#include "utils/fmt/vector.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc b/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc index d4b926c0a6..6cb8f8fa83 100644 --- a/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc +++ b/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc @@ -9,10 +9,7 @@ #include "pcg/dataflow_graph/operator_added_result.dtg.h" -#include "utils/fmt/vector.h" -#include "utils/graph.h" #include -#include namespace FlexFlow { OperatorAddedResult::OperatorAddedResult( diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc index 41ad9e4e63..626cca4f95 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc @@ -9,12 +9,7 @@ #include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" #include -#include -#include namespace FlexFlow { V1MultiDiGraph::V1MultiDiGraph( diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc index 4c908ae2f1..d80e433b24 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc @@ -9,12 +9,7 @@ #include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" #include -#include -#include namespace FlexFlow { V1OperatorGraph::V1OperatorGraph( diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc index 6c1ae1dfac..2d685165bb 100644 --- a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -9,8 +9,6 @@ #include "pcg/initializers/constant_initializer_attrs.dtg.h" -#include "op-attrs/datatype.h" -#include "utils/json.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc index 4497d849e6..27f2125b12 100644 --- a/lib/pcg/src/pcg/layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -9,10 +9,6 @@ #include "pcg/layer_attrs.dtg.h" -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "utils/json.h" -#include "utils/stack_string.h" -#include #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc index 706de4e376..91343f704f 100644 --- a/lib/pcg/src/pcg/layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/layer_guid_t.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/machine_view.dtg.cc b/lib/pcg/src/pcg/machine_view.dtg.cc index de577fe409..081473c0a7 100644 --- a/lib/pcg/src/pcg/machine_view.dtg.cc +++ b/lib/pcg/src/pcg/machine_view.dtg.cc @@ -9,8 +9,6 @@ #include "pcg/machine_view.dtg.h" -#include "pcg/device_id_t.dtg.h" -#include "pcg/strided_rectangle.dtg.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc index 7d31197f9d..8417b77b5b 100644 --- a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/operator_graph/operator_graph_input.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc index 2b5a2abbcd..8891eae5c3 100644 --- a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/operator_graph/operator_graph_output.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc index cdc9130979..6a1fb33193 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc @@ -9,10 +9,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/dataflow_graph/dataflow_graph.h" -#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - namespace FlexFlow { ParallelComputationGraph::ParallelComputationGraph( ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc index 7b2dbf8de1..bf7bb7ccc2 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_added_result.dtg.cc @@ -9,11 +9,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include "utils/fmt/vector.h" #include -#include namespace FlexFlow { ParallelLayerAddedResult::ParallelLayerAddedResult( diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc index 5a982b13ab..9f2585f4e3 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_attrs.dtg.cc @@ -9,9 +9,6 @@ #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc index df575ebc98..794a2078e7 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc index 88f7ed4d3c..7ac977bc94 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.cc @@ -9,11 +9,6 @@ #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/param_sync.dtg.h" -#include "pcg/create_grad.dtg.h" -#include "pcg/initializer_attrs.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc index 38c2970225..bc10f450c2 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include "utils/graph/multidiedge.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc index d50c5861ea..b933957e71 100644 --- a/lib/pcg/src/pcg/strided_rectangle.dtg.cc +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -9,8 +9,6 @@ #include "pcg/strided_rectangle.dtg.h" -#include "op-attrs/dim_ordered.h" -#include "pcg/strided_rectangle_side.dtg.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc index e2533f7a21..ea278b950f 100644 --- a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc +++ b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/strided_rectangle_side.dtg.h" -#include "pcg/num_points_t.dtg.h" #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/tensor_attrs.dtg.cc b/lib/pcg/src/pcg/tensor_attrs.dtg.cc index e75fe506f6..2b92c68d35 100644 --- a/lib/pcg/src/pcg/tensor_attrs.dtg.cc +++ b/lib/pcg/src/pcg/tensor_attrs.dtg.cc @@ -9,10 +9,6 @@ #include "pcg/tensor_attrs.dtg.h" -#include "op-attrs/param_sync.dtg.h" -#include "op-attrs/tensor_shape.dtg.h" -#include "pcg/initializer_attrs.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc index c8fbb7299b..096c9b4374 100644 --- a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -9,7 +9,6 @@ #include "pcg/tensor_guid_t.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc index 2956dad2c4..0e9ab62c69 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc @@ -9,9 +9,6 @@ #include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" -#include "substitutions/constraint_type.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc index 67e3761515..41fd4f7868 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc @@ -9,7 +9,6 @@ #include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc index 2879aca911..243075f250 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc @@ -9,7 +9,6 @@ #include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc index 7aca1e75fc..8caa7bd720 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -9,10 +9,7 @@ #include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" -#include "utils/fmt.h" #include -#include namespace FlexFlow { OperatorAttributePattern::OperatorAttributePattern( diff --git a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc index c0dc667822..47b5fcdcda 100644 --- a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc @@ -9,7 +9,6 @@ #include "substitutions/output_graph/attr_constant.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc index 7d07bf9218..3e945beded 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc @@ -9,9 +9,6 @@ #include "substitutions/output_graph/output_graph_expr.dtg.h" -#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { OutputGraphExpr::OutputGraphExpr( ::FlexFlow::NodeLabelledOpenMultiDiGraph< diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc index 2864ccbfac..952ed375dc 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/output_graph/output_operator_attr_access.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" -#include "utils/graph.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc index 98183c9a14..04901657e0 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc @@ -9,10 +9,7 @@ #include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" -#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" #include -#include namespace FlexFlow { OutputOperatorAttrsAssignment::OutputOperatorAttrsAssignment( diff --git a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc index 7133ab42a7..9056a5ebdd 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc @@ -9,10 +9,6 @@ #include "substitutions/pcg_pattern.dtg.h" -#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { PCGPattern::PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< ::FlexFlow::OperatorAttributePattern, diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc index 83baef2cfc..eabee4a906 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc @@ -9,10 +9,6 @@ #include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "pcg/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { SubParallelComputationGraph::SubParallelComputationGraph( ::FlexFlow::OutputLabelledOpenMultiDiGraph< diff --git a/lib/substitutions/src/substitutions/substitution.dtg.cc b/lib/substitutions/src/substitutions/substitution.dtg.cc index 67d39d6ff7..81c8a572df 100644 --- a/lib/substitutions/src/substitutions/substitution.dtg.cc +++ b/lib/substitutions/src/substitutions/substitution.dtg.cc @@ -9,9 +9,6 @@ #include "substitutions/substitution.dtg.h" -#include "substitutions/output_graph/output_graph_expr.dtg.h" -#include "substitutions/pcg_pattern.dtg.h" - namespace FlexFlow { Substitution::Substitution( ::FlexFlow::PCGPattern const &pcg_pattern, diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc index 17147b3681..3e808a405b 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc @@ -9,9 +9,6 @@ #include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" -#include "substitutions/constraint_type.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc index c7e81718ed..107ec7d531 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc @@ -9,7 +9,6 @@ #include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc index 52a61a8a87..9e663eed39 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc @@ -9,7 +9,6 @@ #include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc index 8f96fd49b8..da2fbd7cfc 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc @@ -9,10 +9,7 @@ #include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" -#include "utils/hash-utils.h" #include -#include namespace FlexFlow { TensorAttributePattern::TensorAttributePattern( diff --git a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc index 401c738d88..c4dc578c0a 100644 --- a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/closed_pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { ClosedPatternEdge::ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge) : raw_edge(raw_edge) {} diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc index 65c87db0e4..4983d4b91c 100644 --- a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { DownwardOpenPatternEdge::DownwardOpenPatternEdge( ::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge) diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc index 4da15179da..30e7b78725 100644 --- a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc @@ -9,10 +9,6 @@ #include "substitutions/unlabelled/edge_splits.dtg.h" -#include "utils/bidict.h" -#include "utils/graph.h" -#include - namespace FlexFlow { UnlabelledPatternEdgeSplits::UnlabelledPatternEdgeSplits( ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc index e46becf4be..b74579dadd 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/input_pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { InputPatternEdge::InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge) : raw_edge(raw_edge) {} diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc index 613159ad83..650bc0ec68 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc @@ -9,11 +9,6 @@ #include "substitutions/unlabelled/match_additional_criterion.dtg.h" -#include "substitutions/unlabelled/pattern_edge.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" -#include "utils/graph.h" -#include - namespace FlexFlow { MatchAdditionalCriterion::MatchAdditionalCriterion( std::function const diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc index 152115d52a..b1be52cffc 100644 --- a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/output_pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { OutputPatternEdge::OutputPatternEdge( ::FlexFlow::OutputMultiDiEdge const &raw_edge) diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc index a19e5bb6d1..51ea760af3 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { PatternEdge::PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge) : raw_edge(raw_edge) {} diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc index b2cd557c06..9eb9e2bfbc 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/pattern_node.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { PatternNode::PatternNode(::FlexFlow::Node const &raw_node) : raw_node(raw_node) {} diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc index d678a1edfe..b0e9795d93 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc @@ -9,10 +9,7 @@ #include "substitutions/unlabelled/pattern_split.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" -#include "utils/graph.h" #include -#include namespace FlexFlow { PatternSplit::PatternSplit( diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc index 019209ee86..0bebd8dd91 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { UnlabelledGraphPattern::UnlabelledGraphPattern( ::FlexFlow::OpenMultiDiGraphView const &raw_graph) diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc index 1fe34ed778..fd16559b09 100644 --- a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc @@ -9,8 +9,6 @@ #include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" -#include "utils/graph.h" - namespace FlexFlow { UpwardOpenPatternEdge::UpwardOpenPatternEdge( ::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge) diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h new file mode 100644 index 0000000000..8ec4d4b210 --- /dev/null +++ b/lib/utils/include/utils/fmt/variant.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H + +#include +#include + +namespace fmt { + +template +struct formatter, Char> + /* std::enable_if_t>::value>> */ + : formatter<::std::string> { + template + auto format(std::variant const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result = std::visit([&](auto &&x) { return fmt::to_string(x); }, m); + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::variant const &v) { + return s << fmt::to_string(v); +} + +} // namespace FlexFlow + + +#endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 52a1f71d31..6f64c3459d 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -2,12 +2,11 @@ #define _FLEXFLOW_UTILS_GRAPH_ALGORITHMS_H #include "utils/graph/digraph/digraph.h" -#include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/graph.h" #include "utils/graph/undirected/undirected_graph.h" -#include "utils/graph/open_multidigraph/open_multidigraph.h" -#include "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" -#include "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" +// #include "utils/graph/open_multidigraph/open_multidigraph.h" +// #include "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" +// #include "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" #include "utils/dot_file.h" namespace FlexFlow { @@ -15,48 +14,48 @@ namespace FlexFlow { std::vector add_nodes(Graph &, int); std::vector add_nodes(UndirectedGraph &, int); std::vector add_nodes(DiGraph &, int); -std::vector add_nodes(MultiDiGraph &, int); -std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); +// std::vector add_nodes(MultiDiGraph &, int); +// std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); std::unordered_set get_nodes(GraphView const &); -std::unordered_set get_nodes(OpenMultiDiEdge const &); +// std::unordered_set get_nodes(OpenMultiDiEdge const &); std::unordered_set query_nodes(GraphView const &, std::unordered_set const &); -void remove_node(MultiDiGraph &, Node const &); +// void remove_node(MultiDiGraph &, Node const &); void remove_node(DiGraph &, Node const &); void remove_node(UndirectedGraph &, Node const &); -void remove_node_if_unused(MultiDiGraph &, Node const &); +// void remove_node_if_unused(MultiDiGraph &, Node const &); void remove_node_if_unused(DiGraph &, Node const &); void remove_node_if_unused(UndirectedGraph &, Node const &); -void contract_node_inplace(MultiDiGraph &, Node const &from, Node const &into); +// void contract_node_inplace(MultiDiGraph &, Node const &from, Node const &into); void contract_node_inplace(DiGraph &, Node const &from, Node const &into); void contract_node_inplace(UndirectedGraph &, Node const &from, Node const &into); -void contract_out_node_inplace(MultiDiGraph &, Node const &); +// void contract_out_node_inplace(MultiDiGraph &, Node const &); void contract_out_node_inplace(DiGraph &, Node const &); void contract_out_node_inplace(UndirectedGraph &, Node const &); -MultiDiGraphView contract_out_node(MultiDiGraphView const &, Node const &); +// MultiDiGraphView contract_out_node(MultiDiGraphView const &, Node const &); DiGraphView contract_out_node(DiGraphView const &, Node const &); UndirectedGraphView contract_out_node(UndirectedGraphView const &, Node const &); -MultiDiGraphView - contract_node(MultiDiGraphView const &, Node const &from, Node const &into); +// MultiDiGraphView +// contract_node(MultiDiGraphView const &, Node const &from, Node const &into); DiGraphView contract_node(DiGraphView const &, Node const &from, Node const &into); UndirectedGraphView contract_node(UndirectedGraphView const &, Node const &from, Node const &into); -MultiDiGraphView apply_contraction(MultiDiGraphView const &, - std::unordered_map const &); +// MultiDiGraphView apply_contraction(MultiDiGraphView const &, +// std::unordered_map const &); DiGraphView apply_contraction(DiGraphView const &, std::unordered_map const &); UndirectedGraphView apply_contraction(UndirectedGraphView const &, @@ -65,69 +64,69 @@ UndirectedGraphView apply_contraction(UndirectedGraphView const &, std::size_t num_nodes(GraphView const &); bool empty(GraphView const &); -void add_edges(MultiDiGraph &, std::vector const &); +// void add_edges(MultiDiGraph &, std::vector const &); void add_edges(DiGraph &, std::vector const &); void add_edges(UndirectedGraph &, std::vector const &); bool contains_node(GraphView const &, Node const &); -bool contains_edge(MultiDiGraphView const &, MultiDiEdge const &); +// bool contains_edge(MultiDiGraphView const &, MultiDiEdge const &); bool contains_edge(DiGraphView const &, DirectedEdge const &); bool contains_edge(UndirectedGraphView const &, UndirectedEdge const &); -void remove_edges(MultiDiGraph &, std::unordered_set const &); +// void remove_edges(MultiDiGraph &, std::unordered_set const &); void remove_edges(DiGraph &, std::unordered_set const &); void remove_edges(UndirectedGraph &, std::vector const &); std::unordered_set get_endpoints(UndirectedEdge const &); -std::unordered_set get_edges(MultiDiGraphView const &); +// std::unordered_set get_edges(MultiDiGraphView const &); std::unordered_set get_edges(DiGraphView const &); std::unordered_set get_edges(UndirectedGraphView const &); -std::unordered_set - get_edges(UpwardOpenMultiDiGraphView const &); -std::unordered_set - get_edges(DownwardOpenMultiDiGraphView const &); -std::unordered_set get_edges(OpenMultiDiGraphView const &); +// std::unordered_set +// get_edges(UpwardOpenMultiDiGraphView const &); +// std::unordered_set +// get_edges(DownwardOpenMultiDiGraphView const &); +// std::unordered_set get_edges(OpenMultiDiGraphView const &); std::unordered_set get_node_edges(UndirectedGraphView const &, Node const &); -std::unordered_set - get_open_outputs(OpenMultiDiGraphView const &); -std::unordered_set - get_open_inputs(OpenMultiDiGraphView const &); +// std::unordered_set +// get_open_outputs(OpenMultiDiGraphView const &); +// std::unordered_set +// get_open_inputs(OpenMultiDiGraphView const &); -std::unordered_set get_incoming_edges(MultiDiGraphView const &, - Node const &); +// std::unordered_set get_incoming_edges(MultiDiGraphView const &, +// Node const &); std::unordered_set get_incoming_edges(DiGraphView const &, Node const &); -std::unordered_set - get_incoming_edges(UpwardOpenMultiDiGraphView const &, Node const &); -std::unordered_set - get_incoming_edges(DownwardOpenMultiDiGraphView const &, Node const &); -std::unordered_set - get_incoming_edges(OpenMultiDiGraphView const &, Node const &); - -std::unordered_set get_incoming_edges(MultiDiGraphView const &, - std::unordered_set); +// std::unordered_set +// get_incoming_edges(UpwardOpenMultiDiGraphView const &, Node const &); +// std::unordered_set +// get_incoming_edges(DownwardOpenMultiDiGraphView const &, Node const &); +// std::unordered_set +// get_incoming_edges(OpenMultiDiGraphView const &, Node const &); + +// std::unordered_set get_incoming_edges(MultiDiGraphView const &, +// std::unordered_set); std::unordered_set get_incoming_edges(DiGraphView const &, std::unordered_set const &); -std::unordered_set get_outgoing_edges(MultiDiGraphView const &, - Node const &); +// std::unordered_set get_outgoing_edges(MultiDiGraphView const &, +// Node const &); std::unordered_set get_outgoing_edges(DiGraphView const &, Node const &); -std::unordered_set - get_outgoing_edges(UpwardOpenMultiDiGraphView const &, Node const &); -std::unordered_set - get_outgoing_edges(DownwardOpenMultiDiGraphView const &, Node const &); -std::unordered_set - get_outgoing_edges(OpenMultiDiGraphView const &, Node const &); - -std::unordered_set - get_outgoing_edges(MultiDiGraphView const &, - std::unordered_set const &); +// std::unordered_set +// get_outgoing_edges(UpwardOpenMultiDiGraphView const &, Node const &); +// std::unordered_set +// get_outgoing_edges(DownwardOpenMultiDiGraphView const &, Node const &); +// std::unordered_set +// get_outgoing_edges(OpenMultiDiGraphView const &, Node const &); + +// std::unordered_set +// get_outgoing_edges(MultiDiGraphView const &, +// std::unordered_set const &); std::unordered_set get_outgoing_edges(DiGraphView const &, std::unordered_set const &); @@ -141,15 +140,15 @@ std::unordered_set get_predecessors(DiGraphView const &, Node const &); std::unordered_map> get_predecessors(DiGraphView const &, std::unordered_set const &); -Node get_src_node(MultiDiEdge const &); -Node get_dst_node(MultiDiEdge const &); -Node get_dst_node(InputMultiDiEdge const &); -Node get_src_node(OutputMultiDiEdge const &); +// Node get_src_node(MultiDiEdge const &); +// Node get_dst_node(MultiDiEdge const &); +// Node get_dst_node(InputMultiDiEdge const &); +// Node get_src_node(OutputMultiDiEdge const &); std::unordered_set get_neighbors(UndirectedGraphView const &, Node const &); std::unordered_set get_neighbors(DiGraphView const &, Node const &); -std::unordered_set get_neighbors(MultiDiGraphView const &, Node const &); +// std::unordered_set get_neighbors(MultiDiGraphView const &, Node const &); // return the set of nodes without incoming edges std::unordered_set get_sources(DiGraphView const &); @@ -157,14 +156,14 @@ std::unordered_set get_sources(DiGraphView const &); // return the set of nodes without outgoing edges std::unordered_set get_sinks(DiGraphView const &); -std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); -std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); -std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); -std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g); +// std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); +// std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); +// std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); +// std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g); -bool is_acyclic(MultiDiGraphView const &, std::unordered_set const &); +// bool is_acyclic(MultiDiGraphView const &, std::unordered_set const &); std::optional is_acyclic(DiGraphView const &); -std::optional is_acyclic(MultiDiGraphView const &); +// std::optional is_acyclic(MultiDiGraphView const &); std::unordered_map> get_dominators(DiGraphView const &); @@ -179,8 +178,8 @@ std::unordered_map> std::unordered_map> get_imm_post_dominators(DiGraphView const &); std::optional get_imm_post_dominator(DiGraphView const &, Node const &); -std::optional get_imm_post_dominator(MultiDiGraphView const &, - Node const &); +// std::optional get_imm_post_dominator(MultiDiGraphView const &, +// Node const &); std::optional get_imm_post_dominator(DiGraphView const &, std::unordered_set const &); @@ -197,11 +196,11 @@ std::vector get_topological_ordering(DiGraphView const &); std::vector get_unchecked_topological_ordering(DiGraphView const &); std::vector get_edge_topological_ordering(DiGraphView const &); -std::vector - get_edge_topological_ordering(MultiDiGraphView const &); +// std::vector +// get_edge_topological_ordering(MultiDiGraphView const &); -std::unordered_set> - get_weakly_connected_components(MultiDiGraphView const &); +// std::unordered_set> +// get_weakly_connected_components(MultiDiGraphView const &); std::unordered_set> get_weakly_connected_components(DiGraphView const &); std::unordered_set> @@ -213,35 +212,35 @@ std::unordered_set using GraphSplit = std::pair, std::unordered_set>; -std::pair split_edge(MultiDiEdge const &e); -MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &); +// std::pair split_edge(MultiDiEdge const &e); +// MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &); -std::unordered_set get_cut_set(MultiDiGraphView const &, - GraphSplit const &); +// std::unordered_set get_cut_set(MultiDiGraphView const &, +// GraphSplit const &); -std::unordered_set get_cut_set(MultiDiGraphView const &, - std::unordered_set const &); +// std::unordered_set get_cut_set(MultiDiGraphView const &, +// std::unordered_set const &); -bidict> - get_edge_splits(MultiDiGraphView const &, GraphSplit const &); +// bidict> +// get_edge_splits(MultiDiGraphView const &, GraphSplit const &); UndirectedGraphView get_subgraph(UndirectedGraphView const &, std::unordered_set const &); DiGraphView get_subgraph(DiGraphView const &, std::unordered_set const &); -MultiDiGraphView get_subgraph(MultiDiGraphView const &, - std::unordered_set const &); +// MultiDiGraphView get_subgraph(MultiDiGraphView const &, +// std::unordered_set const &); -template -OpenMultiDiGraphView get_subgraph(OpenMultiDiGraphView const &g, - std::unordered_set const &nodes) { - return OpenMultiDiGraphView::create(g, nodes); -} +// template +// OpenMultiDiGraphView get_subgraph(OpenMultiDiGraphView const &g, +// std::unordered_set const &nodes) { +// return OpenMultiDiGraphView::create(g, nodes); +// } std::unordered_map calculate_topo_rank(DiGraphView const &); Node get_node_with_greatest_topo_rank(std::unordered_set const &, DiGraphView const &); -MultiDiGraphView join(MultiDiGraphView const &lhs, MultiDiGraphView const &rhs); +// MultiDiGraphView join(MultiDiGraphView const &lhs, MultiDiGraphView const &rhs); DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs); UndirectedGraphView join(UndirectedGraphView const &lhs, UndirectedGraphView const &rhs); @@ -252,9 +251,9 @@ DiGraphView with_added_edges(DiGraphView const &, std::unordered_set const &); UndirectedGraphView as_undirected(DiGraphView const &); -MultiDiGraphView as_multidigraph(DiGraphView const &); +// MultiDiGraphView as_multidigraph(DiGraphView const &); DiGraphView as_digraph(UndirectedGraphView const &); -OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &); +// OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &); void export_as_dot( DotFile &, diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h index 49fdb24992..aa4f20e575 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml /* proj-data { - "generated_from": "684726a7add4aa912e194335fcfe91ab" + "generated_from": "111e640382a80b659bc33dd86a416ded" } */ @@ -11,7 +11,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_DTG_H #include "fmt/format.h" -#include "utils/graph/node.dtg.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" #include #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h new file mode 100644 index 0000000000..d6d44ce49a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_H + +#include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" + +namespace FlexFlow { + +DataflowEdgeQuery dataflow_edge_query_all(); +DataflowEdgeQuery dataflow_edge_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml index c941bbf985..6957a87863 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml @@ -9,7 +9,8 @@ features = [ includes = [ "utils/graph/query_set.h", - "utils/graph/node.dtg.h", + "utils/graph/node/node.dtg.h", + "utils/fmt/unordered_set.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index f5f0e669b4..d79983d8ec 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -11,9 +11,26 @@ struct DataflowGraph : virtual DataflowGraphView { public: NodeAddedResult add_node(std::vector const &inputs, int num_outputs); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(DataflowEdgeQuery const &) const; + std::unordered_set query_outputs(DataflowOutputQuery const &) const; + + template + static typename std::enable_if::value, + DataflowGraph>::type + create() { + return DataflowGraph(make_cow_ptr()); + } + +protected: + using DataflowGraphView::DataflowGraphView; + private: - IDataflowGraph const &get_interface() const; IDataflowGraph &get_interface(); + IDataflowGraph const &get_interface() const; + + friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h index bdbf204882..dd07355e48 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_VIEW_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_GRAPH_VIEW_H -#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/digraph/digraph_view.h" #include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" #include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" namespace FlexFlow { -struct DataflowGraphView : virtual MultiDiGraphView { +struct DataflowGraphView : virtual DiGraphView { DataflowGraphView(DataflowGraphView const &) = default; DataflowGraphView &operator=(DataflowGraphView const &) = default; diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h index e98994ecf1..90ef014bde 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml /* proj-data { - "generated_from": "9fc7657f7fcc71fdad9e6a5040771ad7" + "generated_from": "d43532deb325bcf8a502efbe90cd287b" } */ @@ -11,6 +11,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_INPUT_DTG_H #include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml index 19da01ab9f..f322fa63fe 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml @@ -7,6 +7,10 @@ features = [ "fmt", ] +includes = [ + "utils/graph/node/node.dtg.h", +] + [[fields]] name = "node" type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h index 4938220290..3e821d7dd8 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml /* proj-data { - "generated_from": "b704f2549a69ee6bfc1c5e28df421f9c" + "generated_from": "3f4ea6635782f141cc593291132c4064" } */ @@ -11,7 +11,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_DTG_H #include "fmt/format.h" -#include "utils/graph/node.dtg.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml index 6f2ce25f2b..f3ccebe046 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph/node.dtg.h", + "utils/graph/node/node.dtg.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h index 5a122c6d51..011d62bd7a 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml /* proj-data { - "generated_from": "6f662c3c4d285a4fd3c60713e6fc67fa" + "generated_from": "de957a7524bf0423dcfb68f70b2e6815" } */ @@ -11,7 +11,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_DTG_H #include "fmt/format.h" -#include "utils/graph/node.dtg.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" #include #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml index a61edbcdb0..0701855ba6 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml @@ -9,7 +9,8 @@ features = [ includes = [ "utils/graph/query_set.h", - "utils/graph/node.dtg.h", + "utils/graph/node/node.dtg.h", + "utils/fmt/unordered_set.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h index 5a1d29f9dc..9ae7ada0a6 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph_view.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_VIEW_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_VIEW_H -#include "utils/graph/multidigraph/i_multidigraph_view.h" +#include "utils/graph/digraph/i_digraph_view.h" #include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" #include "utils/graph/dataflow_graph/dataflow_output.dtg.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" @@ -9,11 +9,11 @@ namespace FlexFlow { -struct IDataflowGraphView : virtual public IMultiDiGraphView { +struct IDataflowGraphView : virtual public IDiGraphView { virtual std::unordered_set query_edges(DataflowEdgeQuery const &) const = 0; virtual std::unordered_set query_outputs(DataflowOutputQuery const &) const = 0; - std::unordered_set query_edges(MultiDiEdgeQuery const &) const override final; + std::unordered_set query_edges(DirectedEdgeQuery const &) const override final; virtual ~IDataflowGraphView() = default; }; diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h index 2a8159576c..edaac13f7e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml /* proj-data { - "generated_from": "4536bb54376e2e221e0ff29347e81662" + "generated_from": "6e5dc11e71c895683bd5bb9c30c1e42d" } */ @@ -13,8 +13,7 @@ #include "fmt/format.h" #include "utils/fmt/vector.h" #include "utils/graph/dataflow_graph/dataflow_output.dtg.h" -#include "utils/graph/multidigraph/multi_di_edge.dtg.h" -#include "utils/graph/node.dtg.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml index 515541eb71..df0d601530 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml @@ -9,8 +9,7 @@ features = [ includes = [ "", - "utils/graph/node.dtg.h", - "utils/graph/multidigraph/multi_di_edge.dtg.h", + "utils/graph/node/node.dtg.h", "utils/fmt/vector.h", "utils/graph/dataflow_graph/dataflow_output.dtg.h", ] diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h new file mode 100644 index 0000000000..ed615fd7f8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_H + +#include "utils/graph/digraph/digraph.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(DiGraph const &); +std::unordered_set get_edges(DirectedEdge const &); +std::unordered_set get_incoming_edges(DiGraph const &, Node const &); +std::unordered_set get_outgoing_edges(DiGraph const &, Node const &); +std::unordered_set get_sources(DiGraph const &); +std::vector get_topological_ordering(DiGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/di_output.dtg.h b/lib/utils/include/utils/graph/digraph/di_output.dtg.h index e88d929a5d..1d785cd64b 100644 --- a/lib/utils/include/utils/graph/digraph/di_output.dtg.h +++ b/lib/utils/include/utils/graph/digraph/di_output.dtg.h @@ -3,14 +3,14 @@ // lib/utils/include/utils/graph/digraph/di_output.struct.toml /* proj-data { - "generated_from": "a8f3fc2ad9e00f3c29a6dcd4658199ba" + "generated_from": "61e6ee4a13c7608bf6df0a549b94b2bc" } */ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_OUTPUT_DTG_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_DI_OUTPUT_DTG_H -#include "utils/graph/node.dtg.h" +#include "utils/graph/node/node.dtg.h" #include #include diff --git a/lib/utils/include/utils/graph/digraph/di_output.struct.toml b/lib/utils/include/utils/graph/digraph/di_output.struct.toml index f678af132a..27a71743f6 100644 --- a/lib/utils/include/utils/graph/digraph/di_output.struct.toml +++ b/lib/utils/include/utils/graph/digraph/di_output.struct.toml @@ -7,7 +7,7 @@ features = [ ] includes = [ - "utils/graph/node.dtg.h", + "utils/graph/node/node.dtg.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h index 716a3c5fc6..88db834947 100644 --- a/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml /* proj-data { - "generated_from": "294ae0103df2a3c388a2ce140c271f4e" + "generated_from": "4d7f3398fb178b272a4230d2db24c0d5" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" +#include #include #include @@ -34,6 +35,13 @@ struct DirectedEdgeQuery { }; } // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DirectedEdgeQuery> { + size_t operator()(::FlexFlow::DirectedEdgeQuery const &) const; +}; +} // namespace std + namespace FlexFlow { std::string format_as(DirectedEdgeQuery const &); std::ostream &operator<<(std::ostream &, DirectedEdgeQuery const &); diff --git a/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml b/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml index 2ede557642..3447cdb4b6 100644 --- a/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml +++ b/lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml @@ -3,6 +3,7 @@ name = "DirectedEdgeQuery" features = [ "eq", "ord", + "hash", "fmt", ] diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge.dtg.h.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge.variant.toml.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.variant.toml rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge.variant.toml.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge_query.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge_query.dtg.h.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge_query.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multi_di_edge_query.struct.toml.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/downward_open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/i_downward_open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/i_downward_open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h b/lib/utils/include/utils/graph/downward_open_multidigraph.old/i_downward_open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.h rename to lib/utils/include/utils/graph/downward_open_multidigraph.old/i_downward_open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/labelled/algorithms.h b/lib/utils/include/utils/graph/labelled/algorithms.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/algorithms.h rename to lib/utils/include/utils/graph/labelled/algorithms.h.old diff --git a/lib/utils/include/utils/graph/labelled/label_interfaces.h b/lib/utils/include/utils/graph/labelled/label_interfaces.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/label_interfaces.h rename to lib/utils/include/utils/graph/labelled/label_interfaces.h.old diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/node_labelled.h rename to lib/utils/include/utils/graph/labelled/node_labelled.h.old diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h rename to lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h.old diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/node_labelled_open.h rename to lib/utils/include/utils/graph/labelled/node_labelled_open.h.old diff --git a/lib/utils/include/utils/graph/labelled/open_algorithms.h b/lib/utils/include/utils/graph/labelled/open_algorithms.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/open_algorithms.h rename to lib/utils/include/utils/graph/labelled/open_algorithms.h.old diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/open_views.h rename to lib/utils/include/utils/graph/labelled/open_views.h.old diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h deleted file mode 100644 index ac5648c2e1..0000000000 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ /dev/null @@ -1,142 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H - -#include "node_labelled.h" -#include "output_labelled_interfaces.h" - -namespace FlexFlow { - -template -struct OutputLabelledMultiDiGraphView - : virtual public NodeLabelledMultiDiGraphView { -private: - using Interface = IOutputLabelledMultiDiGraphView; - -public: - OutputLabelledMultiDiGraphView(OutputLabelledMultiDiGraphView const &) = - default; - OutputLabelledMultiDiGraphView & - operator=(OutputLabelledMultiDiGraphView const &) = default; - - NodeLabel const &at(Node const &n) const { - return this->get_ptr().at(n); - } - - OutputLabel const &at(MultiDiOutput const &o) const { - return this->get_ptr().at(o); - } - - std::unordered_set query_nodes(NodeQuery const &q) const { - return this->get_ptr().query_nodes(q); - } - - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->get_ptr().query_edges(q); - } - - template - static typename std::enable_if::value, - OutputLabelledMultiDiGraphView>::type - create(Args &&...args) { - return OutputLabelledMultiDiGraphView( - make_cow_ptr(std::forward(args)...)); - } - -protected: - using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; - -private: - Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); - } -}; - -template -struct OutputLabelledMultiDiGraph - : virtual OutputLabelledMultiDiGraphView { -private: - using Interface = IOutputLabelledMultiDiGraph; - -public: - OutputLabelledMultiDiGraph(OutputLabelledMultiDiGraph const &other) = default; - OutputLabelledMultiDiGraph & - operator=(OutputLabelledMultiDiGraph const &other) = default; - - Node add_node(NodeLabel const &l) { - return this->get_ptr().add_node(l); - } - - NodePort add_node_port() { - return this->get_ptr().add_node_port(); - } - - NodeLabel &at(Node const &n) { - return this->get_ptr().at(n); - } - - NodeLabel const &at(Node const &n) const { - return this->get_ptr().at(n); - } - - void add_output(MultiDiOutput const &o, OutputLabel const &l) { - this->get_ptr().add_output(o, l); - }; - - void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { - this->get_ptr().add_edge(o, i); - }; - - void add_edge(MultiDiEdge const &e) { - this->get_ptr().add_edge(e); - } - - OutputLabel &at(MultiDiOutput const &o) { - return this->get_ptr().at(o); - } - - OutputLabel const &at(MultiDiOutput const &o) const { - return this->get_ptr().at(o); - } - - std::unordered_set query_nodes(NodeQuery const &q) const { - return this->get_ptr().query_nodes(q); - } - - std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->get_ptr().query_edges(q); - } - - template - static typename std::enable_if::value, - OutputLabelledMultiDiGraph>::type - create() { - return OutputLabelledMultiDiGraph(make_cow_ptr()); - } - -private: - OutputLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} - -private: - Interface &get_ptr() { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); - } - - Interface const &get_ptr() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get()); - } -}; - -template >:: - value && - !std::is_same::value), - bool>::type = true> -NodeLabel const &at(T const &g, Node const &n) { - return g.at(n); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h deleted file mode 100644 index 1680fc4fb5..0000000000 --- a/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_GRAPH_INTERFACES_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_GRAPH_INTERFACES_H - -#include "node_labelled_open.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -template -struct IOutputLabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - - virtual OutputLabel const &at(MultiDiOutput const &) const = 0; - - using INodeLabelledMultiDiGraphView::at; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); - -template -struct IOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraphView, - public INodeLabelledMultiDiGraph { -public: - virtual IOutputLabelledMultiDiGraph *clone() const = 0; - - virtual void add_output(MultiDiOutput const &output, - OutputLabel const &label) = 0; - virtual NodePort add_node_port() = 0; - - virtual NodeLabel &at(Node const &) = 0; - virtual NodeLabel const &at(Node const &) const = 0; - virtual OutputLabel &at(MultiDiOutput const &) = 0; - virtual OutputLabel const &at(MultiDiOutput const &) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/output_labelled_open.h rename to lib/utils/include/utils/graph/labelled/output_labelled_open.h.old diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h rename to lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h.old diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/standard_labelled.h rename to lib/utils/include/utils/graph/labelled/standard_labelled.h.old diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/standard_labelled_interfaces.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/standard_labelled_interfaces.h rename to lib/utils/include/utils/graph/labelled/standard_labelled_interfaces.h.old diff --git a/lib/utils/include/utils/graph/labelled/unordered_label.h b/lib/utils/include/utils/graph/labelled/unordered_label.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/unordered_label.h rename to lib/utils/include/utils/graph/labelled/unordered_label.h.old diff --git a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h rename to lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h.old diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h.old similarity index 100% rename from lib/utils/include/utils/graph/labelled/views.h rename to lib/utils/include/utils/graph/labelled/views.h.old diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.old similarity index 100% rename from lib/utils/include/utils/graph/labelled_graphs.h rename to lib/utils/include/utils/graph/labelled_graphs.old diff --git a/lib/utils/include/utils/graph/multidigraph/adjacency_multidigraph.h b/lib/utils/include/utils/graph/multidigraph.old/adjacency_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/multidigraph/adjacency_multidigraph.h rename to lib/utils/include/utils/graph/multidigraph.old/adjacency_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/multidigraph/i_multidigraph.h b/lib/utils/include/utils/graph/multidigraph.old/i_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/multidigraph/i_multidigraph.h rename to lib/utils/include/utils/graph/multidigraph.old/i_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h b/lib/utils/include/utils/graph/multidigraph.old/i_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/multidigraph/i_multidigraph_view.h rename to lib/utils/include/utils/graph/multidigraph.old/i_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.dtg.h.old similarity index 80% rename from lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.dtg.h.old index 0c471c8c35..26a1627ded 100644 --- a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.dtg.h +++ b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.dtg.h.old @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml /* proj-data { - "generated_from": "73b001bfb7a0b75c42cd5037bb8dc686" + "generated_from": "b7e237c6d5f55b89cb72848b20aef534" } */ @@ -19,9 +19,7 @@ namespace FlexFlow { struct MultiDiEdge { MultiDiEdge() = delete; - explicit MultiDiEdge(::FlexFlow::Node const &src, - ::FlexFlow::Node const &dst, - std::pair const &raw_edge_uid); + explicit MultiDiEdge(size_t const &raw_edge_uid); bool operator==(MultiDiEdge const &) const; bool operator!=(MultiDiEdge const &) const; @@ -29,9 +27,7 @@ struct MultiDiEdge { bool operator>(MultiDiEdge const &) const; bool operator<=(MultiDiEdge const &) const; bool operator>=(MultiDiEdge const &) const; - ::FlexFlow::Node src; - ::FlexFlow::Node dst; - std::pair raw_edge_uid; + size_t raw_edge_uid; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.struct.toml.old similarity index 65% rename from lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml rename to lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.struct.toml.old index 41b08deb18..f6733c16d7 100644 --- a/lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml +++ b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge.struct.toml.old @@ -11,14 +11,6 @@ includes = [ "utils/graph/node/node.dtg.h", ] -[[fields]] -name = "src" -type = "::FlexFlow::Node" - -[[fields]] -name = "dst" -type = "::FlexFlow::Node" - [[fields]] name = "raw_edge_uid" type = "size_t" diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.dtg.h.old similarity index 84% rename from lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.dtg.h.old index 47b30da97b..a081f865cd 100644 --- a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.dtg.h +++ b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.dtg.h.old @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml /* proj-data { - "generated_from": "bede7a523428098275e26ba89bb30eb0" + "generated_from": "56edb1e799c2bdf7435479ce8a483311" } */ @@ -11,6 +11,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_MULTI_DI_EDGE_QUERY_DTG_H #include "fmt/format.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" #include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" #include @@ -22,7 +23,8 @@ struct MultiDiEdgeQuery { MultiDiEdgeQuery() = delete; explicit MultiDiEdgeQuery( ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, - ::FlexFlow::query_set<::FlexFlow::Node> const &dsts); + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts, + ::FlexFlow::query_set<::FlexFlow::Edge> const &uids); bool operator==(MultiDiEdgeQuery const &) const; bool operator!=(MultiDiEdgeQuery const &) const; @@ -32,6 +34,7 @@ struct MultiDiEdgeQuery { bool operator>=(MultiDiEdgeQuery const &) const; ::FlexFlow::query_set<::FlexFlow::Node> srcs; ::FlexFlow::query_set<::FlexFlow::Node> dsts; + ::FlexFlow::query_set<::FlexFlow::Edge> uids; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.h.old similarity index 100% rename from lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.h rename to lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.h.old diff --git a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.struct.toml.old similarity index 72% rename from lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.struct.toml.old index 1d555b2626..3128f3b8c9 100644 --- a/lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml +++ b/lib/utils/include/utils/graph/multidigraph.old/multi_di_edge_query.struct.toml.old @@ -10,6 +10,7 @@ features = [ includes = [ "utils/graph/query_set.h", "utils/graph/node/node.dtg.h", + "utils/graph/multidigraph/multi_di_edge.dtg.h", ] [[fields]] @@ -19,3 +20,7 @@ type = "::FlexFlow::query_set<::FlexFlow::Node>" [[fields]] name = "dsts" type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "uids" +type = "::FlexFlow::query_set<::FlexFlow::Edge>" diff --git a/lib/utils/include/utils/graph/multidigraph.old/multi_di_output.h.old b/lib/utils/include/utils/graph/multidigraph.old/multi_di_output.h.old new file mode 100644 index 0000000000..b91b6c2a44 --- /dev/null +++ b/lib/utils/include/utils/graph/multidigraph.old/multi_di_output.h.old @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_EDGES_MULTI_DI_OUTPUT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_EDGES_MULTI_DI_OUTPUT_H + +#include "utils/graph/digraph/di_output.h" +#include "utils/graph/node_port.h" + +namespace FlexFlow { +struct MultiDiOutput : DiOutput { + NodePort src_idx; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.old/multidigraph.h.old similarity index 96% rename from lib/utils/include/utils/graph/multidigraph/multidigraph.h rename to lib/utils/include/utils/graph/multidigraph.old/multidigraph.h.old index 0fc498b8ac..85552380cc 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.old/multidigraph.h.old @@ -19,7 +19,7 @@ struct MultiDiGraph : virtual MultiDiGraphView { void add_node_unsafe(Node const &); void remove_node_unsafe(Node const &); - void add_edge(Edge const &e); + Edge add_edge(Node const &src, Node const &dst); void remove_edge(Edge const &e); std::unordered_set query_nodes(NodeQuery const &) const; diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph_view.h b/lib/utils/include/utils/graph/multidigraph.old/multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/multidigraph/multidigraph_view.h rename to lib/utils/include/utils/graph/multidigraph.old/multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/node/node_source.h b/lib/utils/include/utils/graph/node/node_source.h new file mode 100644 index 0000000000..7ec11f0686 --- /dev/null +++ b/lib/utils/include/utils/graph/node/node_source.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_SOURCE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_NODE_SOURCE_H + +#include "utils/graph/node/node.dtg.h" + +namespace FlexFlow { + +struct NodeSource { +public: + NodeSource(); + + Node new_node(); +private: + static size_t next_available_node_id; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_multidigraph/adjacency_openmultidigraph.h b/lib/utils/include/utils/graph/open_multidigraph.old/adjacency_openmultidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/adjacency_openmultidigraph.h rename to lib/utils/include/utils/graph/open_multidigraph.old/adjacency_openmultidigraph.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h b/lib/utils/include/utils/graph/open_multidigraph.old/i_open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph.h rename to lib/utils/include/utils/graph/open_multidigraph.old/i_open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h b/lib/utils/include/utils/graph/open_multidigraph.old/i_open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/i_open_multidigraph_view.h rename to lib/utils/include/utils/graph/open_multidigraph.old/i_open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.h rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge.struct.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge.struct.toml.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.h rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/input_multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/input_multi_di_edge_query.struct.toml.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.variant.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge.variant.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge.variant.toml.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multi_di_edge_query.struct.toml.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multidigraph.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h b/lib/utils/include/utils/graph/open_multidigraph.old/open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/open_multidigraph_view.h rename to lib/utils/include/utils/graph/open_multidigraph.old/open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.h rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge.struct.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge.struct.toml.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.dtg.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.h.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.h rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.h.old diff --git a/lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/open_multidigraph/output_multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/open_multidigraph.old/output_multi_di_edge_query.struct.toml.old diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index ff65533d2a..1f2ab6757f 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -6,31 +6,37 @@ #include "utils/exception.h" #include #include +#include +#include "utils/optional.h" +#include "utils/hash-utils.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { template struct query_set { query_set() = delete; - query_set(T const &t) : query(std::unordered_set{t}) {} + query_set(T const &t) : query(std::set{t}) {} - query_set(std::unordered_set const &query) : query(query) {} + query_set(std::unordered_set const &query) : query(std::set{query.cbegin(), query.cend()}) {} - query_set(std::optional> const &query) : query(query) {} + query_set(std::optional> const &query) + : query(transform(query, [](std::unordered_set const &s) { return std::set{s.cbegin(), s.cend()}; })) + { } query_set(std::initializer_list const &l) : query_set(std::unordered_set{l}) {} friend bool operator==(query_set const &lhs, query_set const &rhs) { - return lhs.value == rhs.value; + return lhs.query == rhs.query; } friend bool operator!=(query_set const &lhs, query_set const &rhs) { - return lhs.value != rhs.value; + return lhs.query != rhs.query; } friend bool operator<(query_set const &lhs, query_set const &rhs) { - return lhs.value < rhs.value; + return lhs.query < rhs.query; } friend bool is_matchall(query_set const &q) { @@ -39,7 +45,8 @@ struct query_set { friend std::unordered_set allowed_values(query_set const &q) { assert(!is_matchall(q)); - return q.query.value(); + std::set query_value = q.query.value(); + return std::unordered_set{query_value.begin(), query_value.end()}; } static query_set matchall() { @@ -50,8 +57,11 @@ struct query_set { return {std::unordered_set{}}; } + std::optional> const &value() const { + return this->query; + } private: - std::optional> query; + std::optional> query; }; template @@ -128,4 +138,15 @@ query_set query_union(query_set const &lhs, query_set const &rhs) { } // namespace FlexFlow +namespace std { + +template +struct hash<::FlexFlow::query_set> { + size_t operator()(::FlexFlow::query_set const &q) const { + return ::FlexFlow::get_std_hash(q.value()); + } +}; + +} + #endif diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel.dtg.h b/lib/utils/include/utils/graph/serial_parallel/parallel.dtg.h new file mode 100644 index 0000000000..8a90089437 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/parallel.dtg.h @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml +/* proj-data +{ + "generated_from": "0aee46f91f8e9ae0f18e1f496aa886b4" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/serial_parallel/serial.fwd.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct Parallel { + Parallel() = delete; + explicit Parallel( + std::vector> const + &children); + + bool operator==(Parallel const &) const; + bool operator!=(Parallel const &) const; + bool operator<(Parallel const &) const; + bool operator>(Parallel const &) const; + bool operator<=(Parallel const &) const; + bool operator>=(Parallel const &) const; + std::vector> children; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::Parallel> { + size_t operator()(::FlexFlow::Parallel const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(Parallel const &); +std::ostream &operator<<(std::ostream &, Parallel const &); +} // namespace FlexFlow +#include "utils/graph/serial_parallel/serial.dtg.h" + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_DTG_H diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h b/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h new file mode 100644 index 0000000000..c82a8ec6b3 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h @@ -0,0 +1,10 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_FWD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_FWD_H + +namespace FlexFlow { + +struct Parallel; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml b/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml new file mode 100644 index 0000000000..b8358a96c2 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Parallel" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/serial_parallel/serial.fwd.h", + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/variant.h", + "utils/hash/vector.h", +] + +trailing_includes = [ + "utils/graph/serial_parallel/serial.dtg.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial.dtg.h b/lib/utils/include/utils/graph/serial_parallel/serial.dtg.h new file mode 100644 index 0000000000..ad6f63c24f --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial.dtg.h @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/serial.struct.toml +/* proj-data +{ + "generated_from": "c5342b3e8b7dfa96c95fc171f85a0cf7" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/serial_parallel/parallel.fwd.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct Serial { + Serial() = delete; + explicit Serial( + std::vector> const + &children); + + bool operator==(Serial const &) const; + bool operator!=(Serial const &) const; + bool operator<(Serial const &) const; + bool operator>(Serial const &) const; + bool operator<=(Serial const &) const; + bool operator>=(Serial const &) const; + std::vector> children; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::Serial> { + size_t operator()(::FlexFlow::Serial const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(Serial const &); +std::ostream &operator<<(std::ostream &, Serial const &); +} // namespace FlexFlow +#include "utils/graph/serial_parallel/parallel.dtg.h" + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_DTG_H diff --git a/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h b/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h new file mode 100644 index 0000000000..913b81434c --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h @@ -0,0 +1,10 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_FWD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_FWD_H + +namespace FlexFlow { + +struct Serial; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml b/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml new file mode 100644 index 0000000000..1a5fd2408e --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Serial" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/serial_parallel/parallel.fwd.h", + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/variant.h", + "utils/hash/vector.h", +] + +trailing_includes = [ + "utils/graph/serial_parallel/parallel.dtg.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h new file mode 100644 index 0000000000..31ecc2dfdd --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h @@ -0,0 +1,127 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +/* proj-data +{ + "generated_from": "c019d65a059a20f13a419fa343ad0d26" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/serial_parallel/parallel.dtg.h" +#include "utils/graph/serial_parallel/serial.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct SerialParallelDecomposition { + SerialParallelDecomposition() = delete; + explicit SerialParallelDecomposition(::FlexFlow::Serial const &); + explicit SerialParallelDecomposition(::FlexFlow::Parallel const &); + explicit SerialParallelDecomposition(::FlexFlow::Node const &); + template + static constexpr bool IsPartOfSerialParallelDecomposition_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::Serial>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::Parallel>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::Node>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type SerialParallelDecomposition", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::Serial>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::Parallel>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::Node>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type SerialParallelDecomposition", + this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfSerialParallelDecomposition_v, + "SerialParallelDecomposition::has() expected one of " + "[::FlexFlow::Serial, ::FlexFlow::Parallel, " + "::FlexFlow::Node], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfSerialParallelDecomposition_v, + "SerialParallelDecomposition::get() expected one of " + "[::FlexFlow::Serial, ::FlexFlow::Parallel, " + "::FlexFlow::Node], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfSerialParallelDecomposition_v, + "SerialParallelDecomposition::get() expected one of " + "[::FlexFlow::Serial, ::FlexFlow::Parallel, " + "::FlexFlow::Node], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(SerialParallelDecomposition const &) const; + bool operator!=(SerialParallelDecomposition const &) const; + bool operator<(SerialParallelDecomposition const &) const; + bool operator>(SerialParallelDecomposition const &) const; + bool operator<=(SerialParallelDecomposition const &) const; + bool operator>=(SerialParallelDecomposition const &) const; + std::variant<::FlexFlow::Serial, ::FlexFlow::Parallel, ::FlexFlow::Node> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::SerialParallelDecomposition> { + size_t operator()(::FlexFlow::SerialParallelDecomposition const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::SerialParallelDecomposition const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::SerialParallelDecomposition const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_DTG_H diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml new file mode 100644 index 0000000000..cd80f2dd3e --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "SerialParallelDecomposition" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/serial_parallel/serial.dtg.h", + "utils/graph/serial_parallel/parallel.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::Serial" + +[[values]] +type = "::FlexFlow::Parallel" + +[[values]] +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/serial_parallel/serialparallel.h b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h new file mode 100644 index 0000000000..f1cf977eb3 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_H +#define _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/optional.h" +#include +#include + +namespace FlexFlow { + +Node find_source_node(DiGraphView const &); +Node find_sink_node(DiGraphView const &); + +std::optional find_bottleneck_node(DiGraphView const &); + +struct Parallel; + +SerialParallelDecomposition + get_serial_parallel_decomposition(DiGraphView const &); + +std::unordered_set get_nodes(SerialParallelDecomposition const &sp); + +// std::unordered_map parallel_extend(MultiDiGraph &g, +// MultiDiGraph const &ext); + +// std::unordered_map serial_extend(MultiDiGraph &g, +// MultiDiGraph const &ext); + +// MultiDiGraph serial_composition(MultiDiGraph const &g1, MultiDiGraph const &g2); + +// MultiDiGraph parallel_composition(MultiDiGraph const &g1, +// MultiDiGraph const &g2); + +// MultiDiGraph multidigraph_from_sp_decomposition( +// SerialParallelDecomposition const &sp_decomposition); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.dtg.h b/lib/utils/include/utils/graph/serial_parallel/split_type.dtg.h new file mode 100644 index 0000000000..d239e7a096 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/split_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +/* proj-data +{ + "generated_from": "61d75c03b0273d05eb9707f75132974e" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SPLIT_TYPE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SPLIT_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class SplitType { SERIAL, PARALLEL }; +std::string format_as(SplitType); +std::ostream &operator<<(std::ostream &, SplitType); +void to_json(::nlohmann::json &, SplitType); +void from_json(::nlohmann::json const &, SplitType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SplitType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SPLIT_TYPE_DTG_H diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml new file mode 100644 index 0000000000..96d85f0e12 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SplitType" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "SERIAL" + +[[values]] +name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/serialparallel.h b/lib/utils/include/utils/graph/serialparallel.h deleted file mode 100644 index 47bcb4031e..0000000000 --- a/lib/utils/include/utils/graph/serialparallel.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_H -#define _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_H - -#include "digraph.h" -#include "multidigraph.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -Node find_source_node(DiGraphView const &); -Node find_sink_node(DiGraphView const &); - -std::optional find_bottleneck_node(DiGraphView const &); - -struct Parallel; - -struct Serial { - std::vector> children; -}; - -struct Parallel { - std::vector> children; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Parallel, children); -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Serial, children); - -using SerialParallelDecomposition = std::variant; - -SerialParallelDecomposition - get_serial_parallel_decomposition(DiGraphView const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); - -std::unordered_map parallel_extend(MultiDiGraph &g, - MultiDiGraph const &ext); - -std::unordered_map serial_extend(MultiDiGraph &g, - MultiDiGraph const &ext); - -MultiDiGraph serial_composition(MultiDiGraph const &g1, MultiDiGraph const &g2); - -MultiDiGraph parallel_composition(MultiDiGraph const &g1, - MultiDiGraph const &g2); - -MultiDiGraph multidigraph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/i_upward_open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/i_upward_open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/i_upward_open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/i_upward_open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.dtg.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.variant.toml.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.variant.toml rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge.variant.toml.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge_query.dtg.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge_query.dtg.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge_query.struct.toml.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.struct.toml rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multi_di_edge_query.struct.toml.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multidigraph.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multidigraph.h.old diff --git a/lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h b/lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multidigraph_view.h.old similarity index 100% rename from lib/utils/include/utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h rename to lib/utils/include/utils/graph/upward_open_multidigraph.old/upward_open_multidigraph_view.h.old diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index 8330ef51bf..f251912103 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -4,8 +4,8 @@ #include "utils/graph/digraph/adjacency_digraph.h" #include "utils/graph/digraph/digraph_view.h" #include "utils/graph/undirected/undirected_graph_view.h" -#include "utils/graph/multidigraph/multidigraph_view.h" -#include "utils/graph/open_multidigraph/open_multidigraph_view.h" +// #include "utils/graph/multidigraph/multidigraph_view.h" +// #include "utils/graph/open_multidigraph/open_multidigraph_view.h" #include "utils/graph/views/join_node_key.dtg.h" namespace FlexFlow { @@ -58,22 +58,22 @@ struct DiSubgraphView : public IDiGraphView { std::unordered_set subgraph_nodes; }; -struct MultiDiSubgraphView : public IMultiDiGraphView { -public: - MultiDiSubgraphView() = delete; - explicit MultiDiSubgraphView(MultiDiGraphView const &, - std::unordered_set const &); - - std::unordered_set - query_edges(MultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - MultiDiSubgraphView *clone() const override; - -private: - MultiDiGraphView g; - std::unordered_set subgraph_nodes; -}; +// struct MultiDiSubgraphView : public IMultiDiGraphView { +// public: +// MultiDiSubgraphView() = delete; +// explicit MultiDiSubgraphView(MultiDiGraphView const &, +// std::unordered_set const &); +// +// std::unordered_set +// query_edges(MultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// MultiDiSubgraphView *clone() const override; +// +// private: +// MultiDiGraphView g; +// std::unordered_set subgraph_nodes; +// }; struct NodeSource { public: @@ -147,29 +147,29 @@ struct JoinedDigraphView : virtual public IDiGraphView { JoinedNodeView joined_nodes; }; -struct JoinedMultiDigraphView : public IMultiDiGraphView { -public: - JoinedMultiDigraphView() = delete; - JoinedMultiDigraphView(MultiDiGraphView const &lhs, - MultiDiGraphView const &rhs); - - std::unordered_set - query_edges(MultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedNodeView const &joined_nodes_view() const; - - JoinedMultiDigraphView *clone() const override; - -private: - MultiDiEdge fix_lhs_edge(MultiDiEdge const &) const; - MultiDiEdge fix_rhs_edge(MultiDiEdge const &) const; - -private: - MultiDiGraphView lhs; - MultiDiGraphView rhs; - JoinedNodeView joined_nodes; -}; +// struct JoinedMultiDigraphView : public IMultiDiGraphView { +// public: +// JoinedMultiDigraphView() = delete; +// JoinedMultiDigraphView(MultiDiGraphView const &lhs, +// MultiDiGraphView const &rhs); +// +// std::unordered_set +// query_edges(MultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// JoinedNodeView const &joined_nodes_view() const; +// +// JoinedMultiDigraphView *clone() const override; +// +// private: +// MultiDiEdge fix_lhs_edge(MultiDiEdge const &) const; +// MultiDiEdge fix_rhs_edge(MultiDiEdge const &) const; +// +// private: +// MultiDiGraphView lhs; +// MultiDiGraphView rhs; +// JoinedNodeView joined_nodes; +// }; struct AddDirectedEdgesView : public IDiGraphView { public: @@ -229,88 +229,88 @@ struct ContractNodeView : public IDiGraphView { Node from, to; }; -struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { -public: - OpenMultiDiSubgraphView() = delete; - OpenMultiDiSubgraphView(OpenMultiDiGraphView const &, - std::unordered_set const &); - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - OpenMultiDiSubgraphView *clone() const override; - -private: - OpenMultiDiGraphView g; - std::unordered_set nodes; - std::unordered_set inputs; - std::unordered_set outputs; -}; - -struct UpwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { - UpwardOpenMultiDiSubgraphView() = delete; - UpwardOpenMultiDiSubgraphView(OpenMultiDiGraphView const &, - std::unordered_set const &); - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - UpwardOpenMultiDiSubgraphView *clone() const override; - -private: - OpenMultiDiGraphView g; - std::unordered_set nodes; - std::unordered_set inputs; -}; - -struct DownwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { - DownwardOpenMultiDiSubgraphView() = delete; - DownwardOpenMultiDiSubgraphView(OpenMultiDiGraphView const &, - std::unordered_set const &); - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - DownwardOpenMultiDiSubgraphView *clone() const override; - -private: - OpenMultiDiGraphView g; - std::unordered_set nodes; - std::unordered_set outputs; -}; - -struct ClosedMultiDiSubgraphView : public IOpenMultiDiGraphView { - ClosedMultiDiSubgraphView() = delete; - ClosedMultiDiSubgraphView(OpenMultiDiGraphView const &, - std::unordered_set const &); - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - ClosedMultiDiSubgraphView *clone() const override; - -private: - OpenMultiDiGraphView g; - std::unordered_set nodes; -}; +// struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { +// public: +// OpenMultiDiSubgraphView() = delete; +// OpenMultiDiSubgraphView(OpenMultiDiGraphView const &, +// std::unordered_set const &); +// +// std::unordered_set +// query_edges(OpenMultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// OpenMultiDiSubgraphView *clone() const override; +// +// private: +// OpenMultiDiGraphView g; +// std::unordered_set nodes; +// std::unordered_set inputs; +// std::unordered_set outputs; +// }; + +// struct UpwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { +// UpwardOpenMultiDiSubgraphView() = delete; +// UpwardOpenMultiDiSubgraphView(OpenMultiDiGraphView const &, +// std::unordered_set const &); +// +// std::unordered_set +// query_edges(OpenMultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// UpwardOpenMultiDiSubgraphView *clone() const override; +// +// private: +// OpenMultiDiGraphView g; +// std::unordered_set nodes; +// std::unordered_set inputs; +// }; + +// struct DownwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { +// DownwardOpenMultiDiSubgraphView() = delete; +// DownwardOpenMultiDiSubgraphView(OpenMultiDiGraphView const &, +// std::unordered_set const &); +// +// std::unordered_set +// query_edges(OpenMultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// DownwardOpenMultiDiSubgraphView *clone() const override; +// +// private: +// OpenMultiDiGraphView g; +// std::unordered_set nodes; +// std::unordered_set outputs; +// }; + +// struct ClosedMultiDiSubgraphView : public IOpenMultiDiGraphView { +// ClosedMultiDiSubgraphView() = delete; +// ClosedMultiDiSubgraphView(OpenMultiDiGraphView const &, +// std::unordered_set const &); +// +// std::unordered_set +// query_edges(OpenMultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// ClosedMultiDiSubgraphView *clone() const override; +// +// private: +// OpenMultiDiGraphView g; +// std::unordered_set nodes; +// }; UndirectedEdge to_undirected_edge(DirectedEdge const &); std::unordered_set to_undirected_edges(std::unordered_set const &); -UndirectedEdge to_undirected_edge(MultiDiEdge const &); -std::unordered_set - to_undirected_edges(std::unordered_set const &); +// UndirectedEdge to_undirected_edge(MultiDiEdge const &); +// std::unordered_set +// to_undirected_edges(std::unordered_set const &); std::unordered_set to_directed_edges(UndirectedEdge const &); std::unordered_set to_directed_edges(std::unordered_set const &); -DirectedEdge to_directed_edge(MultiDiEdge const &); -std::unordered_set - to_directed_edges(std::unordered_set const &); +// DirectedEdge to_directed_edge(MultiDiEdge const &); +// std::unordered_set +// to_directed_edges(std::unordered_set const &); struct ViewDiGraphAsUndirectedGraph : public IUndirectedGraphView { public: @@ -340,33 +340,33 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView { UndirectedGraphView g; }; -struct ViewDiGraphAsMultiDiGraph : public IMultiDiGraphView { -public: - explicit ViewDiGraphAsMultiDiGraph(DiGraphView const &); - - std::unordered_set - query_edges(MultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - ViewDiGraphAsMultiDiGraph *clone() const override; - -private: - DiGraphView g; -}; - -struct ViewMultiDiGraphAsOpenMultiDiGraph : public IOpenMultiDiGraphView { -public: - explicit ViewMultiDiGraphAsOpenMultiDiGraph(MultiDiGraphView const &); - - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - ViewMultiDiGraphAsOpenMultiDiGraph *clone() const override; - -private: - MultiDiGraphView g; -}; +// struct ViewDiGraphAsMultiDiGraph : public IMultiDiGraphView { +// public: +// explicit ViewDiGraphAsMultiDiGraph(DiGraphView const &); +// +// std::unordered_set +// query_edges(MultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// ViewDiGraphAsMultiDiGraph *clone() const override; +// +// private: +// DiGraphView g; +// }; + +// struct ViewMultiDiGraphAsOpenMultiDiGraph : public IOpenMultiDiGraphView { +// public: +// explicit ViewMultiDiGraphAsOpenMultiDiGraph(MultiDiGraphView const &); +// +// std::unordered_set +// query_edges(OpenMultiDiEdgeQuery const &) const override; +// std::unordered_set query_nodes(NodeQuery const &) const override; +// +// ViewMultiDiGraphAsOpenMultiDiGraph *clone() const override; +// +// private: +// MultiDiGraphView g; +// }; DirectedEdge flipped(DirectedEdge const &); @@ -395,10 +395,10 @@ Impl materialize_digraph_view(IDiGraphView const &g) { return materialize_view(g); } -template -Impl materialize_multidigraph_view(IMultiDiGraphView const &g) { - return materialize_view(g); -} +// template +// Impl materialize_multidigraph_view(IMultiDiGraphView const &g) { +// return materialize_view(g); +// } } // namespace FlexFlow diff --git a/lib/utils/include/utils/hash-utils-core.h b/lib/utils/include/utils/hash-utils-core.h index a16674f454..ea333563d0 100644 --- a/lib/utils/include/utils/hash-utils-core.h +++ b/lib/utils/include/utils/hash-utils-core.h @@ -7,6 +7,8 @@ #include #include +namespace FlexFlow { + template std::size_t get_std_hash(T const &v) { std::hash hasher; @@ -27,6 +29,16 @@ inline void hash_combine(std::size_t &seed, T const &v, Ts... rest) { hash_combine(seed, rest...); } +template +void unordered_container_hash(std::size_t &seed, T const &t) { + hash_combine(seed, t.size()); + size_t total = 0; + for (auto const &v : t) { + total += get_std_hash(v); + } + hash_combine(seed, total); +} + template void iter_hash(std::size_t &seed, It start, It end) { hash_combine(seed, std::distance(start, end)); @@ -35,61 +47,6 @@ void iter_hash(std::size_t &seed, It start, It end) { } } -namespace std { -template -struct hash> { -private: - // this is a termination condition - // N == sizeof...(TupleTypes) - // - template - inline typename std::enable_if::type - hash_combine_tup(size_t &seed, - std::tuple const &tup) const {} - - // this is the computation function - // continues till condition N < sizeof...(TupleTypes) holds - // - template - inline typename std::enable_if < Idx::type - hash_combine_tup(size_t &seed, - std::tuple const &tup) const { - hash_combine(seed, std::get(tup)); - - // on to next element - hash_combine_tup(seed, tup); - } - -public: - size_t operator()(std::tuple const &tupleValue) const { - size_t seed = 0; - // begin with the first iteration - hash_combine_tup<0>(seed, tupleValue); - return seed; - } -}; - -template -struct hash> { - size_t operator()(std::pair const &p) const { - size_t seed = 283746; - - hash_combine(seed, p.first); - hash_combine(seed, p.second); - - return seed; - } -}; - -template -struct hash> { - size_t operator()(std::vector const &vec) const { - size_t seed = 0; - iter_hash(seed, vec.cbegin(), vec.cend()); - return seed; - } -}; - -} // namespace std +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/hash-utils.h b/lib/utils/include/utils/hash-utils.h index d56ff34644..831a12e554 100644 --- a/lib/utils/include/utils/hash-utils.h +++ b/lib/utils/include/utils/hash-utils.h @@ -3,6 +3,8 @@ #include "containers.h" #include "hash-utils-core.h" +#include +#include using namespace FlexFlow; @@ -10,20 +12,40 @@ namespace std { template struct hash> { size_t operator()(std::unordered_set const &s) const { - auto sorted = sorted_by(s, ::FlexFlow::compare_by([](T const &t) { - return get_std_hash(t); - })); - return get_std_hash(sorted); + size_t result = 0; + unordered_container_hash(result, s); + return result; + } +}; + +template +struct hash> { + size_t operator()(std::set const &s) const { + size_t result = 0; + unordered_container_hash(result, s); + return result; } }; template struct hash> { size_t operator()(std::unordered_map const &m) const { - return get_std_hash(::FlexFlow::items(m)); + size_t result = 0; + unordered_container_hash(result, m); + return result; } }; +template +struct hash> { + size_t operator()(std::map const &m) const { + size_t result = 0; + unordered_container_hash(result, m); + return result; + } +}; + + } // namespace std #endif diff --git a/lib/utils/include/utils/hash/pair.h b/lib/utils/include/utils/hash/pair.h new file mode 100644 index 0000000000..5d0af39848 --- /dev/null +++ b/lib/utils/include/utils/hash/pair.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_PAIR_H + +#include +#include "utils/hash-utils-core.h" + +namespace std { + +template +struct hash> { + size_t operator()(std::pair const &p) const { + size_t seed = 283746; + + hash_combine(seed, p.first); + hash_combine(seed, p.second); + + return seed; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/tuple.h b/lib/utils/include/utils/hash/tuple.h new file mode 100644 index 0000000000..de64264064 --- /dev/null +++ b/lib/utils/include/utils/hash/tuple.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_TUPLE_H + +#include "utils/hash-utils-core.h" +#include + +namespace std { +template +struct hash> { +private: + // this is a termination condition + // N == sizeof...(TupleTypes) + // + template + inline typename std::enable_if::type + hash_combine_tup(size_t &seed, + std::tuple const &tup) const {} + + // this is the computation function + // continues till condition N < sizeof...(TupleTypes) holds + // + template + inline typename std::enable_if < Idx::type + hash_combine_tup(size_t &seed, + std::tuple const &tup) const { + hash_combine(seed, std::get(tup)); + + // on to next element + hash_combine_tup(seed, tup); + } + +public: + size_t operator()(std::tuple const &tupleValue) const { + size_t seed = 0; + // begin with the first iteration + hash_combine_tup<0>(seed, tupleValue); + return seed; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/vector.h b/lib/utils/include/utils/hash/vector.h new file mode 100644 index 0000000000..3785076288 --- /dev/null +++ b/lib/utils/include/utils/hash/vector.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_VECTOR_H + +#include "utils/hash-utils-core.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::vector const &vec) const { + size_t seed = 0; + iter_hash(seed, vec.cbegin(), vec.cend()); + return seed; + } +}; + +} + +#endif diff --git a/lib/utils/src/graph/adjacency_openmultidigraph.cc b/lib/utils/src/graph/adjacency_openmultidigraph.cc.old similarity index 100% rename from lib/utils/src/graph/adjacency_openmultidigraph.cc rename to lib/utils/src/graph/adjacency_openmultidigraph.cc.old diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc.old similarity index 100% rename from lib/utils/src/graph/open_graphs.cc rename to lib/utils/src/graph/open_graphs.cc.old diff --git a/lib/utils/src/utils/fmt/variant.cc b/lib/utils/src/utils/fmt/variant.cc new file mode 100644 index 0000000000..e2d387eedb --- /dev/null +++ b/lib/utils/src/utils/fmt/variant.cc @@ -0,0 +1 @@ +#include "utils/fmt/variant.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 2223b120a7..5c86ff1086 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -1,20 +1,18 @@ #include "utils/graph/algorithms.h" #include "utils/containers.h" #include "utils/exception.h" -#include "utils/graph/diedge.h" -#include "utils/graph/digraph.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/multidigraph.h" -#include "utils/graph/multidigraph_interfaces.h" +#include "utils/graph/node/node_query.h" #include "utils/graph/traversal.h" -#include "utils/graph/undirected.h" -#include "utils/graph/views.h" +#include "utils/graph/views/views.h" #include "utils/variant.h" #include #include #include #include #include +#include "utils/hash-utils.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" namespace FlexFlow { @@ -39,37 +37,37 @@ std::vector add_nodes(DiGraph &g, int num_nodes) { return add_nodes_impl(g, num_nodes); } -std::vector add_nodes(MultiDiGraph &g, int num_nodes) { - return add_nodes_impl(g, num_nodes); -} +// std::vector add_nodes(MultiDiGraph &g, int num_nodes) { +// return add_nodes_impl(g, num_nodes); +// } -std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes) { - return add_nodes_impl(g, num_nodes); -} +// std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes) { +// return add_nodes_impl(g, num_nodes); +// } -std::vector add_node_ports(MultiDiGraph &g, int num_node_ports) { - std::vector node_ports; - for (int i = 0; i < num_node_ports; i++) { - node_ports.push_back(g.add_node_port()); - } - return node_ports; -} +// std::vector add_node_ports(MultiDiGraph &g, int num_node_ports) { +// std::vector node_ports; +// for (int i = 0; i < num_node_ports; i++) { +// node_ports.push_back(g.add_node_port()); +// } +// return node_ports; +// } std::unordered_set get_nodes(GraphView const &g) { - return g.query_nodes(NodeQuery::all()); + return g.query_nodes(node_query_all()); } -std::unordered_set get_nodes(InputMultiDiEdge const &edge) { - return {edge.dst}; -} +// std::unordered_set get_nodes(InputMultiDiEdge const &edge) { +// return {edge.dst}; +// } -std::unordered_set get_nodes(OutputMultiDiEdge const &edge) { - return {edge.src}; -} +// std::unordered_set get_nodes(OutputMultiDiEdge const &edge) { +// return {edge.src}; +// } -std::unordered_set get_nodes(MultiDiEdge const &edge) { - return {edge.src, edge.dst}; -} +// std::unordered_set get_nodes(MultiDiEdge const &edge) { +// return {edge.src, edge.dst}; +// } struct GetNodesFunctor { template @@ -78,30 +76,30 @@ struct GetNodesFunctor { } }; -std::unordered_set get_nodes(OpenMultiDiEdge const &edge) { - return visit(GetNodesFunctor{}, edge); -} +// std::unordered_set get_nodes(OpenMultiDiEdge const &edge) { +// return visit(GetNodesFunctor{}, edge); +// } std::unordered_set query_nodes(GraphView const &g, std::unordered_set const &nodes) { - return g.query_nodes({nodes}); + return g.query_nodes(NodeQuery{nodes}); } -std::unordered_set get_present_node_ports(MultiDiGraphView const &g) { - return flatmap(get_edges(g), [](MultiDiEdge const &e) { - return std::unordered_set{e.src_idx, e.dst_idx}; - }); -} +// std::unordered_set get_present_node_ports(MultiDiGraphView const &g) { +// return flatmap(get_edges(g), [](MultiDiEdge const &e) { +// return std::unordered_set{e.src_idx, e.dst_idx}; +// }); +// } -void remove_node(MultiDiGraph &g, Node const &n) { - for (MultiDiEdge const &e : get_incoming_edges(g, n)) { - g.remove_edge(e); - } - for (MultiDiEdge const &e : get_outgoing_edges(g, n)) { - g.remove_edge(e); - } - g.remove_node_unsafe(n); -} +// void remove_node(MultiDiGraph &g, Node const &n) { +// for (MultiDiEdge const &e : get_incoming_edges(g, n)) { +// g.remove_edge(e); +// } +// for (MultiDiEdge const &e : get_outgoing_edges(g, n)) { +// g.remove_edge(e); +// } +// g.remove_node_unsafe(n); +// } void remove_node(DiGraph &g, Node const &n) { for (DirectedEdge const &e : get_incoming_edges(g, n)) { @@ -120,16 +118,16 @@ void remove_node(UndirectedGraph &g, Node const &n) { g.remove_node_unsafe(n); } -void remove_node_if_unused(MultiDiGraph &g, Node const &n) { - if (!get_incoming_edges(g, n).empty()) { - return; - } - if (!get_outgoing_edges(g, n).empty()) { - return; - } - - g.remove_node_unsafe(n); -} +// void remove_node_if_unused(MultiDiGraph &g, Node const &n) { +// if (!get_incoming_edges(g, n).empty()) { +// return; +// } +// if (!get_outgoing_edges(g, n).empty()) { +// return; +// } +// +// g.remove_node_unsafe(n); +// } void remove_node_if_unused(DiGraph &g, Node const &n) { if (!get_incoming_edges(g, n).empty()) { @@ -176,11 +174,11 @@ DiGraphView apply_contraction(DiGraphView const &g, return contractedView; } -void add_edges(MultiDiGraph &g, std::vector const &edges) { - for (MultiDiEdge const &e : edges) { - g.add_edge(e); - } -} +// void add_edges(MultiDiGraph &g, std::vector const &edges) { +// for (MultiDiEdge const &e : edges) { +// g.add_edge(e); +// } +// } void add_edges(DiGraph &g, std::vector const &edges) { for (DirectedEdge const &e : edges) { @@ -194,26 +192,26 @@ void add_edges(UndirectedGraph &g, std::vector const &edges) { } } -bool contains_edge(MultiDiGraphView const &g, MultiDiEdge const &e) { - return contains(g.query_edges({e.src, e.dst, e.src_idx, e.dst_idx}), e); -} +// bool contains_edge(MultiDiGraphView const &g, MultiDiEdge const &e) { +// return contains(g.query_edges({e.src, e.dst, e.src_idx, e.dst_idx}), e); +// } bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { - return contains(g.query_edges({e.src, e.dst}), e); + return contains(g.query_edges(DirectedEdgeQuery{e.src, e.dst}), e); } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = {{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; return contains(g.query_edges(q), e); } -void remove_edges(MultiDiGraph &g, - std::unordered_set const &edges) { - for (MultiDiEdge const &e : edges) { - assert(contains_edge(g, e)); - g.remove_edge(e); - } -} +// void remove_edges(MultiDiGraph &g, +// std::unordered_set const &edges) { +// for (MultiDiEdge const &e : edges) { +// assert(contains_edge(g, e)); +// g.remove_edge(e); +// } +// } void remove_edges(DiGraph &g, std::unordered_set const &edges) { for (DirectedEdge const &e : edges) { @@ -234,130 +232,134 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { return {e.smaller, e.bigger}; } -std::unordered_set get_edges(MultiDiGraphView const &g) { - return g.query_edges(MultiDiEdgeQuery::all()); -} +// std::unordered_set get_edges(MultiDiGraphView const &g) { +// return g.query_edges(MultiDiEdgeQuery::all()); +// } std::unordered_set get_edges(DiGraphView const &g) { - return g.query_edges(DirectedEdgeQuery::all()); + return g.query_edges(directed_edge_query_all()); } std::unordered_set get_edges(UndirectedGraphView const &g) { - return g.query_edges(UndirectedEdgeQuery::all()); + return g.query_edges(undirected_edge_query_all()); } -std::unordered_set get_edges(OpenMultiDiGraphView const &g) { - return g.query_edges(OpenMultiDiEdgeQuery::all()); -} +// std::unordered_set get_edges(OpenMultiDiGraphView const &g) { +// return g.query_edges(OpenMultiDiEdgeQuery::all()); +// } std::unordered_set get_node_edges(UndirectedGraphView const &g, Node const &n) { - return g.query_edges({n}); + return g.query_edges(UndirectedEdgeQuery{n}); } -std::unordered_set get_outputs(MultiDiGraphView const &g) { - return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiOutput { - return static_cast(e); - }); -} +// std::unordered_set get_outputs(MultiDiGraphView const &g) { +// return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiOutput { +// return static_cast(e); +// }); +// } -std::unordered_set get_inputs(MultiDiGraphView const &g) { - return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiInput { - return static_cast(e); - }); -} +// std::unordered_set get_inputs(MultiDiGraphView const &g) { +// return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiInput { +// return static_cast(e); +// }); +// } -std::unordered_set get_incoming_edges(MultiDiGraphView const &g, - Node const &n) { - return get_incoming_edges(g, std::unordered_set{n}); -} +// std::unordered_set get_incoming_edges(MultiDiGraphView const &g, +// Node const &n) { +// return get_incoming_edges(g, std::unordered_set{n}); +// } std::unordered_set get_incoming_edges(DiGraphView const &g, Node const &n) { return get_incoming_edges(g, std::unordered_set{n}); } -std::unordered_set - get_incoming_edges(MultiDiGraphView const &g, - std::unordered_set dsts) { - return g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(dsts)); -} +// std::unordered_set +// get_incoming_edges(MultiDiGraphView const &g, +// std::unordered_set dsts) { +// return g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(dsts)); +// } std::unordered_set get_incoming_edges(DiGraphView const &g, std::unordered_set const &dsts) { - auto multidigraph_view = as_multidigraph(g); - return to_directed_edges(get_incoming_edges(multidigraph_view, dsts)); + NOT_IMPLEMENTED(); + // auto multidigraph_view = as_multidigraph(g); + // return to_directed_edges(get_incoming_edges(multidigraph_view, dsts)); } -std::unordered_set - get_outgoing_edges(MultiDiGraphView const &g, - std::unordered_set const &srcs) { - return g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(srcs)); -} +// std::unordered_set +// get_outgoing_edges(MultiDiGraphView const &g, +// std::unordered_set const &srcs) { +// return g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(srcs)); +// } -std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, - Node const &n) { - return get_outgoing_edges(g, std::unordered_set{n}); -} +// std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, +// Node const &n) { +// return get_outgoing_edges(g, std::unordered_set{n}); +// } std::unordered_set get_outgoing_edges(DiGraphView const &g, std::unordered_set const &dsts) { - auto multidigraph_view = as_multidigraph(g); - return to_directed_edges(get_outgoing_edges(multidigraph_view, dsts)); + NOT_IMPLEMENTED(); + // auto multidigraph_view = as_multidigraph(g); + // return to_directed_edges(get_outgoing_edges(multidigraph_view, dsts)); } std::unordered_set get_outgoing_edges(DiGraphView const &g, Node const &n) { return get_outgoing_edges(g, std::unordered_set{n}); } -std::unordered_map> - get_incoming_edges_by_idx(MultiDiGraphView const &g, Node const &n) { - std::unordered_set edges = get_incoming_edges(g, n); - std::unordered_map> result; - for (MultiDiEdge const &e : edges) { - result[e.dst_idx].insert(e); - } - return result; -} - -std::unordered_map> - get_outgoing_edges_by_idx(MultiDiGraphView const &g, Node const &n) { - std::unordered_set edges = get_outgoing_edges(g, n); - std::unordered_map> result; - for (MultiDiEdge const &e : edges) { - result[e.src_idx].insert(e); - } - return result; -} - -std::unordered_set - get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all( - narrow(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); -} -std::unordered_set - get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return value_all(narrow(g.query_edges( - OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::all().with_dst_nodes({n}), - MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())))); -} - -std::unordered_set - get_open_outputs(OpenMultiDiGraphView const &g) { - return narrow( - g.query_edges(OutputMultiDiEdgeQuery::all())); -} -std::unordered_set - get_open_inputs(OpenMultiDiGraphView const &g) { - return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); -} +// std::unordered_map> +// get_incoming_edges_by_idx(MultiDiGraphView const &g, Node const &n) { +// std::unordered_set edges = get_incoming_edges(g, n); +// std::unordered_map> result; +// for (MultiDiEdge const &e : edges) { +// result[e.dst_idx].insert(e); +// } +// return result; +// } + +// std::unordered_map> +// get_outgoing_edges_by_idx(MultiDiGraphView const &g, Node const &n) { +// std::unordered_set edges = get_outgoing_edges(g, n); +// std::unordered_map> result; +// for (MultiDiEdge const &e : edges) { +// result[e.src_idx].insert(e); +// } +// return result; +// } + +// std::unordered_set +// get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { +// return value_all( +// narrow(g.query_edges(OpenMultiDiEdgeQuery( +// InputMultiDiEdgeQuery::none(), +// MultiDiEdgeQuery::all().with_src_nodes({n}), +// OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); +// } + +// std::unordered_set +// get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { +// return value_all(narrow(g.query_edges( +// OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::all().with_dst_nodes({n}), +// MultiDiEdgeQuery::all().with_dst_nodes({n}), +// OutputMultiDiEdgeQuery::none())))); +// } + +// std::unordered_set +// get_open_outputs(OpenMultiDiGraphView const &g) { +// return narrow( +// g.query_edges(OutputMultiDiEdgeQuery::all())); +// } + +// std::unordered_set +// get_open_inputs(OpenMultiDiGraphView const &g) { +// return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); +// } std::unordered_map> get_predecessors(DiGraphView const &g, @@ -436,9 +438,9 @@ std::optional is_acyclic(DiGraphView const &g) { return true; } -std::optional is_acyclic(MultiDiGraph const &g) { - return is_acyclic(g); -} +// std::optional is_acyclic(MultiDiGraph const &g) { +// return is_acyclic(g); +// } std::vector get_unchecked_topological_ordering(DiGraphView const &g) { auto dfs_view = unchecked_dfs(g, get_sources(g)); @@ -487,48 +489,48 @@ std::vector get_edge_topological_ordering(DiGraphView const &g) { return result; } -Node get_src_node(MultiDiEdge const &e) { - return e.src; -} +// Node get_src_node(MultiDiEdge const &e) { +// return e.src; +// } -Node get_dst_node(MultiDiEdge const &e) { - return e.dst; -} +// Node get_dst_node(MultiDiEdge const &e) { +// return e.dst; +// } -Node get_dst_node(InputMultiDiEdge const &e) { - return e.dst; -} +// Node get_dst_node(InputMultiDiEdge const &e) { +// return e.dst; +// } -Node get_src_node(OutputMultiDiEdge const &e) { - return e.src; -} +// Node get_src_node(OutputMultiDiEdge const &e) { +// return e.src; +// } -NodePort get_src_idx(MultiDiEdge const &e) { - return e.src_idx; -} +// NodePort get_src_idx(MultiDiEdge const &e) { +// return e.src_idx; +// } -NodePort get_dst_idx(MultiDiEdge const &e) { - return e.dst_idx; -} +// NodePort get_dst_idx(MultiDiEdge const &e) { +// return e.dst_idx; +// } -NodePort get_dst_idx(InputMultiDiEdge const &e) { - return e.dst_idx; -} +// NodePort get_dst_idx(InputMultiDiEdge const &e) { +// return e.dst_idx; +// } -NodePort get_src_idx(OutputMultiDiEdge const &e) { - return e.src_idx; -} +// NodePort get_src_idx(OutputMultiDiEdge const &e) { +// return e.src_idx; +// } std::unordered_set get_neighbors(DiGraphView const &g, Node const &n) { UndirectedGraphView undirected = as_undirected(g); return get_neighbors(undirected, n); } -std::unordered_set get_neighbors(MultiDiGraphView const &g, - Node const &n) { - UndirectedGraphView undirected = as_undirected(g); - return get_neighbors(undirected, n); -} +// std::unordered_set get_neighbors(MultiDiGraphView const &g, +// Node const &n) { +// UndirectedGraphView undirected = as_undirected(g); +// return get_neighbors(undirected, n); +// } std::unordered_set get_neighbors(UndirectedGraphView const &g, Node const &n) { @@ -537,19 +539,19 @@ std::unordered_set get_neighbors(UndirectedGraphView const &g, }); } -std::vector - get_edge_topological_ordering(MultiDiGraphView const &g) { - std::vector result; - for (Node const &n : get_topological_ordering(g)) { - for (MultiDiEdge const &e : get_outgoing_edges(g, n)) { - result.push_back(e); - } - } - - assert(result.size() == get_edges(g).size()); - - return result; -} +// std::vector +// get_edge_topological_ordering(MultiDiGraphView const &g) { +// std::vector result; +// for (Node const &n : get_topological_ordering(g)) { +// for (MultiDiEdge const &e : get_outgoing_edges(g, n)) { +// result.push_back(e); +// } +// } +// +// assert(result.size() == get_edges(g).size()); +// +// return result; +// } std::unordered_map> get_dominators(DiGraphView const &g) { @@ -656,50 +658,50 @@ std::optional } } -std::pair - split_edge(MultiDiEdge const &e) { - return {OutputMultiDiEdge{e.src, e.src_idx, e.get_uid()}, - InputMultiDiEdge{e.dst, e.dst_idx, e.get_uid()}}; -} - -MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge, - InputMultiDiEdge const &input_edge) { - assert(output_edge.uid.first == input_edge.dst.value()); - assert(output_edge.uid.second == input_edge.dst_idx.value()); - assert(input_edge.uid.first == output_edge.src.value()); - assert(input_edge.uid.second == output_edge.src_idx.value()); - return { - input_edge.dst, input_edge.dst_idx, output_edge.src, output_edge.src_idx}; -} - -std::unordered_set get_cut_set(MultiDiGraphView const &g, - GraphSplit const &s) { - return set_union( - g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes(s.first).with_dst_nodes( - s.second)), - g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes(s.second).with_dst_nodes( - s.first))); -} - -std::unordered_set - get_cut_set(MultiDiGraphView const &g, - std::unordered_set const &nodes) { - return get_cut_set(g, GraphSplit{nodes, set_difference(get_nodes(g), nodes)}); -} - -bidict> - get_edge_splits(MultiDiGraphView const &graph, GraphSplit const &split) { - bidict> result; - std::unordered_set cut_set = get_cut_set(graph, split); - for (MultiDiEdge const &edge : cut_set) { - result.equate(edge, split_edge(edge)); - } - return result; - return generate_bidict(get_cut_set(graph, split), - [](MultiDiEdge const &e) { return split_edge(e); }); -} +// std::pair +// split_edge(MultiDiEdge const &e) { +// return {OutputMultiDiEdge{e.src, e.src_idx, e.get_uid()}, +// InputMultiDiEdge{e.dst, e.dst_idx, e.get_uid()}}; +// } + +// MultiDiEdge unsplit_edge(OutputMultiDiEdge const &output_edge, +// InputMultiDiEdge const &input_edge) { +// assert(output_edge.uid.first == input_edge.dst.value()); +// assert(output_edge.uid.second == input_edge.dst_idx.value()); +// assert(input_edge.uid.first == output_edge.src.value()); +// assert(input_edge.uid.second == output_edge.src_idx.value()); +// return { +// input_edge.dst, input_edge.dst_idx, output_edge.src, output_edge.src_idx}; +// } + +// std::unordered_set get_cut_set(MultiDiGraphView const &g, +// GraphSplit const &s) { +// return set_union( +// g.query_edges( +// MultiDiEdgeQuery::all().with_src_nodes(s.first).with_dst_nodes( +// s.second)), +// g.query_edges( +// MultiDiEdgeQuery::all().with_src_nodes(s.second).with_dst_nodes( +// s.first))); +// } + +// std::unordered_set +// get_cut_set(MultiDiGraphView const &g, +// std::unordered_set const &nodes) { +// return get_cut_set(g, GraphSplit{nodes, set_difference(get_nodes(g), nodes)}); +// } + +// bidict> +// get_edge_splits(MultiDiGraphView const &graph, GraphSplit const &split) { +// bidict> result; +// std::unordered_set cut_set = get_cut_set(graph, split); +// for (MultiDiEdge const &edge : cut_set) { +// result.equate(edge, split_edge(edge)); +// } +// return result; +// return generate_bidict(get_cut_set(graph, split), +// [](MultiDiEdge const &e) { return split_edge(e); }); +// } UndirectedGraphView get_subgraph(UndirectedGraphView const &g, std::unordered_set const &nodes) { @@ -711,15 +713,15 @@ DiGraphView get_subgraph(DiGraphView const &g, return DiGraphView::create(g, nodes); } -MultiDiGraphView get_subgraph(MultiDiGraphView const &g, - std::unordered_set const &nodes) { - return MultiDiGraphView::create(g, nodes); -} +// MultiDiGraphView get_subgraph(MultiDiGraphView const &g, +// std::unordered_set const &nodes) { +// return MultiDiGraphView::create(g, nodes); +// } -MultiDiGraphView join(MultiDiGraphView const &lhs, - MultiDiGraphView const &rhs) { - return MultiDiGraphView::create(lhs, rhs); -} +// MultiDiGraphView join(MultiDiGraphView const &lhs, +// MultiDiGraphView const &rhs) { +// return MultiDiGraphView::create(lhs, rhs); +// } DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) { return DiGraphView::create(lhs, rhs); @@ -734,27 +736,27 @@ UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } -MultiDiGraphView as_multidigraph(DiGraphView const &g) { - return MultiDiGraphView::create(g); -} +// MultiDiGraphView as_multidigraph(DiGraphView const &g) { +// return MultiDiGraphView::create(g); +// } DiGraphView as_digraph(UndirectedGraphView const &g) { return DiGraphView::create(g); } -OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &g) { - return OpenMultiDiGraphView::create(g); -} +// OpenMultiDiGraphView as_openmultidigraph(MultiDiGraphView const &g) { +// return OpenMultiDiGraphView::create(g); +// } std::unordered_set> get_weakly_connected_components(DiGraphView const &g) { return get_connected_components(as_undirected(g)); } -std::unordered_set> - get_weakly_connected_components(MultiDiGraphView const &g) { - return get_connected_components(as_undirected(g)); -} +// std::unordered_set> +// get_weakly_connected_components(MultiDiGraphView const &g) { +// return get_connected_components(as_undirected(g)); +// } std::unordered_set> get_connected_components(UndirectedGraphView const &g) { @@ -770,28 +772,28 @@ std::unordered_set> return components; } -std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return get_incoming_edges(g, n).size() == 0; - }); -} - -std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return get_outgoing_edges(g, n).size() == 0; - }); -} - -std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return !get_incoming_edges(g, n).empty(); - }); -} - -std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return !get_outgoing_edges(g, n).empty(); - }); -} +// std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g) { +// return filter(get_nodes(g), [&](Node const &n) { +// return get_incoming_edges(g, n).size() == 0; +// }); +// } + +// std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { +// return filter(get_nodes(g), [&](Node const &n) { +// return get_outgoing_edges(g, n).size() == 0; +// }); +// } + +// std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { +// return filter(get_nodes(g), [&](Node const &n) { +// return !get_incoming_edges(g, n).empty(); +// }); +// } + +// std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { +// return filter(get_nodes(g), [&](Node const &n) { +// return !get_outgoing_edges(g, n).empty(); +// }); +// } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc index 794519ffed..2c7c8b9866 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge.dtg.cc @@ -9,8 +9,6 @@ #include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" -#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" -#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc new file mode 100644 index 0000000000..03aaae8559 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc @@ -0,0 +1,23 @@ +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" + +namespace FlexFlow { + +DataflowEdgeQuery dataflow_edge_query_all() { + return DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowEdgeQuery dataflow_edge_query_none() { + return DataflowEdgeQuery{ + query_set::match_none(), + query_set::match_none(), + query_set::match_none(), + query_set::match_none(), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc index 65ac12003c..d80b5b7afd 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc @@ -3,14 +3,12 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml /* proj-data { - "generated_from": "684726a7add4aa912e194335fcfe91ab" + "generated_from": "111e640382a80b659bc33dd86a416ded" } */ #include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" -#include "utils/graph/node.dtg.h" -#include "utils/graph/query_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc index f3afb4a9b1..5fc3b177f2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc @@ -7,4 +7,24 @@ NodeAddedResult DataflowGraph::add_node(std::vector const &input return this->get_interface().add_node(inputs, num_outputs); } +std::unordered_set DataflowGraph::query_nodes(NodeQuery const &q) const { + return this->get_interface().query_nodes(q); +} + +std::unordered_set DataflowGraph::query_edges(DataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); +} + +std::unordered_set DataflowGraph::query_outputs(DataflowOutputQuery const &q) const { + return this->get_interface().query_outputs(q); +} + +IDataflowGraph &DataflowGraph::get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); +} + +IDataflowGraph const &DataflowGraph::get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc index 32d61f4f0b..06fa244ab5 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_input.dtg.cc @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_input.struct.toml /* proj-data { - "generated_from": "9fc7657f7fcc71fdad9e6a5040771ad7" + "generated_from": "d43532deb325bcf8a502efbe90cd287b" } */ diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc index 8c8cf6b73a..a0aeb56a9b 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output.dtg.cc @@ -3,13 +3,12 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_output.struct.toml /* proj-data { - "generated_from": "b704f2549a69ee6bfc1c5e28df421f9c" + "generated_from": "3f4ea6635782f141cc593291132c4064" } */ #include "utils/graph/dataflow_graph/dataflow_output.dtg.h" -#include "utils/graph/node.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc index 7bc200e887..94b6d8bcaa 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.dtg.cc @@ -3,14 +3,12 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.struct.toml /* proj-data { - "generated_from": "6f662c3c4d285a4fd3c60713e6fc67fa" + "generated_from": "de957a7524bf0423dcfb68f70b2e6815" } */ #include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" -#include "utils/graph/node.dtg.h" -#include "utils/graph/query_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc index f29054cc2d..dc645bc0b8 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph_view.cc @@ -3,7 +3,7 @@ namespace FlexFlow { -std::unordered_set IDataflowGraphView::query_edges(MultiDiEdgeQuery const &q) const { +std::unordered_set IDataflowGraphView::query_edges(DirectedEdgeQuery const &q) const { DataflowEdgeQuery dataflow_query = DataflowEdgeQuery{ q.srcs, matchall(), @@ -13,7 +13,7 @@ std::unordered_set IDataflowGraphView::query_edges(MultiDiEdgeQuery std::unordered_set dataflow_edges = this->query_edges(dataflow_query); return transform(dataflow_edges, [](DataflowEdge const &e) { - return MultiDiEdge{e.src.node, e.dst.node, std::make_pair(e.src.idx, e.dst.idx)}; + return DirectedEdge{e.src.node, e.dst.node}; }); } diff --git a/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc index dcbe3578f2..3c6eeace8a 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/node_added_result.dtg.cc @@ -3,18 +3,13 @@ // lib/utils/include/utils/graph/dataflow_graph/node_added_result.struct.toml /* proj-data { - "generated_from": "4536bb54376e2e221e0ff29347e81662" + "generated_from": "6e5dc11e71c895683bd5bb9c30c1e42d" } */ #include "utils/graph/dataflow_graph/node_added_result.dtg.h" -#include "utils/fmt/vector.h" -#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" -#include "utils/graph/multidigraph/multi_di_edge.dtg.h" -#include "utils/graph/node.dtg.h" #include -#include namespace FlexFlow { NodeAddedResult::NodeAddedResult( diff --git a/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc b/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc index 705d8f6158..4a54986832 100644 --- a/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc @@ -40,7 +40,7 @@ std::unordered_set std::unordered_set result; for (auto const &src_kv : query_keys(query.srcs, this->adjacency)) { for (auto const &dst : apply_query(query.dsts, src_kv.second)) { - result.insert({src_kv.first, dst}); + result.insert(DirectedEdge{src_kv.first, dst}); } } return result; diff --git a/lib/utils/src/utils/graph/digraph/di_input.dtg.cc b/lib/utils/src/utils/graph/digraph/di_input.dtg.cc index 7b44d41e97..506894a902 100644 --- a/lib/utils/src/utils/graph/digraph/di_input.dtg.cc +++ b/lib/utils/src/utils/graph/digraph/di_input.dtg.cc @@ -9,7 +9,6 @@ #include "utils/graph/digraph/di_input.dtg.h" -#include "utils/graph/node/node.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/di_output.dtg.cc b/lib/utils/src/utils/graph/digraph/di_output.dtg.cc index 9723a1cd84..7782a95270 100644 --- a/lib/utils/src/utils/graph/digraph/di_output.dtg.cc +++ b/lib/utils/src/utils/graph/digraph/di_output.dtg.cc @@ -3,14 +3,12 @@ // lib/utils/include/utils/graph/digraph/di_output.struct.toml /* proj-data { - "generated_from": "a8f3fc2ad9e00f3c29a6dcd4658199ba" + "generated_from": "61e6ee4a13c7608bf6df0a549b94b2bc" } */ #include "utils/graph/digraph/di_output.dtg.h" -#include "utils/graph/node.dtg.h" - namespace FlexFlow { DiOutput::DiOutput(::FlexFlow::Node const &src) : src(src) {} bool DiOutput::operator==(DiOutput const &other) const { diff --git a/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc b/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc index 79f910de69..e925e48c55 100644 --- a/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc +++ b/lib/utils/src/utils/graph/digraph/directed_edge.dtg.cc @@ -9,7 +9,6 @@ #include "utils/graph/digraph/directed_edge.dtg.h" -#include "utils/graph/node/node.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc b/lib/utils/src/utils/graph/digraph/directed_edge_query.cc similarity index 95% rename from lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc rename to lib/utils/src/utils/graph/digraph/directed_edge_query.cc index 2522e6aaa1..b12098bd96 100644 --- a/lib/utils/src/utils/graph/directed_graph/directed_edge_query.cc +++ b/lib/utils/src/utils/graph/digraph/directed_edge_query.cc @@ -1,4 +1,4 @@ -#include "utils/graph/directed_graph/directed_edge_query.h" +#include "utils/graph/digraph/directed_edge_query.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc b/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc index 3804bd3399..5e4117eacd 100644 --- a/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc +++ b/lib/utils/src/utils/graph/digraph/directed_edge_query.dtg.cc @@ -3,14 +3,12 @@ // lib/utils/include/utils/graph/digraph/directed_edge_query.struct.toml /* proj-data { - "generated_from": "294ae0103df2a3c388a2ce140c271f4e" + "generated_from": "4d7f3398fb178b272a4230d2db24c0d5" } */ #include "utils/graph/digraph/directed_edge_query.dtg.h" -#include "utils/graph/node/node.dtg.h" -#include "utils/graph/query_set.h" #include namespace FlexFlow { @@ -38,6 +36,18 @@ bool DirectedEdgeQuery::operator>=(DirectedEdgeQuery const &other) const { } } // namespace FlexFlow +namespace std { +size_t hash::operator()( + ::FlexFlow::DirectedEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.srcs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dsts) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + namespace FlexFlow { std::string format_as(DirectedEdgeQuery const &x) { std::ostringstream oss; diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc rename to lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge.dtg.cc.old diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/downward_open_multidigraph/downward_open_multi_di_edge_query.dtg.cc.old diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc.old similarity index 100% rename from lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc rename to lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph.cc.old diff --git a/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc b/lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc.old similarity index 100% rename from lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc rename to lib/utils/src/utils/graph/downward_open_multidigraph/i_downward_open_multidigraph_view.cc.old diff --git a/lib/utils/src/utils/graph/labelled_graphs.cc b/lib/utils/src/utils/graph/labelled_graphs.cc deleted file mode 100644 index 4f73c3b6af..0000000000 --- a/lib/utils/src/utils/graph/labelled_graphs.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "utils/graph/labelled_graphs.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidiedge.cc b/lib/utils/src/utils/graph/multidiedge.cc deleted file mode 100644 index cd3655c8e6..0000000000 --- a/lib/utils/src/utils/graph/multidiedge.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "utils/graph/multidiedge.h" - -namespace FlexFlow { - -bool MultiDiOutput::operator>(MultiDiOutput const &other) const { - return !(*this < other) && !(*this == other); -} - -bool MultiDiOutput::operator>=(MultiDiOutput const &other) const { - return !(*this < other); -} - -bool MultiDiOutput::operator<=(MultiDiOutput const &other) const { - return (*this < other) || (*this == other); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc.old similarity index 100% rename from lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc rename to lib/utils/src/utils/graph/multidigraph/adjacency_multidigraph.cc.old diff --git a/lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc.old similarity index 100% rename from lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc rename to lib/utils/src/utils/graph/multidigraph/i_multidigraph.cc.old diff --git a/lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc b/lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc.old similarity index 100% rename from lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc rename to lib/utils/src/utils/graph/multidigraph/i_multidigraph_view.cc.old diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc deleted file mode 100644 index ae9070c9dd..0000000000 --- a/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc +++ /dev/null @@ -1,73 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml -/* proj-data -{ - "generated_from": "73b001bfb7a0b75c42cd5037bb8dc686" -} -*/ - -#include "utils/graph/multidigraph/multi_di_edge.dtg.h" - -#include "utils/graph/node/node.dtg.h" -#include - -namespace FlexFlow { -MultiDiEdge::MultiDiEdge(::FlexFlow::Node const &src, - ::FlexFlow::Node const &dst, - std::pair const &raw_edge_uid) - : src(src), dst(dst), raw_edge_uid(raw_edge_uid) {} -bool MultiDiEdge::operator==(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) == - std::tie(other.src, other.dst, other.raw_edge_uid); -} -bool MultiDiEdge::operator!=(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) != - std::tie(other.src, other.dst, other.raw_edge_uid); -} -bool MultiDiEdge::operator<(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) < - std::tie(other.src, other.dst, other.raw_edge_uid); -} -bool MultiDiEdge::operator>(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) > - std::tie(other.src, other.dst, other.raw_edge_uid); -} -bool MultiDiEdge::operator<=(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) <= - std::tie(other.src, other.dst, other.raw_edge_uid); -} -bool MultiDiEdge::operator>=(MultiDiEdge const &other) const { - return std::tie(this->src, this->dst, this->raw_edge_uid) >= - std::tie(other.src, other.dst, other.raw_edge_uid); -} -} // namespace FlexFlow - -namespace std { -size_t hash::operator()( - ::FlexFlow::MultiDiEdge const &x) const { - size_t result = 0; - result ^= std::hash<::FlexFlow::Node>{}(x.src) + 0x9e3779b9 + (result << 6) + - (result >> 2); - result ^= std::hash<::FlexFlow::Node>{}(x.dst) + 0x9e3779b9 + (result << 6) + - (result >> 2); - result ^= std::hash>{}(x.raw_edge_uid) + 0x9e3779b9 + - (result << 6) + (result >> 2); - return result; -} -} // namespace std - -namespace FlexFlow { -std::string format_as(MultiDiEdge const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, MultiDiEdge const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc.old b/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc.old new file mode 100644 index 0000000000..10432fdb4d --- /dev/null +++ b/lib/utils/src/utils/graph/multidigraph/multi_di_edge.dtg.cc.old @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/multidigraph/multi_di_edge.struct.toml +/* proj-data +{ + "generated_from": "b7e237c6d5f55b89cb72848b20aef534" +} +*/ + +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" + +#include "utils/graph/node/node.dtg.h" +#include + +namespace FlexFlow { +MultiDiEdge::MultiDiEdge(size_t const &raw_edge_uid) + : raw_edge_uid(raw_edge_uid) {} +bool MultiDiEdge::operator==(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) == std::tie(other.raw_edge_uid); +} +bool MultiDiEdge::operator!=(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) != std::tie(other.raw_edge_uid); +} +bool MultiDiEdge::operator<(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) < std::tie(other.raw_edge_uid); +} +bool MultiDiEdge::operator>(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) > std::tie(other.raw_edge_uid); +} +bool MultiDiEdge::operator<=(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) <= std::tie(other.raw_edge_uid); +} +bool MultiDiEdge::operator>=(MultiDiEdge const &other) const { + return std::tie(this->raw_edge_uid) >= std::tie(other.raw_edge_uid); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::MultiDiEdge const &x) const { + size_t result = 0; + result ^= std::hash{}(x.raw_edge_uid) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(MultiDiEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiDiEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc.old similarity index 100% rename from lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc rename to lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.cc.old diff --git a/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc.old similarity index 61% rename from lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc.old index 686a6f362a..ba01d48808 100644 --- a/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc +++ b/lib/utils/src/utils/graph/multidigraph/multi_di_edge_query.dtg.cc.old @@ -3,12 +3,13 @@ // lib/utils/include/utils/graph/multidigraph/multi_di_edge_query.struct.toml /* proj-data { - "generated_from": "bede7a523428098275e26ba89bb30eb0" + "generated_from": "56edb1e799c2bdf7435479ce8a483311" } */ #include "utils/graph/multidigraph/multi_di_edge_query.dtg.h" +#include "utils/graph/multidigraph/multi_di_edge.dtg.h" #include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" #include @@ -16,25 +17,32 @@ namespace FlexFlow { MultiDiEdgeQuery::MultiDiEdgeQuery( ::FlexFlow::query_set<::FlexFlow::Node> const &srcs, - ::FlexFlow::query_set<::FlexFlow::Node> const &dsts) - : srcs(srcs), dsts(dsts) {} + ::FlexFlow::query_set<::FlexFlow::Node> const &dsts, + ::FlexFlow::query_set<::FlexFlow::Edge> const &uids) + : srcs(srcs), dsts(dsts), uids(uids) {} bool MultiDiEdgeQuery::operator==(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) == std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) == + std::tie(other.srcs, other.dsts, other.uids); } bool MultiDiEdgeQuery::operator!=(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) != std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) != + std::tie(other.srcs, other.dsts, other.uids); } bool MultiDiEdgeQuery::operator<(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) < std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) < + std::tie(other.srcs, other.dsts, other.uids); } bool MultiDiEdgeQuery::operator>(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) > std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) > + std::tie(other.srcs, other.dsts, other.uids); } bool MultiDiEdgeQuery::operator<=(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) <= std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) <= + std::tie(other.srcs, other.dsts, other.uids); } bool MultiDiEdgeQuery::operator>=(MultiDiEdgeQuery const &other) const { - return std::tie(this->srcs, this->dsts) >= std::tie(other.srcs, other.dsts); + return std::tie(this->srcs, this->dsts, this->uids) >= + std::tie(other.srcs, other.dsts, other.uids); } } // namespace FlexFlow @@ -46,6 +54,8 @@ size_t hash::operator()( 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dsts) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Edge>>{}(x.uids) + + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } } // namespace std @@ -56,6 +66,7 @@ std::string format_as(MultiDiEdgeQuery const &x) { oss << ""; return oss.str(); } diff --git a/lib/utils/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/src/utils/graph/multidigraph/multidigraph.cc.old similarity index 100% rename from lib/utils/src/utils/graph/multidigraph/multidigraph.cc rename to lib/utils/src/utils/graph/multidigraph/multidigraph.cc.old diff --git a/lib/utils/src/utils/graph/node/graph.cc b/lib/utils/src/utils/graph/node/graph.cc index 69a66f169d..244716c3ce 100644 --- a/lib/utils/src/utils/graph/node/graph.cc +++ b/lib/utils/src/utils/graph/node/graph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/undirected/graph.h" +#include "utils/graph/node/graph.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/node/graph_view.cc b/lib/utils/src/utils/graph/node/graph_view.cc index 5ea0fe7b63..5404e29f23 100644 --- a/lib/utils/src/utils/graph/node/graph_view.cc +++ b/lib/utils/src/utils/graph/node/graph_view.cc @@ -1,4 +1,4 @@ -#include "utils/graph/undirected/graph_view.h" +#include "utils/graph/node/graph_view.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/node/i_graph_view.cc b/lib/utils/src/utils/graph/node/i_graph_view.cc index 63c5b829cb..d9cced00f2 100644 --- a/lib/utils/src/utils/graph/node/i_graph_view.cc +++ b/lib/utils/src/utils/graph/node/i_graph_view.cc @@ -1 +1 @@ -#include "utils/graph/undirected/i_graph_view.h" +#include "utils/graph/node/i_graph_view.h" diff --git a/lib/utils/src/utils/graph/node/node.dtg.cc b/lib/utils/src/utils/graph/node/node.dtg.cc index 6a314f64dd..9b4feb637e 100644 --- a/lib/utils/src/utils/graph/node/node.dtg.cc +++ b/lib/utils/src/utils/graph/node/node.dtg.cc @@ -9,7 +9,6 @@ #include "utils/graph/node/node.dtg.h" -#include #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/node/node_query.dtg.cc b/lib/utils/src/utils/graph/node/node_query.dtg.cc index 516d9f9d88..259f0bc52a 100644 --- a/lib/utils/src/utils/graph/node/node_query.dtg.cc +++ b/lib/utils/src/utils/graph/node/node_query.dtg.cc @@ -9,8 +9,6 @@ #include "utils/graph/node/node_query.dtg.h" -#include "utils/graph/node/node.dtg.h" -#include "utils/graph/query_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/node/node_source.cc b/lib/utils/src/utils/graph/node/node_source.cc new file mode 100644 index 0000000000..095f07e68b --- /dev/null +++ b/lib/utils/src/utils/graph/node/node_source.cc @@ -0,0 +1,15 @@ +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +size_t NodeSource::next_available_node_id = 0; + +NodeSource::NodeSource() {} + +Node NodeSource::new_node() { + Node result = Node{NodeSource::next_available_node_id}; + NodeSource::next_available_node_id++; + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc rename to lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge.dtg.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc rename to lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/input_multi_di_edge_query.dtg.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc rename to lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge.dtg.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc rename to lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/open_multi_di_edge_query.dtg.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc rename to lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge.dtg.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc rename to lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.cc.old diff --git a/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/open_multidigraph/output_multi_di_edge_query.dtg.cc.old diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.cc new file mode 100644 index 0000000000..7f35c633b4 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.cc @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +/* proj-data +{ + "generated_from": "2a07baf4a649daf2ef8329fcc3bc611d" +} +*/ + +#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" + +#include "utils/hash/vector.h" + +namespace FlexFlow { +IntermediateSpDecompositionTree::IntermediateSpDecompositionTree( + ::FlexFlow::SplitType const &type, + std::vector> const &children) + : type(type), children(children) {} +bool IntermediateSpDecompositionTree::operator==( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) == + std::tie(other.type, other.children); +} +bool IntermediateSpDecompositionTree::operator!=( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) != + std::tie(other.type, other.children); +} +bool IntermediateSpDecompositionTree::operator<( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) < + std::tie(other.type, other.children); +} +bool IntermediateSpDecompositionTree::operator>( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) > + std::tie(other.type, other.children); +} +bool IntermediateSpDecompositionTree::operator<=( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) <= + std::tie(other.type, other.children); +} +bool IntermediateSpDecompositionTree::operator>=( + IntermediateSpDecompositionTree const &other) const { + return std::tie(this->type, this->children) >= + std::tie(other.type, other.children); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::IntermediateSpDecompositionTree const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::SplitType>{}(x.type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash< + std::vector>>{}(x.children) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h new file mode 100644 index 0000000000..c3ff1d2046 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h @@ -0,0 +1,48 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +/* proj-data +{ + "generated_from": "2a07baf4a649daf2ef8329fcc3bc611d" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_DTG_H +#define _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_DTG_H + +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/serial_parallel/split_type.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct IntermediateSpDecompositionTree { + IntermediateSpDecompositionTree() = delete; + explicit IntermediateSpDecompositionTree( + ::FlexFlow::SplitType const &type, + std::vector> const &children); + + bool operator==(IntermediateSpDecompositionTree const &) const; + bool operator!=(IntermediateSpDecompositionTree const &) const; + bool operator<(IntermediateSpDecompositionTree const &) const; + bool operator>(IntermediateSpDecompositionTree const &) const; + bool operator<=(IntermediateSpDecompositionTree const &) const; + bool operator>=(IntermediateSpDecompositionTree const &) const; + ::FlexFlow::SplitType type; + std::vector> + children; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::IntermediateSpDecompositionTree> { + size_t operator()(::FlexFlow::IntermediateSpDecompositionTree const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_DTG_H diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..c84efea03b --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "IntermediateSpDecompositionTree" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph/serial_parallel/split_type.dtg.h", + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", +] + +[[fields]] +name = "type" +type = "::FlexFlow::SplitType" + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/src/utils/graph/serial_parallel/parallel.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/parallel.dtg.cc new file mode 100644 index 0000000000..0f200df62c --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/parallel.dtg.cc @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml +/* proj-data +{ + "generated_from": "0aee46f91f8e9ae0f18e1f496aa886b4" +} +*/ + +#include "utils/graph/serial_parallel/parallel.dtg.h" + +#include "utils/fmt/variant.h" +#include "utils/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +namespace FlexFlow { +Parallel::Parallel( + std::vector> const + &children) + : children(children) {} +bool Parallel::operator==(Parallel const &other) const { + return std::tie(this->children) == std::tie(other.children); +} +bool Parallel::operator!=(Parallel const &other) const { + return std::tie(this->children) != std::tie(other.children); +} +bool Parallel::operator<(Parallel const &other) const { + return std::tie(this->children) < std::tie(other.children); +} +bool Parallel::operator>(Parallel const &other) const { + return std::tie(this->children) > std::tie(other.children); +} +bool Parallel::operator<=(Parallel const &other) const { + return std::tie(this->children) <= std::tie(other.children); +} +bool Parallel::operator>=(Parallel const &other) const { + return std::tie(this->children) >= std::tie(other.children); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(::FlexFlow::Parallel const &x) const { + size_t result = 0; + result ^= + std::hash< + std::vector>>{}( + x.children) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(Parallel const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Parallel const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/serial.dtg.cc new file mode 100644 index 0000000000..1cfa222f78 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serial.dtg.cc @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/serial.struct.toml +/* proj-data +{ + "generated_from": "c5342b3e8b7dfa96c95fc171f85a0cf7" +} +*/ + +#include "utils/graph/serial_parallel/serial.dtg.h" + +#include "utils/fmt/variant.h" +#include "utils/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +namespace FlexFlow { +Serial::Serial( + std::vector> const + &children) + : children(children) {} +bool Serial::operator==(Serial const &other) const { + return std::tie(this->children) == std::tie(other.children); +} +bool Serial::operator!=(Serial const &other) const { + return std::tie(this->children) != std::tie(other.children); +} +bool Serial::operator<(Serial const &other) const { + return std::tie(this->children) < std::tie(other.children); +} +bool Serial::operator>(Serial const &other) const { + return std::tie(this->children) > std::tie(other.children); +} +bool Serial::operator<=(Serial const &other) const { + return std::tie(this->children) <= std::tie(other.children); +} +bool Serial::operator>=(Serial const &other) const { + return std::tie(this->children) >= std::tie(other.children); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(::FlexFlow::Serial const &x) const { + size_t result = 0; + result ^= + std::hash< + std::vector>>{}( + x.children) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(Serial const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Serial const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.cc new file mode 100644 index 0000000000..c8e774e6e6 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +/* proj-data +{ + "generated_from": "c019d65a059a20f13a419fa343ad0d26" +} +*/ + +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" + +#include + +namespace FlexFlow { +SerialParallelDecomposition::SerialParallelDecomposition( + ::FlexFlow::Serial const &v) + : raw_variant(v) {} +SerialParallelDecomposition::SerialParallelDecomposition( + ::FlexFlow::Parallel const &v) + : raw_variant(v) {} +SerialParallelDecomposition::SerialParallelDecomposition( + ::FlexFlow::Node const &v) + : raw_variant(v) {} +bool SerialParallelDecomposition::operator==( + SerialParallelDecomposition const &other) const { + return this->raw_variant == other.raw_variant; +} +bool SerialParallelDecomposition::operator!=( + SerialParallelDecomposition const &other) const { + return this->raw_variant != other.raw_variant; +} +bool SerialParallelDecomposition::operator<( + SerialParallelDecomposition const &other) const { + return this->raw_variant < other.raw_variant; +} +bool SerialParallelDecomposition::operator>( + SerialParallelDecomposition const &other) const { + return this->raw_variant > other.raw_variant; +} +bool SerialParallelDecomposition::operator<=( + SerialParallelDecomposition const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool SerialParallelDecomposition::operator>=( + SerialParallelDecomposition const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::SerialParallelDecomposition>::operator()( + ::FlexFlow::SerialParallelDecomposition const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::SerialParallelDecomposition const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type SerialParallelDecomposition", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::SerialParallelDecomposition const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc new file mode 100644 index 0000000000..233eb028e2 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc @@ -0,0 +1,194 @@ +#include "utils/graph/serial_parallel/serialparallel.h" +#include "./serialparallel_internal.h" +#include "./sink_settings.dtg.h" +#include "./source_settings.dtg.h" +#include "utils/containers.h" +#include "utils/graph/algorithms.h" + +namespace FlexFlow { + +Node find_source_node(DiGraphView const &g) { + std::unordered_set srcs = get_sources(g); + return get_only(srcs); +} + +Node find_sink_node(DiGraphView const &g) { + std::unordered_set sinks = get_sinks(g); + return get_only(sinks); +} + +std::optional find_bottleneck_node(DiGraphView const &g) { + std::unordered_set sources = get_sources(g); + std::unordered_set sinks = get_sinks(g); + + std::optional maybe_bottleneck = get_imm_post_dominator(g, sources); + if (maybe_bottleneck.has_value()) { + assert(contains(get_dominators(g, sinks), maybe_bottleneck.value())); + } + return maybe_bottleneck; +} + +std::unordered_set from_source_to_sink(DiGraphView const &g, + Node const &src, + Node const &sink) { + assert(contains(get_dominators(g, sink), src)); + + std::vector bfs = get_bfs_ordering(g, {src}); + auto end = find(bfs, sink); + assert(end != bfs.end()); + + std::unordered_set result(bfs.cbegin(), ++end); + return result; +} + +struct FlattenAST { + void add_flattened_child_to_parent(IntermediateSpDecompositionTree &parent, + SplitAST const &child) { + if (std::holds_alternative(child)) { + parent.children.push_back(child); + return; + } + + IntermediateSpDecompositionTree child_node = get(child); + + if (parent.type == child_node.type) { + extend(parent.children, child_node.children); + } else { + parent.children.push_back(child); + } + } + + SplitAST operator()(IntermediateSpDecompositionTree const &ast_node) { + IntermediateSpDecompositionTree result(ast_node.type); + for (SplitAST const &child : ast_node.children) { + SplitAST flattened_child = flatten_ast(child); + add_flattened_child_to_parent(result, flattened_child); + } + return result; + } + + SplitAST operator()(Node const &ast_node) { + return ast_node; + } +}; + +SerialParallelDecomposition + get_serial_parallel_decomposition(DiGraphView const &g) { + SplitAST ast = sp_decomposition(g); + return to_final_ast(ast); +} + +std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { + return sp.visit([](auto &&t) { return get_nodes(t); }); +} + +std::unordered_set get_nodes(Serial const &serial) { + return set_union(transform( + serial.children, + [](std::variant const &child) -> std::unordered_set { + return std::visit([](auto &&t) { return get_nodes(t); }, child); + })); +} + +std::unordered_set get_nodes(Parallel const ¶llel) { + return set_union( + transform(parallel.children, [](std::variant const &child) { + return visit(GetNodes{}, child); + })); +} + +std::unordered_set get_nodes(Node const &node) { + return {node}; +} + +// std::unordered_map parallel_extend(MultiDiGraph &g, +// MultiDiGraph const &ext) { +// std::unordered_map node_map; +// std::unordered_map node_port_map; +// for (Node const &node : get_nodes(MultiDiGraphView(ext))) { +// node_map.emplace(node, g.add_node()); +// } +// for (NodePort const &node_port : get_present_node_ports(ext)) { +// node_port_map.emplace(node_port, g.add_node_port()); +// } +// for (MultiDiEdge const &edge : get_edges(ext)) { +// g.add_edge(MultiDiEdge{node_map.at(edge.dst), +// node_port_map.at(edge.dst_idx), +// node_map.at(edge.src), +// node_port_map.at(edge.src_idx)}); +// } +// return node_map; +// } + +// std::unordered_map serial_extend(MultiDiGraph &g, +// MultiDiGraph const &ext) { +// std::unordered_set original_sinks = get_sinks(g); +// std::unordered_map node_map = parallel_extend(g, ext); +// for (Node const &node1 : original_sinks) { +// for (Node const &node2 : get_sources(ext)) { +// g.add_edge(MultiDiEdge{ +// node_map.at(node2), g.add_node_port(), node1, g.add_node_port()}); +// } +// } +// return node_map; +// } + +// MultiDiGraph serial_composition(MultiDiGraph const &g1, +// MultiDiGraph const &g2) { +// MultiDiGraph g = g1; +// serial_extend(g, g2); +// return g; +// } + +// MultiDiGraph parallel_composition(MultiDiGraph const &g1, +// MultiDiGraph const &g2) { +// MultiDiGraph g = g1; +// parallel_extend(g, g2); +// return g; +// } + +// struct MultiDiGraphFromSPDecompositionFunctor { +// template +// MultiDiGraph operator()(T const &t) { +// return multidigraph_from_sp_decomposition(t); +// } +// }; + +// MultiDiGraph multidigraph_from_sp_decomposition( +// SerialParallelDecomposition const &sp_decomposition) { +// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +// } + +// MultiDiGraph multidigraph_from_sp_decomposition( +// std::variant const &sp_decomposition) { +// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +// } + +// MultiDiGraph multidigraph_from_sp_decomposition( +// std::variant const &sp_decomposition) { +// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); +// } + +// MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { +// MultiDiGraph g = MultiDiGraph::create(); +// for (std::variant const &child : serial.children) { +// serial_extend(g, multidigraph_from_sp_decomposition(child)); +// } +// return g; +// } + +// MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { +// MultiDiGraph g = MultiDiGraph::create(); +// for (std::variant const &child : parallel.children) { +// parallel_extend(g, multidigraph_from_sp_decomposition(child)); +// } +// return g; +// } + +// MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { +// MultiDiGraph g = MultiDiGraph::create(); +// g.add_node(); +// return g; +// } + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc new file mode 100644 index 0000000000..2fd1e81aa0 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc @@ -0,0 +1,123 @@ +#include "utils/graph/serial_parallel/serialparallel_internal.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/serial_parallel/sink_settings.dtg.h" +#include "utils/graph/serial_parallel/source_settings.dtg.h" +#include "utils/graph/serial_parallel/split_ast_node.dtg.h" + +namespace FlexFlow { + +std::unordered_set + from_source_to_sink(DiGraphView const &g, + std::unordered_set const &srcs, + std::unordered_set const &sinks, + SourceSettings include_src, + SinkSettings include_sink) { + assert(is_acyclic(g)); + + Node contracted_src = get_first(srcs); + Node contracted_sink = get_first(sinks); + std::unordered_map contraction; + for (Node const &src : srcs) { + contraction.insert({src, contracted_src}); + } + for (Node const &sink : sinks) { + contraction.insert({sink, contracted_sink}); + } + auto contracted_view = apply_contraction(g, contraction); + + std::unordered_set result = + from_source_to_sink(contracted_view, contracted_src, contracted_sink); + result.erase(contracted_src); + result.erase(contracted_sink); + + if (include_src == SourceSettings::INCLUDE_SOURCE_NODES) { + result = set_union(result, srcs); + } + if (include_sink == SinkSettings::INCLUDE_SINK_NODES) { + result = set_union(result, sinks); + } + return result; +} + +DiGraphView source_to_sink_subgraph(DiGraphView const &g, + std::unordered_set const &srcs, + std::unordered_set const &sinks, + SourceSettings include_src, + SinkSettings include_sink) { + return get_subgraph( + g, from_source_to_sink(g, srcs, sinks, include_src, include_sink)); +} + +IntermediateSpDecompositionTree sp_decomposition(DiGraphView const &g) { + if (num_nodes(g) == 1) { + return get_only(get_nodes(g)); + } + + std::unordered_set sources = get_sources(g); + std::unordered_set sinks = get_sinks(g); + + std::optional bottleneck = find_bottleneck_node(g); + if (bottleneck.has_value()) { + return IntermediateSpDecompositionTree{ + SplitType::SERIAL, + { + sp_decomposition(source_to_sink_subgraph(g, + sources, + {bottleneck.value()}, + SourceSettings::INCLUDE_SOURCE_NODES, + SinkSettings::EXCLUDE_SINK_NODES)), + sp_decomposition(source_to_sink_subgraph(g, + {bottleneck.value()}, + sinks, + SourceSettings::INCLUDE_SOURCE_NODES, + SinkSettings::INCLUDE_SINK_NODES))}}; + } else { + return parallel_decomposition(g); + } +} + +IntermediateSpDecompositionTree parallel_decomposition(DiGraphView const &g) { + std::unordered_set> weakly_connected_components = + get_weakly_connected_components(g); + assert(weakly_connected_components.size() > 1); + + IntermediateSpDecompositionTree split(SplitType::PARALLEL, {}); + for (auto const &component : weakly_connected_components) { + split.children.push_back(sp_decomposition(get_subgraph(g, component))); + } + + return split; +} + +SplitAST flatten_ast(SplitAST const &ast) { + return visit(FlattenAST{}, ast); +} + +struct ToFinalAST { + std::variant operator()(IntermediateSpDecompositionTree const &node) { + if (node.type == SplitType::SERIAL) { + return Serial{transform(node.children, [](SplitAST const &s) { + return narrow>(internal_to_final_ast(s)).value(); + })}; + } else { + return Parallel{transform(node.children, [](SplitAST const &s) { + return narrow>(internal_to_final_ast(s)).value(); + })}; + } + } + + std::variant operator()(Node const &node) { + return node; + } +}; + +std::variant internal_to_final_ast(SplitAST const &ast) { + return visit(ToFinalAST{}, ast); +} + +SerialParallelDecomposition to_final_ast(SplitAST const &ast) { + return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, + internal_to_final_ast(ast)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h new file mode 100644 index 0000000000..927f50cff1 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_INTERNAL_H +#define _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_INTERNAL_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/serial_parallel/serialparallel.h" +#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" +#include "utils/visitable.h" +#include +#include + +namespace FlexFlow { + +struct ParallelInternal; + +using SplitAST = std::variant; + +IntermediateSpDecompositionTree sp_decomposition(DiGraphView const &g); +IntermediateSpDecompositionTree parallel_decomposition(DiGraphView const &g); + +std::unordered_set + from_source_to_sink(DiGraphView const &, Node const &src, Node const &sink); + +std::variant internal_to_final_ast(SplitAST const &); +SerialParallelDecomposition to_final_ast(SplitAST const &); +SplitAST flatten_ast(SplitAST const &ast); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.cc new file mode 100644 index 0000000000..9cb2d794a5 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.cc @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml +/* proj-data +{ + "generated_from": "547a817b201dd109c446a3bb8afe5343" +} +*/ + +#include "utils/graph/serial_parallel/sink_settings.dtg.h" + +#include +#include + +namespace std { +size_t + hash::operator()(FlexFlow::SinkSettings x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(SinkSettings x) { + switch (x) { + case SinkSettings::INCLUDE_SINK_NODES: + return "INCLUDE_SINK_NODES"; + case SinkSettings::EXCLUDE_SINK_NODES: + return "EXCLUDE_SINK_NODES"; + default: + std::ostringstream oss; + oss << "Unknown SinkSettings value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, SinkSettings x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, SinkSettings x) { + switch (x) { + case SinkSettings::INCLUDE_SINK_NODES: + j = "INCLUDE_SINK_NODES"; + break; + case SinkSettings::EXCLUDE_SINK_NODES: + j = "EXCLUDE_SINK_NODES"; + break; + default: + std::ostringstream oss; + oss << "Unknown SinkSettings value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, SinkSettings &x) { + std::string as_str = j.get(); + if (as_str == "INCLUDE_SINK_NODES") { + x = SinkSettings::INCLUDE_SINK_NODES; + } else if (as_str == "EXCLUDE_SINK_NODES") { + x = SinkSettings::EXCLUDE_SINK_NODES; + } else { + std::ostringstream oss; + oss << "Unknown SinkSettings value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::SinkSettings::INCLUDE_SINK_NODES, + FlexFlow::SinkSettings::EXCLUDE_SINK_NODES); +} +} // namespace rc diff --git a/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.h b/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.h new file mode 100644 index 0000000000..0b0399239b --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/sink_settings.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml +/* proj-data +{ + "generated_from": "547a817b201dd109c446a3bb8afe5343" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SINK_SETTINGS_DTG_H +#define _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SINK_SETTINGS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class SinkSettings { INCLUDE_SINK_NODES, EXCLUDE_SINK_NODES }; +std::string format_as(SinkSettings); +std::ostream &operator<<(std::ostream &, SinkSettings); +void to_json(::nlohmann::json &, SinkSettings); +void from_json(::nlohmann::json const &, SinkSettings &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SinkSettings) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SINK_SETTINGS_DTG_H diff --git a/lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml b/lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml new file mode 100644 index 0000000000..5668556543 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/sink_settings.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SinkSettings" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "INCLUDE_SINK_NODES" + +[[values]] +name = "EXCLUDE_SINK_NODES" diff --git a/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.cc new file mode 100644 index 0000000000..a3a295bfae --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.cc @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml +/* proj-data +{ + "generated_from": "56c4dd8f16ee7756372801aa91f619ea" +} +*/ + +#include "utils/graph/serial_parallel/source_settings.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::SourceSettings x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(SourceSettings x) { + switch (x) { + case SourceSettings::INCLUDE_SOURCE_NODES: + return "INCLUDE_SOURCE_NODES"; + case SourceSettings::EXCLUDE_SOURCE_NODES: + return "EXCLUDE_SOURCE_NODES"; + default: + std::ostringstream oss; + oss << "Unknown SourceSettings value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, SourceSettings x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, SourceSettings x) { + switch (x) { + case SourceSettings::INCLUDE_SOURCE_NODES: + j = "INCLUDE_SOURCE_NODES"; + break; + case SourceSettings::EXCLUDE_SOURCE_NODES: + j = "EXCLUDE_SOURCE_NODES"; + break; + default: + std::ostringstream oss; + oss << "Unknown SourceSettings value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, SourceSettings &x) { + std::string as_str = j.get(); + if (as_str == "INCLUDE_SOURCE_NODES") { + x = SourceSettings::INCLUDE_SOURCE_NODES; + } else if (as_str == "EXCLUDE_SOURCE_NODES") { + x = SourceSettings::EXCLUDE_SOURCE_NODES; + } else { + std::ostringstream oss; + oss << "Unknown SourceSettings value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::SourceSettings::INCLUDE_SOURCE_NODES, + FlexFlow::SourceSettings::EXCLUDE_SOURCE_NODES); +} +} // namespace rc diff --git a/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.h b/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.h new file mode 100644 index 0000000000..8ee8176016 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/source_settings.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml +/* proj-data +{ + "generated_from": "56c4dd8f16ee7756372801aa91f619ea" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SOURCE_SETTINGS_DTG_H +#define _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SOURCE_SETTINGS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class SourceSettings { INCLUDE_SOURCE_NODES, EXCLUDE_SOURCE_NODES }; +std::string format_as(SourceSettings); +std::ostream &operator<<(std::ostream &, SourceSettings); +void to_json(::nlohmann::json &, SourceSettings); +void from_json(::nlohmann::json const &, SourceSettings &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SourceSettings) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_SERIAL_PARALLEL_SOURCE_SETTINGS_DTG_H diff --git a/lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml b/lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml new file mode 100644 index 0000000000..8d17dc4d77 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/source_settings.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SourceSettings" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "INCLUDE_SOURCE_NODES" + +[[values]] +name = "EXCLUDE_SOURCE_NODES" diff --git a/lib/utils/src/utils/graph/serial_parallel/split_type.dtg.cc b/lib/utils/src/utils/graph/serial_parallel/split_type.dtg.cc new file mode 100644 index 0000000000..668823f2e5 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/split_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +/* proj-data +{ + "generated_from": "61d75c03b0273d05eb9707f75132974e" +} +*/ + +#include "utils/graph/serial_parallel/split_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::SplitType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(SplitType x) { + switch (x) { + case SplitType::SERIAL: + return "SERIAL"; + case SplitType::PARALLEL: + return "PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown SplitType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, SplitType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, SplitType x) { + switch (x) { + case SplitType::SERIAL: + j = "SERIAL"; + break; + case SplitType::PARALLEL: + j = "PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown SplitType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, SplitType &x) { + std::string as_str = j.get(); + if (as_str == "SERIAL") { + x = SplitType::SERIAL; + } else if (as_str == "PARALLEL") { + x = SplitType::PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown SplitType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::SplitType::SERIAL, + FlexFlow::SplitType::PARALLEL); +} +} // namespace rc diff --git a/lib/utils/src/utils/graph/serialparallel.cc b/lib/utils/src/utils/graph/serialparallel.cc deleted file mode 100644 index f1c9e41005..0000000000 --- a/lib/utils/src/utils/graph/serialparallel.cc +++ /dev/null @@ -1,325 +0,0 @@ -#include "utils/graph/serialparallel.h" -#include "serialparallel_internal.h" -#include "utils/containers.h" -#include "utils/graph/adjacency_multidigraph.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph.h" - -namespace FlexFlow { - -Node find_source_node(DiGraphView const &g) { - std::unordered_set srcs = get_sources(g); - return get_only(srcs); -} - -Node find_sink_node(DiGraphView const &g) { - std::unordered_set sinks = get_sinks(g); - return get_only(sinks); -} - -std::optional find_bottleneck_node(DiGraphView const &g) { - std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sinks(g); - - std::optional maybe_bottleneck = get_imm_post_dominator(g, sources); - if (maybe_bottleneck.has_value()) { - assert(contains(get_dominators(g, sinks), maybe_bottleneck.value())); - } - return maybe_bottleneck; -} - -enum class SourceSettings { INCLUDE_SOURCE_NODES, EXCLUDE_SOURCE_NODES }; - -enum class SinkSettings { INCLUDE_SINK_NODES, EXCLUDE_SINK_NODES }; - -std::unordered_set from_source_to_sink(DiGraphView const &g, - Node const &src, - Node const &sink) { - assert(contains(get_dominators(g, sink), src)); - - std::vector bfs = get_bfs_ordering(g, {src}); - auto end = find(bfs, sink); - assert(end != bfs.end()); - - std::unordered_set result(bfs.cbegin(), ++end); - return result; -} - -std::unordered_set - from_source_to_sink(DiGraphView const &g, - std::unordered_set const &srcs, - std::unordered_set const &sinks, - SourceSettings include_src, - SinkSettings include_sink) { - assert(is_acyclic(g)); - - Node contracted_src = get_first(srcs); - Node contracted_sink = get_first(sinks); - std::unordered_map contraction; - for (Node const &src : srcs) { - contraction.insert({src, contracted_src}); - } - for (Node const &sink : sinks) { - contraction.insert({sink, contracted_sink}); - } - auto contracted_view = apply_contraction(g, contraction); - - std::unordered_set result = - from_source_to_sink(contracted_view, contracted_src, contracted_sink); - result.erase(contracted_src); - result.erase(contracted_sink); - - if (include_src == SourceSettings::INCLUDE_SOURCE_NODES) { - result = set_union(result, srcs); - } - if (include_sink == SinkSettings::INCLUDE_SINK_NODES) { - result = set_union(result, sinks); - } - return result; -} - -DiGraphView source_to_sink_subgraph(DiGraphView const &g, - std::unordered_set const &srcs, - std::unordered_set const &sinks, - SourceSettings include_src, - SinkSettings include_sink) { - return get_subgraph( - g, from_source_to_sink(g, srcs, sinks, include_src, include_sink)); -} - -SplitAST sp_decomposition(DiGraphView const &g) { - if (num_nodes(g) == 1) { - return get_only(get_nodes(g)); - } - - std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sinks(g); - - std::optional bottleneck = find_bottleneck_node(g); - if (bottleneck.has_value()) { - return SplitASTNode(SplitType::SERIAL, - sp_decomposition(source_to_sink_subgraph( - g, - sources, - {bottleneck.value()}, - SourceSettings::INCLUDE_SOURCE_NODES, - SinkSettings::EXCLUDE_SINK_NODES)), - sp_decomposition(source_to_sink_subgraph( - g, - {bottleneck.value()}, - sinks, - SourceSettings::INCLUDE_SOURCE_NODES, - SinkSettings::INCLUDE_SINK_NODES))); - } else { - return parallel_decomposition(g); - } -} - -SplitAST parallel_decomposition(DiGraphView const &g) { - std::unordered_set> weakly_connected_components = - get_weakly_connected_components(g); - assert(weakly_connected_components.size() > 1); - - SplitASTNode split(SplitType::PARALLEL); - for (auto const &component : weakly_connected_components) { - split.children.push_back(sp_decomposition(get_subgraph(g, component))); - } - - return split; -} - -SplitASTNode::SplitASTNode(SplitType type) : SplitASTNode(type, {}) {} - -SplitASTNode::SplitASTNode(SplitType type, - SplitAST const &lhs, - SplitAST const &rhs) - : SplitASTNode(type, {lhs, rhs}) {} - -SplitASTNode::SplitASTNode(SplitType type, - std::vector const &children) - : type(type), children(children) {} - -struct FlattenAST { - void add_flattened_child_to_parent(SplitASTNode &parent, - SplitAST const &child) { - if (std::holds_alternative(child)) { - parent.children.push_back(child); - return; - } - - SplitASTNode child_node = get(child); - - if (parent.type == child_node.type) { - extend(parent.children, child_node.children); - } else { - parent.children.push_back(child); - } - } - - SplitAST operator()(SplitASTNode const &ast_node) { - SplitASTNode result(ast_node.type); - for (SplitAST const &child : ast_node.children) { - SplitAST flattened_child = flatten_ast(child); - add_flattened_child_to_parent(result, flattened_child); - } - return result; - } - - SplitAST operator()(Node const &ast_node) { - return ast_node; - } -}; - -SplitAST flatten_ast(SplitAST const &ast) { - return visit(FlattenAST{}, ast); -} - -struct ToFinalAST { - std::variant operator()(SplitASTNode const &node) { - if (node.type == SplitType::SERIAL) { - return Serial{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); - })}; - } else { - return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow>(to_final_ast(s)).value(); - })}; - } - } - - std::variant operator()(Node const &node) { - return node; - } -}; - -std::variant to_final_ast(SplitAST const &ast) { - return visit(ToFinalAST{}, ast); -} - -SerialParallelDecomposition - get_serial_parallel_decomposition(DiGraphView const &g) { - SplitAST ast = sp_decomposition(g); - return to_final_ast(ast); -} - -struct GetNodes { - template - std::unordered_set operator()(T const &t) { - return get_nodes(t); - } -}; - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return visit(GetNodes{}, sp); -} - -std::unordered_set get_nodes(Serial const &serial) { - return set_union(transform( - serial.children, - [](std::variant const child) -> std::unordered_set { - return visit(GetNodes{}, child); - })); -} - -std::unordered_set get_nodes(Parallel const ¶llel) { - return set_union( - transform(parallel.children, [](std::variant const &child) { - return visit(GetNodes{}, child); - })); -} - -std::unordered_set get_nodes(Node const &node) { - return {node}; -} - -std::unordered_map parallel_extend(MultiDiGraph &g, - MultiDiGraph const &ext) { - std::unordered_map node_map; - std::unordered_map node_port_map; - for (Node const &node : get_nodes(MultiDiGraphView(ext))) { - node_map.emplace(node, g.add_node()); - } - for (NodePort const &node_port : get_present_node_ports(ext)) { - node_port_map.emplace(node_port, g.add_node_port()); - } - for (MultiDiEdge const &edge : get_edges(ext)) { - g.add_edge(MultiDiEdge{node_map.at(edge.dst), - node_port_map.at(edge.dst_idx), - node_map.at(edge.src), - node_port_map.at(edge.src_idx)}); - } - return node_map; -} - -std::unordered_map serial_extend(MultiDiGraph &g, - MultiDiGraph const &ext) { - std::unordered_set original_sinks = get_sinks(g); - std::unordered_map node_map = parallel_extend(g, ext); - for (Node const &node1 : original_sinks) { - for (Node const &node2 : get_sources(ext)) { - g.add_edge(MultiDiEdge{ - node_map.at(node2), g.add_node_port(), node1, g.add_node_port()}); - } - } - return node_map; -} - -MultiDiGraph serial_composition(MultiDiGraph const &g1, - MultiDiGraph const &g2) { - MultiDiGraph g = g1; - serial_extend(g, g2); - return g; -} - -MultiDiGraph parallel_composition(MultiDiGraph const &g1, - MultiDiGraph const &g2) { - MultiDiGraph g = g1; - parallel_extend(g, g2); - return g; -} - -struct MultiDiGraphFromSPDecompositionFunctor { - template - MultiDiGraph operator()(T const &t) { - return multidigraph_from_sp_decomposition(t); - } -}; - -MultiDiGraph multidigraph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition( - std::variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition( - std::variant const &sp_decomposition) { - return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -} - -MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { - MultiDiGraph g = MultiDiGraph::create(); - for (std::variant const &child : serial.children) { - serial_extend(g, multidigraph_from_sp_decomposition(child)); - } - return g; -} - -MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { - MultiDiGraph g = MultiDiGraph::create(); - for (std::variant const &child : parallel.children) { - parallel_extend(g, multidigraph_from_sp_decomposition(child)); - } - return g; -} - -MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { - MultiDiGraph g = MultiDiGraph::create(); - g.add_node(); - return g; -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serialparallel_internal.h b/lib/utils/src/utils/graph/serialparallel_internal.h deleted file mode 100644 index 3d3e17fecb..0000000000 --- a/lib/utils/src/utils/graph/serialparallel_internal.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_INTERNAL_H -#define _FLEXFLOW_UTILS_GRAPH_SERIALPARALLEL_INTERNAL_H - -#include "utils/graph/digraph.h" -#include "utils/graph/node.h" -#include "utils/graph/serialparallel.h" -#include "utils/visitable.h" -#include -#include - -namespace FlexFlow { - -struct ParallelInternal; - -enum class SplitType { SERIAL, PARALLEL }; - -struct SplitASTNode; - -using SplitAST = std::variant; - -struct SplitASTNode { - SplitASTNode(SplitType type); - SplitASTNode(SplitType type, SplitAST const &lhs, SplitAST const &rhs); - SplitASTNode(SplitType type, std::vector const &children); - - std::vector children; - SplitType type; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(SplitASTNode, children, type); - -SplitAST sp_decomposition(DiGraphView const &g); -SplitAST parallel_decomposition(DiGraphView const &g); - -std::unordered_set - from_source_to_sink(DiGraphView const &, Node const &src, Node const &sink); - -std::variant to_final_ast(SplitAST const &); -SplitAST flatten_ast(SplitAST const &ast); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc index f67e39519c..167ae8fe20 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.dtg.cc @@ -9,8 +9,6 @@ #include "utils/graph/undirected/undirected_edge_query.dtg.h" -#include "utils/graph/node/node.dtg.h" -#include "utils/graph/query_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc.old similarity index 100% rename from lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc rename to lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph.cc.old diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc.old similarity index 100% rename from lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc rename to lib/utils/src/utils/graph/upward_open_multidigraph/i_upward_open_multidigraph_view.cc.old diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc.old similarity index 100% rename from lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc rename to lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.cc.old diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc rename to lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge.dtg.cc.old diff --git a/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc b/lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc.old similarity index 100% rename from lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc rename to lib/utils/src/utils/graph/upward_open_multidigraph/upward_open_multi_di_edge_query.dtg.cc.old diff --git a/lib/utils/src/utils/graph/views/join_node_key.dtg.cc b/lib/utils/src/utils/graph/views/join_node_key.dtg.cc index 0139d3974f..05a4a3ee88 100644 --- a/lib/utils/src/utils/graph/views/join_node_key.dtg.cc +++ b/lib/utils/src/utils/graph/views/join_node_key.dtg.cc @@ -9,8 +9,6 @@ #include "utils/graph/views/join_node_key.dtg.h" -#include "utils/graph/node/node.dtg.h" -#include "utils/graph/views/lr_direction.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/views.cc b/lib/utils/src/utils/graph/views/views.cc similarity index 52% rename from lib/utils/src/utils/graph/views.cc rename to lib/utils/src/utils/graph/views/views.cc index 567914a249..518e9784e9 100644 --- a/lib/utils/src/utils/graph/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -3,13 +3,9 @@ #include "utils/disjoint_set.h" #include "utils/exception.h" #include "utils/graph/algorithms.h" -#include "utils/graph/open_multidigraph/input_multi_di_edge.h" -#include "utils/graph/open_multidigraph/input_multi_di_edge_query.h" -#include "utils/graph/open_multidigraph/output_multi_di_edge.h" #include "utils/graph/undirected/undirected_edge_query.h" #include "utils/graph/node/node_query.h" #include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/multidigraph/multi_di_edge_query.h" namespace FlexFlow { @@ -89,20 +85,20 @@ DiSubgraphView *DiSubgraphView::clone() const { return new DiSubgraphView(g, subgraph_nodes); } -MultiDiSubgraphView::MultiDiSubgraphView( - MultiDiGraphView const &g, std::unordered_set const &subgraph_nodes) - : g(g), subgraph_nodes(subgraph_nodes) {} +// MultiDiSubgraphView::MultiDiSubgraphView( +// MultiDiGraphView const &g, std::unordered_set const &subgraph_nodes) +// : g(g), subgraph_nodes(subgraph_nodes) {} -std::unordered_set - MultiDiSubgraphView::query_edges(MultiDiEdgeQuery const &query) const { - MultiDiEdgeQuery subgraph_query = MultiDiEdgeQuery{this->subgraph_nodes, this->subgraph_nodes}; - return this->g.query_edges(query_intersection(query, subgraph_query)); -} +// std::unordered_set +// MultiDiSubgraphView::query_edges(MultiDiEdgeQuery const &query) const { +// MultiDiEdgeQuery subgraph_query = MultiDiEdgeQuery{this->subgraph_nodes, this->subgraph_nodes}; +// return this->g.query_edges(query_intersection(query, subgraph_query)); +// } -std::unordered_set - MultiDiSubgraphView::query_nodes(NodeQuery const &query) const { - return this->g.query_nodes(query_intersection(query, NodeQuery{this->subgraph_nodes})); -} +// std::unordered_set +// MultiDiSubgraphView::query_nodes(NodeQuery const &query) const { +// return this->g.query_nodes(query_intersection(query, NodeQuery{this->subgraph_nodes})); +// } UndirectedGraphView view_subgraph(UndirectedGraphView const &g, @@ -115,10 +111,10 @@ DiGraphView view_subgraph(DiGraphView const &g, return DiGraphView::create(g, subgraph_nodes); } -MultiDiGraphView view_subgraph(MultiDiGraphView const &g, - std::unordered_set const &subgraph_nodes) { - return MultiDiGraphView::create(g, subgraph_nodes); -} +// MultiDiGraphView view_subgraph(MultiDiGraphView const &g, +// std::unordered_set const &subgraph_nodes) { +// return MultiDiGraphView::create(g, subgraph_nodes); +// } Node NodeSource::fresh_node() { Node result(this->next_node_idx); @@ -259,51 +255,51 @@ DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; } -JoinedMultiDigraphView::JoinedMultiDigraphView(MultiDiGraphView const &lhs, - MultiDiGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -std::unordered_set - JoinedMultiDigraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set - JoinedMultiDigraphView::query_edges(MultiDiEdgeQuery const &query) const { - std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); - std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); - - auto traced_srcs = this->joined_nodes.trace_nodes(srcs); - auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - MultiDiEdgeQuery left_query = MultiDiEdgeQuery{ - traced_srcs.first, traced_dsts.first}; - MultiDiEdgeQuery right_query = MultiDiEdgeQuery{ - traced_srcs.second, traced_dsts.second}; - - return set_union( - transform(this->lhs.query_edges(left_query), - [&](MultiDiEdge const &e) { return this->fix_lhs_edge(e); }), - transform(this->rhs.query_edges(right_query), - [&](MultiDiEdge const &e) { return this->fix_rhs_edge(e); })); -} - -JoinedMultiDigraphView *JoinedMultiDigraphView::clone() const { - return new JoinedMultiDigraphView(lhs, rhs); -} - -MultiDiEdge JoinedMultiDigraphView::fix_lhs_edge(MultiDiEdge const &e) const { - return MultiDiEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}) - }; -} - -MultiDiEdge JoinedMultiDigraphView::fix_rhs_edge(MultiDiEdge const &e) const { - return MultiDiEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}) - }; -} +// JoinedMultiDigraphView::JoinedMultiDigraphView(MultiDiGraphView const &lhs, +// MultiDiGraphView const &rhs) +// : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} + +// std::unordered_set +// JoinedMultiDigraphView::query_nodes(NodeQuery const &query) const { +// return this->joined_nodes.query_nodes(query); +// } + +// std::unordered_set +// JoinedMultiDigraphView::query_edges(MultiDiEdgeQuery const &query) const { +// std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); +// std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); +// +// auto traced_srcs = this->joined_nodes.trace_nodes(srcs); +// auto traced_dsts = this->joined_nodes.trace_nodes(dsts); +// MultiDiEdgeQuery left_query = MultiDiEdgeQuery{ +// traced_srcs.first, traced_dsts.first}; +// MultiDiEdgeQuery right_query = MultiDiEdgeQuery{ +// traced_srcs.second, traced_dsts.second}; +// +// return set_union( +// transform(this->lhs.query_edges(left_query), +// [&](MultiDiEdge const &e) { return this->fix_lhs_edge(e); }), +// transform(this->rhs.query_edges(right_query), +// [&](MultiDiEdge const &e) { return this->fix_rhs_edge(e); })); +// } + +// JoinedMultiDigraphView *JoinedMultiDigraphView::clone() const { +// return new JoinedMultiDigraphView(lhs, rhs); +// } + +// MultiDiEdge JoinedMultiDigraphView::fix_lhs_edge(MultiDiEdge const &e) const { +// return MultiDiEdge{ +// this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT}), +// this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}) +// }; +// } + +// MultiDiEdge JoinedMultiDigraphView::fix_rhs_edge(MultiDiEdge const &e) const { +// return MultiDiEdge{ +// this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT}), +// this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}) +// }; +// } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { return {e.src, e.dst}; @@ -315,14 +311,14 @@ std::unordered_set to_undirected_edges( [](DirectedEdge const &e) { return to_undirected_edge(e); }); } -UndirectedEdge to_undirected_edge(MultiDiEdge const &e) { - return to_undirected_edge(to_directed_edge(e)); -} +// UndirectedEdge to_undirected_edge(MultiDiEdge const &e) { +// return to_undirected_edge(to_directed_edge(e)); +// } -std::unordered_set - to_undirected_edges(std::unordered_set const &multidi_edges) { - return to_undirected_edges(to_directed_edges(multidi_edges)); -} +// std::unordered_set +// to_undirected_edges(std::unordered_set const &multidi_edges) { +// return to_undirected_edges(to_directed_edges(multidi_edges)); +// } std::unordered_set to_directed_edges(UndirectedEdge const &e) { return std::unordered_set{ @@ -336,14 +332,14 @@ std::unordered_set to_directed_edges( return flatmap_v2(undirected_edges, to_directed_edges); } -DirectedEdge to_directed_edge(MultiDiEdge const &e) { - return DirectedEdge{e.src, e.dst}; -} +// DirectedEdge to_directed_edge(MultiDiEdge const &e) { +// return DirectedEdge{e.src, e.dst}; +// } -std::unordered_set - to_directed_edges(std::unordered_set const &multidi_edges) { - return transform(multidi_edges, to_directed_edge); -} +// std::unordered_set +// to_directed_edges(std::unordered_set const &multidi_edges) { +// return transform(multidi_edges, to_directed_edge); +// } ViewDiGraphAsUndirectedGraph::ViewDiGraphAsUndirectedGraph(DiGraphView const &g) : g(g) {} @@ -389,177 +385,177 @@ std::unordered_set return g.query_nodes(q); } -ViewDiGraphAsMultiDiGraph::ViewDiGraphAsMultiDiGraph(DiGraphView const &g) - : g(g) {} - -std::unordered_set ViewDiGraphAsMultiDiGraph::query_edges( - MultiDiEdgeQuery const &multidi_query) const { - DirectedEdgeQuery directed_query{multidi_query.srcs, multidi_query.dsts}; - - std::unordered_set const directed_edges = - this->g.query_edges(directed_query); - - return transform(directed_edges, [](DirectedEdge const &e) { - return MultiDiEdge{e.dst, e.src}; - }); -} - -std::unordered_set - ViewDiGraphAsMultiDiGraph::query_nodes(NodeQuery const &node_query) const { - return this->g.query_nodes(node_query); -} - -ViewMultiDiGraphAsOpenMultiDiGraph::ViewMultiDiGraphAsOpenMultiDiGraph( - MultiDiGraphView const &g) - : g(g) {} - -std::unordered_set - ViewMultiDiGraphAsOpenMultiDiGraph::query_edges( - OpenMultiDiEdgeQuery const &q) const { - return transform(g.query_edges(q.standard_edge_query), - [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); -} - -std::unordered_set - ViewMultiDiGraphAsOpenMultiDiGraph::query_nodes(NodeQuery const &q) const { - return g.query_nodes(q); -} - -ViewMultiDiGraphAsOpenMultiDiGraph * - ViewMultiDiGraphAsOpenMultiDiGraph::clone() const { - return new ViewMultiDiGraphAsOpenMultiDiGraph(g); -} - -std::unordered_set - query_edge(std::unordered_set const &edges, - InputMultiDiEdgeQuery const &q) { - return filter(edges, [&](InputMultiDiEdge const &e) { - return includes(q.dsts, e.dst) && includes(q.dstIdxs, e.dst_idx); - }); -} - -std::unordered_set - query_edge(std::unordered_set const &edges, - OutputMultiDiEdgeQuery const &q) { - return filter(edges, [&](OutputMultiDiEdge const &e) { - return includes(q.srcs, e.src) && includes(q.srcIdxs, e.src_idx); - }); -} - -OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( - OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); - this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); -} - -std::unordered_set - OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { - OpenMultiDiEdgeQuery subgraph_query( - q.input_edge_query.with_dst_nodes(nodes), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - q.output_edge_query.with_src_nodes(nodes)); - std::unordered_set result = g.query_edges(subgraph_query); - extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); - extend(result, - query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); - return result; -} - -std::unordered_set - OpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { - return g.query_nodes(query_intersection(q, NodeQuery(nodes))); -} - -UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( - OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) { - this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); -} - -UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { - return new UpwardOpenMultiDiSubgraphView(g, nodes); -} - -std::unordered_set UpwardOpenMultiDiSubgraphView::query_edges( - OpenMultiDiEdgeQuery const &q) const { - OpenMultiDiEdgeQuery subgraph_query( - q.input_edge_query.with_dst_nodes(nodes), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - OutputMultiDiEdgeQuery::none()); - std::unordered_set result = g.query_edges(subgraph_query); - extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); - return result; -} - -std::unordered_set - UpwardOpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { - return g.query_nodes(query_intersection(q, NodeQuery(nodes))); -} - -DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( - OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) { - this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); -} - -std::unordered_set - DownwardOpenMultiDiSubgraphView::query_edges( - OpenMultiDiEdgeQuery const &q) const { - OpenMultiDiEdgeQuery subgraph_query{ - input_multidiedge_query_none(), - MultiDiEdgeQuery{nodes, nodes}, - OutputMultiDiEdgeQuery{nodes}, - }; - std::unordered_set result = g.query_edges(subgraph_query); - extend(result, - query_edge(outputs, OutputMultiDiEdgeQuery{nodes})); - return result; -} - -std::unordered_set - DownwardOpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { - return g.query_nodes(query_intersection(q, NodeQuery(nodes))); -} - -ClosedMultiDiSubgraphView::ClosedMultiDiSubgraphView( - OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) {} - -std::unordered_set ClosedMultiDiSubgraphView::query_edges( - OpenMultiDiEdgeQuery const &q) const { - return g.query_edges( - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes)); -} - -std::unordered_set - ClosedMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { - return g.query_nodes(query_intersection(q, NodeQuery(nodes))); -} - -ClosedMultiDiSubgraphView *ClosedMultiDiSubgraphView::clone() const { - return new ClosedMultiDiSubgraphView(g, nodes); -} +// ViewDiGraphAsMultiDiGraph::ViewDiGraphAsMultiDiGraph(DiGraphView const &g) +// : g(g) {} + +// std::unordered_set ViewDiGraphAsMultiDiGraph::query_edges( +// MultiDiEdgeQuery const &multidi_query) const { +// DirectedEdgeQuery directed_query{multidi_query.srcs, multidi_query.dsts}; +// +// std::unordered_set const directed_edges = +// this->g.query_edges(directed_query); +// +// return transform(directed_edges, [](DirectedEdge const &e) { +// return MultiDiEdge{e.dst, e.src}; +// }); +// } + +// std::unordered_set +// ViewDiGraphAsMultiDiGraph::query_nodes(NodeQuery const &node_query) const { +// return this->g.query_nodes(node_query); +// } + +// ViewMultiDiGraphAsOpenMultiDiGraph::ViewMultiDiGraphAsOpenMultiDiGraph( +// MultiDiGraphView const &g) +// : g(g) {} + +// std::unordered_set +// ViewMultiDiGraphAsOpenMultiDiGraph::query_edges( +// OpenMultiDiEdgeQuery const &q) const { +// return transform(g.query_edges(q.standard_edge_query), +// [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); +// } + +// std::unordered_set +// ViewMultiDiGraphAsOpenMultiDiGraph::query_nodes(NodeQuery const &q) const { +// return g.query_nodes(q); +// } + +// ViewMultiDiGraphAsOpenMultiDiGraph * +// ViewMultiDiGraphAsOpenMultiDiGraph::clone() const { +// return new ViewMultiDiGraphAsOpenMultiDiGraph(g); +// } + +// std::unordered_set +// query_edge(std::unordered_set const &edges, +// InputMultiDiEdgeQuery const &q) { +// return filter(edges, [&](InputMultiDiEdge const &e) { +// return includes(q.dsts, e.dst) && includes(q.dstIdxs, e.dst_idx); +// }); +// } + +// std::unordered_set +// query_edge(std::unordered_set const &edges, +// OutputMultiDiEdgeQuery const &q) { +// return filter(edges, [&](OutputMultiDiEdge const &e) { +// return includes(q.srcs, e.src) && includes(q.srcIdxs, e.src_idx); +// }); +// } + +// OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( +// OpenMultiDiGraphView const &g, std::unordered_set const &nodes) +// : g(g), nodes(nodes) { +// this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); +// this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); +// } + +// std::unordered_set +// OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { +// OpenMultiDiEdgeQuery subgraph_query( +// q.input_edge_query.with_dst_nodes(nodes), +// q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), +// q.output_edge_query.with_src_nodes(nodes)); +// std::unordered_set result = g.query_edges(subgraph_query); +// extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); +// extend(result, +// query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); +// return result; +// } + +// std::unordered_set +// OpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { +// return g.query_nodes(query_intersection(q, NodeQuery(nodes))); +// } + +// UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( +// OpenMultiDiGraphView const &g, std::unordered_set const &nodes) +// : g(g), nodes(nodes) { +// this->inputs = transform(get_cut_set(g, nodes), input_multidiedge_from_multidiedge); +// } + +// UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { +// return new UpwardOpenMultiDiSubgraphView(g, nodes); +// } + +// std::unordered_set UpwardOpenMultiDiSubgraphView::query_edges( +// OpenMultiDiEdgeQuery const &q) const { +// OpenMultiDiEdgeQuery subgraph_query( +// q.input_edge_query.with_dst_nodes(nodes), +// q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), +// OutputMultiDiEdgeQuery::none()); +// std::unordered_set result = g.query_edges(subgraph_query); +// extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); +// return result; +// } + +// std::unordered_set +// UpwardOpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { +// return g.query_nodes(query_intersection(q, NodeQuery(nodes))); +// } + +// DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( +// OpenMultiDiGraphView const &g, std::unordered_set const &nodes) +// : g(g), nodes(nodes) { +// this->outputs = transform(get_cut_set(g, nodes), output_multidiedge_from_multidiedge); +// } + +// std::unordered_set +// DownwardOpenMultiDiSubgraphView::query_edges( +// OpenMultiDiEdgeQuery const &q) const { +// OpenMultiDiEdgeQuery subgraph_query{ +// input_multidiedge_query_none(), +// MultiDiEdgeQuery{nodes, nodes}, +// OutputMultiDiEdgeQuery{nodes}, +// }; +// std::unordered_set result = g.query_edges(subgraph_query); +// extend(result, +// query_edge(outputs, OutputMultiDiEdgeQuery{nodes})); +// return result; +// } + +// std::unordered_set +// DownwardOpenMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { +// return g.query_nodes(query_intersection(q, NodeQuery(nodes))); +// } + +// ClosedMultiDiSubgraphView::ClosedMultiDiSubgraphView( +// OpenMultiDiGraphView const &g, std::unordered_set const &nodes) +// : g(g), nodes(nodes) {} + +// std::unordered_set ClosedMultiDiSubgraphView::query_edges( +// OpenMultiDiEdgeQuery const &q) const { +// return g.query_edges( +// q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes)); +// } + +// std::unordered_set +// ClosedMultiDiSubgraphView::query_nodes(NodeQuery const &q) const { +// return g.query_nodes(query_intersection(q, NodeQuery(nodes))); +// } + +// ClosedMultiDiSubgraphView *ClosedMultiDiSubgraphView::clone() const { +// return new ClosedMultiDiSubgraphView(g, nodes); +// } JoinedUndirectedGraphView *JoinedUndirectedGraphView::clone() const { return new JoinedUndirectedGraphView(lhs, rhs); } -DownwardOpenMultiDiSubgraphView * - DownwardOpenMultiDiSubgraphView::clone() const { - return new DownwardOpenMultiDiSubgraphView(g, nodes); -} +// DownwardOpenMultiDiSubgraphView * +// DownwardOpenMultiDiSubgraphView::clone() const { +// return new DownwardOpenMultiDiSubgraphView(g, nodes); +// } -ViewDiGraphAsMultiDiGraph *ViewDiGraphAsMultiDiGraph::clone() const { - return new ViewDiGraphAsMultiDiGraph(g); -} +// ViewDiGraphAsMultiDiGraph *ViewDiGraphAsMultiDiGraph::clone() const { +// return new ViewDiGraphAsMultiDiGraph(g); +// } -OpenMultiDiSubgraphView *OpenMultiDiSubgraphView::clone() const { - return new OpenMultiDiSubgraphView(g, nodes); -} +// OpenMultiDiSubgraphView *OpenMultiDiSubgraphView::clone() const { +// return new OpenMultiDiSubgraphView(g, nodes); +// } -MultiDiSubgraphView *MultiDiSubgraphView::clone() const { - return new MultiDiSubgraphView(g, subgraph_nodes); -} +// MultiDiSubgraphView *MultiDiSubgraphView::clone() const { +// return new MultiDiSubgraphView(g, subgraph_nodes); +// } } // namespace FlexFlow From 6d68324b081e57cd8889b703420c87310f41c16a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 23 Jun 2024 00:07:04 -0700 Subject: [PATCH 10/71] Add open dataflow graph, start to replace pcg dataflow graph --- .proj.toml | 2 +- lib/local-execution/src/ops/conv_2d.cc | 1 - .../include/op-attrs/ops/cast_attrs.dtg.h | 8 +- .../op-attrs/ops/cast_attrs.struct.toml | 4 +- .../src/op-attrs/ops/cast_attrs.dtg.cc | 14 +- lib/op-attrs/test/src/ops/conv_2d.cc | 2 +- lib/op-attrs/test/src/test_operator_attrs.cc | 4 +- lib/pcg/include/pcg/computation_graph.dtg.h | 12 +- .../include/pcg/computation_graph.struct.toml | 5 +- .../include/pcg/dataflow_graph/algorithms.h | 36 ---- .../pcg/dataflow_graph/dataflow_graph.h | 200 +++++++++--------- .../operator_added_result.dtg.h | 43 ---- .../operator_added_result.struct.toml | 22 -- lib/pcg/include/pcg/dataflow_input.dtg.h | 101 --------- .../include/pcg/dataflow_input.variant.toml | 21 -- ...idigraph.dtg.h => v1_dataflow_graph.dtg.h} | 28 ++- .../file_format/v1/graphs/v1_dataflow_graph.h | 15 ++ ...uct.toml => v1_dataflow_graph.struct.toml} | 6 +- .../v1/graphs/v1_jsonable_graph.dtg.h | 23 +- .../graphs/v1_labelled_dataflow_graph.dtg.h | 103 +++++++++ .../v1/graphs/v1_labelled_dataflow_graph.h | 33 +++ ...=> v1_labelled_dataflow_graph.struct.toml} | 12 +- .../file_format/v1/graphs/v1_multidigraph.h | 16 -- .../parallel_computation_graph.dtg.h | 12 +- .../parallel_computation_graph.struct.toml | 4 +- lib/pcg/include/pcg/tensor_guid_t.dtg.h | 8 +- lib/pcg/include/pcg/tensor_guid_t.struct.toml | 4 +- lib/pcg/src/pcg/computation_graph.cc | 25 +-- lib/pcg/src/pcg/computation_graph.dtg.cc | 6 +- lib/pcg/src/pcg/dataflow_graph/algorithms.cc | 1 - .../operator_added_result.dtg.cc | 57 ----- lib/pcg/src/pcg/dataflow_input.dtg.cc | 41 ---- .../v1/graphs/v1_dataflow_graph.cc | 28 +++ .../v1/graphs/v1_dataflow_graph.dtg.cc | 49 +++++ .../v1/graphs/v1_jsonable_graph.dtg.cc | 2 +- .../graphs/v1_labelled_dataflow_graph.dtg.cc | 10 + .../v1/graphs/v1_multidigraph.dtg.cc | 53 ----- .../parallel_computation_graph.dtg.cc | 7 +- lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 6 +- .../tensor_attribute_value.dtg.h | 6 +- .../tensor_attribute_value.variant.toml | 4 +- .../tensor_attribute_value.dtg.cc | 2 +- lib/utils/include/utils/check_fmtable.h | 2 + lib/utils/include/utils/containers.decl.h | 3 + lib/utils/include/utils/containers.h | 7 + lib/utils/include/utils/fmt/map.h | 45 ++++ lib/utils/include/utils/fmt/set.h | 43 ++++ lib/utils/include/utils/graph.h | 18 +- lib/utils/include/utils/graph/algorithms.h | 5 +- lib/utils/include/utils/graph/construction.h | 30 --- .../utils/graph/dataflow_graph/algorithms.h | 16 ++ .../dataflow_graph/dataflow_edge_query.dtg.h | 3 +- .../dataflow_edge_query.struct.toml | 1 - .../dataflow_graph/dataflow_graph_view.h | 3 + .../dataflow_graph/dataflow_output_query.h | 13 ++ .../unordered_set_dataflow_graph.h | 37 ++++ .../dataflow_graph_output.dtg.h | 46 ++++ .../dataflow_graph_output.struct.toml | 16 ++ .../dataflow_output_edge.dtg.h | 49 +++++ .../dataflow_output_edge.struct.toml | 21 ++ .../downward_open_dataflow_edge.dtg.h | 116 ++++++++++ .../downward_open_dataflow_edge.variant.toml | 19 ++ .../downward_open_dataflow_graph.h | 27 +++ .../i_downward_open_dataflow_graph.h | 16 ++ .../i_downward_open_dataflow_graph_view.h | 12 ++ .../include/utils/graph/graph_split.dtg.h | 47 ++++ .../utils/graph/graph_split.struct.toml | 22 ++ .../i_labelled_dataflow_graph.h | 22 ++ .../i_labelled_dataflow_graph_view.h | 20 ++ .../labelled_dataflow_graph.h | 31 +++ .../labelled_dataflow_graph_view.h | 38 ++++ .../include/utils/graph/node/algorithms.h | 12 ++ .../graph/open_dataflow_graph/algorithms.h | 14 ++ .../dataflow_graph_input.dtg.h | 45 ++++ .../dataflow_graph_input.struct.toml | 12 ++ .../dataflow_input_edge.dtg.h | 49 +++++ .../dataflow_input_edge.struct.toml | 21 ++ .../dataflow_input_edge_query.dtg.h | 53 +++++ .../dataflow_input_edge_query.h | 13 ++ .../dataflow_input_edge_query.struct.toml | 26 +++ .../i_open_dataflow_graph.h | 22 ++ .../i_open_dataflow_graph_view.h | 23 ++ .../open_dataflow_edge.dtg.h | 113 ++++++++++ .../open_dataflow_graph/open_dataflow_edge.h | 12 ++ .../open_dataflow_edge.variant.toml | 19 ++ .../open_dataflow_edge_query.dtg.h | 50 +++++ .../open_dataflow_edge_query.h | 13 ++ .../open_dataflow_edge_query.struct.toml | 21 ++ .../open_dataflow_graph/open_dataflow_graph.h | 33 +++ .../open_dataflow_graph_view.h | 33 +++ .../open_dataflow_value.dtg.h | 113 ++++++++++ .../open_dataflow_value.variant.toml | 19 ++ .../unordered_set_open_dataflow_graph.h | 43 ++++ lib/utils/include/utils/graph/query_set.h | 1 + .../graph/serial_parallel/serialparallel.h | 3 + lib/utils/include/utils/hash-utils-core.h | 52 ----- lib/utils/include/utils/hash-utils.h | 89 ++++---- lib/utils/include/utils/hash/map.h | 20 ++ lib/utils/include/utils/hash/pair.h | 6 +- lib/utils/include/utils/hash/set.h | 20 ++ lib/utils/include/utils/hash/tuple.h | 4 +- lib/utils/include/utils/hash/unordered_map.h | 20 ++ lib/utils/include/utils/hash/unordered_set.h | 20 ++ lib/utils/include/utils/hash/vector.h | 4 +- lib/utils/include/utils/required_core.h | 2 +- lib/utils/src/fp16.cc | 3 +- lib/utils/src/utils/check_fmtable.cc | 1 + lib/utils/src/utils/fmt/expected.cc | 1 + lib/utils/src/utils/fmt/map.cc | 1 + lib/utils/src/utils/fmt/pair.cc | 1 + lib/utils/src/utils/fmt/set.cc | 1 + lib/utils/src/utils/fmt/unordered_map.cc | 1 + lib/utils/src/utils/fmt/unordered_set.cc | 1 + lib/utils/src/utils/fmt/vector.cc | 1 + lib/utils/src/utils/graph.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 5 +- .../utils/graph/dataflow_graph/algorithms.cc | 35 +++ .../dataflow_graph/dataflow_edge_query.dtg.cc | 2 +- .../graph/dataflow_graph/dataflow_graph.cc | 2 +- .../dataflow_graph/dataflow_output_query.cc | 19 ++ .../graph/dataflow_graph/i_dataflow_graph.cc | 1 + .../unordered_set_dataflow_graph.cc | 60 ++++++ .../dataflow_graph_output.dtg.cc | 57 +++++ .../dataflow_output_edge.dtg.cc | 63 ++++++ .../downward_open_dataflow_edge.dtg.cc | 80 +++++++ lib/utils/src/utils/graph/graph_split.dtg.cc | 52 +++++ .../i_labelled_dataflow_graph.cc | 1 + .../i_labelled_dataflow_graph_view.cc | 1 + lib/utils/src/utils/graph/node/algorithms.cc | 11 + .../graph/open_dataflow_graph/algorithms.cc | 31 +++ .../dataflow_graph_input.dtg.cc | 57 +++++ .../dataflow_input_edge.dtg.cc | 62 ++++++ .../dataflow_input_edge_query.cc | 20 ++ .../dataflow_input_edge_query.dtg.cc | 80 +++++++ .../i_open_dataflow_graph.cc | 1 + .../i_open_dataflow_graph_view.cc | 19 ++ .../open_dataflow_graph/open_dataflow_edge.cc | 13 ++ .../open_dataflow_edge.dtg.cc | 72 +++++++ .../open_dataflow_edge_query.cc | 21 ++ .../open_dataflow_edge_query.dtg.cc | 77 +++++++ .../open_dataflow_graph.cc | 22 ++ .../open_dataflow_graph_view.cc | 18 ++ .../open_dataflow_value.dtg.cc | 72 +++++++ .../unordered_set_open_dataflow_graph.cc | 79 +++++++ lib/utils/src/utils/graph/query_set.cc | 1 + .../graph/serial_parallel/serialparallel.cc | 37 +--- .../serialparallel_internal.cc | 49 ++++- .../serial_parallel/serialparallel_internal.h | 10 +- lib/utils/src/utils/graph/traversal.cc | 1 + .../utils/graph/undirected/undirected_edge.cc | 2 +- lib/utils/src/utils/graph/views/views.cc | 1 + lib/utils/src/utils/hash/pair.cc | 1 + lib/utils/src/utils/hash/set.cc | 1 + lib/utils/src/utils/hash/tuple.cc | 1 + lib/utils/src/utils/hash/unordered_map.cc | 1 + lib/utils/src/utils/hash/unordered_set.cc | 1 + lib/utils/src/utils/hash/vector.cc | 1 + lib/utils/test/CMakeLists.txt | 1 + .../unordered_set_dataflow_graph.cc | 73 +++++++ 159 files changed, 3178 insertions(+), 812 deletions(-) delete mode 100644 lib/pcg/include/pcg/dataflow_graph/algorithms.h delete mode 100644 lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h delete mode 100644 lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml delete mode 100644 lib/pcg/include/pcg/dataflow_input.dtg.h delete mode 100644 lib/pcg/include/pcg/dataflow_input.variant.toml rename lib/pcg/include/pcg/file_format/v1/graphs/{v1_multidigraph.dtg.h => v1_dataflow_graph.dtg.h} (51%) create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h rename lib/pcg/include/pcg/file_format/v1/graphs/{v1_multidigraph.struct.toml => v1_dataflow_graph.struct.toml} (83%) create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h rename lib/pcg/include/pcg/file_format/v1/graphs/{v1_jsonable_graph.struct.toml => v1_labelled_dataflow_graph.struct.toml} (60%) delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h delete mode 100644 lib/pcg/src/pcg/dataflow_graph/algorithms.cc delete mode 100644 lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc delete mode 100644 lib/pcg/src/pcg/dataflow_input.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc delete mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc create mode 100644 lib/utils/include/utils/fmt/map.h create mode 100644 lib/utils/include/utils/fmt/set.h delete mode 100644 lib/utils/include/utils/graph/construction.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/graph_split.dtg.h create mode 100644 lib/utils/include/utils/graph/graph_split.struct.toml create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/node/algorithms.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h delete mode 100644 lib/utils/include/utils/hash-utils-core.h create mode 100644 lib/utils/include/utils/hash/map.h create mode 100644 lib/utils/include/utils/hash/set.h create mode 100644 lib/utils/include/utils/hash/unordered_map.h create mode 100644 lib/utils/include/utils/hash/unordered_set.h create mode 100644 lib/utils/src/utils/check_fmtable.cc create mode 100644 lib/utils/src/utils/fmt/expected.cc create mode 100644 lib/utils/src/utils/fmt/map.cc create mode 100644 lib/utils/src/utils/fmt/pair.cc create mode 100644 lib/utils/src/utils/fmt/set.cc create mode 100644 lib/utils/src/utils/fmt/unordered_map.cc create mode 100644 lib/utils/src/utils/fmt/unordered_set.cc create mode 100644 lib/utils/src/utils/fmt/vector.cc create mode 100644 lib/utils/src/utils/graph.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.cc create mode 100644 lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/graph_split.dtg.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/node/algorithms.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/query_set.cc create mode 100644 lib/utils/src/utils/hash/pair.cc create mode 100644 lib/utils/src/utils/hash/set.cc create mode 100644 lib/utils/src/utils/hash/tuple.cc create mode 100644 lib/utils/src/utils/hash/unordered_map.cc create mode 100644 lib/utils/src/utils/hash/unordered_set.cc create mode 100644 lib/utils/src/utils/hash/vector.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc diff --git a/.proj.toml b/.proj.toml index 01ae36eddd..a31561632d 100644 --- a/.proj.toml +++ b/.proj.toml @@ -19,7 +19,7 @@ test_targets = [ "pcg-tests", # "substitutions-tests", # "compiler-tests", - "substitution-generator-tests", + # "substitution-generator-tests", ] [cmake_flags_extra] diff --git a/lib/local-execution/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc index bc3e66f60f..213a5e2173 100644 --- a/lib/local-execution/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -1,7 +1,6 @@ #include "conv_2d.h" #include "kernels/conv_2d_kernels.h" #include "op-attrs/get_output_shapes.h" -#include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h index 0cfb1c2161..28bd8258a0 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml /* proj-data { - "generated_from": "c171c87db89b9ec9ea7d52a50c153054" + "generated_from": "902985a57f18e36925e35d90701329fa" } */ @@ -12,7 +12,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/datatype.h" +#include "op-attrs/datatype.dtg.h" #include "rapidcheck.h" #include #include @@ -21,7 +21,7 @@ namespace FlexFlow { struct CastAttrs { CastAttrs() = delete; - explicit CastAttrs(DataType const &dtype); + explicit CastAttrs(::FlexFlow::DataType const &dtype); bool operator==(CastAttrs const &) const; bool operator!=(CastAttrs const &) const; @@ -29,7 +29,7 @@ struct CastAttrs { bool operator>(CastAttrs const &) const; bool operator<=(CastAttrs const &) const; bool operator>=(CastAttrs const &) const; - DataType dtype; + ::FlexFlow::DataType dtype; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml index 6c12680ea1..287861888c 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml @@ -10,9 +10,9 @@ features = [ ] includes = [ - "op-attrs/datatype.h" + "op-attrs/datatype.dtg.h" ] [[fields]] name = "dtype" -type = "DataType" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc index 661aca32a9..8bcf704ba0 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml /* proj-data { - "generated_from": "c171c87db89b9ec9ea7d52a50c153054" + "generated_from": "902985a57f18e36925e35d90701329fa" } */ @@ -12,7 +12,7 @@ #include namespace FlexFlow { -CastAttrs::CastAttrs(DataType const &dtype) : dtype(dtype) {} +CastAttrs::CastAttrs(::FlexFlow::DataType const &dtype) : dtype(dtype) {} bool CastAttrs::operator==(CastAttrs const &other) const { return std::tie(this->dtype) == std::tie(other.dtype); } @@ -37,8 +37,8 @@ namespace std { size_t hash::operator()( ::FlexFlow::CastAttrs const &x) const { size_t result = 0; - result ^= std::hash{}(x.dtype) + 0x9e3779b9 + (result << 6) + - (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.dtype) + 0x9e3779b9 + + (result << 6) + (result >> 2); return result; } } // namespace std @@ -46,7 +46,8 @@ size_t hash::operator()( namespace nlohmann { ::FlexFlow::CastAttrs adl_serializer<::FlexFlow::CastAttrs>::from_json(json const &j) { - return ::FlexFlow::CastAttrs{j.at("dtype").template get()}; + return ::FlexFlow::CastAttrs{ + j.at("dtype").template get<::FlexFlow::DataType>()}; } void adl_serializer<::FlexFlow::CastAttrs>::to_json( json &j, ::FlexFlow::CastAttrs const &v) { @@ -57,7 +58,8 @@ void adl_serializer<::FlexFlow::CastAttrs>::to_json( namespace rc { Gen<::FlexFlow::CastAttrs> Arbitrary<::FlexFlow::CastAttrs>::arbitrary() { - return gen::construct<::FlexFlow::CastAttrs>(gen::arbitrary()); + return gen::construct<::FlexFlow::CastAttrs>( + gen::arbitrary<::FlexFlow::DataType>()); } } // namespace rc diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/ops/conv_2d.cc index 6f5028cfeb..de44918826 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/ops/conv_2d.cc @@ -1,5 +1,5 @@ #include "op-attrs/ops/conv_2d.h" -#include "doctest/doctest.h" +#include "test/utils/doctest.h" #include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index a7724dba69..4906a7b59c 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -1,10 +1,12 @@ -#include "doctest/doctest.h" +#include #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/json.h" #include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{true}; diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h index 028d9ecfab..ad762e0fec 100644 --- a/lib/pcg/include/pcg/computation_graph.dtg.h +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -3,25 +3,27 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "bf8996bea2e022265a372d692c2db8ed" + "generated_from": "79ce58a361b164cad98643f961b9e266" } */ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H -#include "pcg/dataflow_graph/dataflow_graph.h" #include "pcg/layer_attrs.dtg.h" #include "pcg/tensor_attrs.dtg.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" namespace FlexFlow { struct ComputationGraph { ComputationGraph() = delete; explicit ComputationGraph( - ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const &raw_graph); + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const + &raw_graph); - ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs> + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml index 39c68b8e4f..a40935e235 100644 --- a/lib/pcg/include/pcg/computation_graph.struct.toml +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -5,9 +5,10 @@ features = [ ] includes = [ "pcg/layer_attrs.dtg.h", "pcg/tensor_attrs.dtg.h", - "pcg/dataflow_graph/dataflow_graph.h", + "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h", + ] [[fields]] name = "raw_graph" -type = "::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" +type = "::FlexFlow::LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/dataflow_graph/algorithms.h b/lib/pcg/include/pcg/dataflow_graph/algorithms.h deleted file mode 100644 index 413fecd92a..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph/algorithms.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H - -#include "pcg/dataflow_graph/dataflow_graph.h" - -namespace FlexFlow { - -template -std::vector - get_inputs(DataflowGraph const &g, Node const &n) { - std::vector> input_edges = - transform(as_vector(get_incoming_edges(g.get_raw_graph(), - std::unordered_set{n})), - [&](MultiDiEdge const &e) { - int idx = g.idx_for_port(e.dst_idx); - MultiDiOutput val = static_cast(e); - return std::make_pair(idx, val); - }); - - return vector_from_indexed_set(input_edges); -} - -template -std::vector - get_outputs(DataflowGraph const &g, Node const &n) { - return g.get_output_map().at(n); -} - -template -std::vector topological_ordering(DataflowGraph const &g) { - return get_topological_ordering(g.get_raw_graph()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h index c0650bc9b4..2ea5814f59 100644 --- a/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h +++ b/lib/pcg/include/pcg/dataflow_graph/dataflow_graph.h @@ -1,105 +1,105 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H -#include "pcg/dataflow_graph/operator_added_result.dtg.h" -#include "utils/containers/enumerate_vector.h" -#include "utils/graph.h" - -namespace FlexFlow { - -template -struct DataflowGraph { -public: - DataflowGraph() - : g(OutputLabelledMultiDiGraph::template create< - UnorderedOutputLabelledMultiDiGraph>()) {} - - OperatorAddedResult - add_operator(NodeLabel const &func, - std::vector const &inputs, - std::vector const &output_labels) { - Node node = this->g.add_node(func); - for (auto const &[idx, input] : enumerate_vector(inputs)) { - this->g.add_edge(MultiDiEdge{ - node, this->make_port_for_idx(idx), input.src, input.src_idx}); - } - - std::vector outputs; - for (auto const &[idx, label] : enumerate_vector(output_labels)) { - MultiDiOutput output = MultiDiOutput{node, this->make_port_for_idx(idx)}; - this->g.add_output(output, label); - outputs.push_back(output); - } - this->output_map[node] = outputs; - - return OperatorAddedResult{ - node, - outputs, - }; - } - - NodePort make_port_for_idx(int idx) { - if (!this->port_mapping.contains_l(idx)) { - this->port_mapping.equate(idx, this->g.add_node_port()); - } - return this->port_mapping.at_l(idx); - } - - NodePort port_for_idx(int idx) const { - return this->port_mapping.at_l(idx); - } - - int idx_for_port(NodePort const &p) const { - return this->port_mapping.at_r(p); - } - - OutputLabelledMultiDiGraphView const & - get_raw_graph() const { - return this->g; - } - - NodeLabel const &at(Node const &n) const { - return this->g.at(n); - } - - OutputLabel const &at(MultiDiOutput const &o) const { - return this->g.at(o); - } - - std::unordered_map> const & - get_output_map() const { - return this->output_map; - } - -private: - OutputLabelledMultiDiGraph g; - bidict port_mapping; - std::unordered_map> - output_map; // NOTE(@lockshaw): temporary workaround until not tracking - // outputs independent of edges in multidigraph is resolved -}; - -template -std::unordered_set - get_nodes(DataflowGraph const &g) { - return get_nodes(g.get_raw_graph()); -} - -template -std::vector - vector_from_indexed_set(std::vector> const &s) { - std::vector> result{s.size(), std::nullopt}; - for (auto const &[idx, value] : s) { - assert(idx < s.size() && idx >= 0); - assert(!result.at(idx).has_value()); - result.at(idx) = value; - } - return transform(result, [](std::optional const &v) { - assert(v.has_value()); - return v.value(); - }); -} - -} // namespace FlexFlow +// #include "pcg/dataflow_graph/operator_added_result.dtg.h" +// #include "utils/containers/enumerate_vector.h" +// #include "utils/graph.h" +// +// namespace FlexFlow { +// +// template +// struct DataflowGraph { +// public: +// DataflowGraph() +// : g(OutputLabelledMultiDiGraph::template create< +// UnorderedOutputLabelledMultiDiGraph>()) {} +// +// OperatorAddedResult +// add_operator(NodeLabel const &func, +// std::vector const &inputs, +// std::vector const &output_labels) { +// Node node = this->g.add_node(func); +// for (auto const &[idx, input] : enumerate_vector(inputs)) { +// this->g.add_edge(MultiDiEdge{ +// node, this->make_port_for_idx(idx), input.src, input.src_idx}); +// } +// +// std::vector outputs; +// for (auto const &[idx, label] : enumerate_vector(output_labels)) { +// MultiDiOutput output = MultiDiOutput{node, this->make_port_for_idx(idx)}; +// this->g.add_output(output, label); +// outputs.push_back(output); +// } +// this->output_map[node] = outputs; +// +// return OperatorAddedResult{ +// node, +// outputs, +// }; +// } +// +// NodePort make_port_for_idx(int idx) { +// if (!this->port_mapping.contains_l(idx)) { +// this->port_mapping.equate(idx, this->g.add_node_port()); +// } +// return this->port_mapping.at_l(idx); +// } +// +// NodePort port_for_idx(int idx) const { +// return this->port_mapping.at_l(idx); +// } +// +// int idx_for_port(NodePort const &p) const { +// return this->port_mapping.at_r(p); +// } +// +// OutputLabelledMultiDiGraphView const & +// get_raw_graph() const { +// return this->g; +// } +// +// NodeLabel const &at(Node const &n) const { +// return this->g.at(n); +// } +// +// OutputLabel const &at(MultiDiOutput const &o) const { +// return this->g.at(o); +// } +// +// std::unordered_map> const & +// get_output_map() const { +// return this->output_map; +// } +// +// private: +// OutputLabelledMultiDiGraph g; +// bidict port_mapping; +// std::unordered_map> +// output_map; // NOTE(@lockshaw): temporary workaround until not tracking +// // outputs independent of edges in multidigraph is resolved +// }; +// +// template +// std::unordered_set +// get_nodes(DataflowGraph const &g) { +// return get_nodes(g.get_raw_graph()); +// } +// +// template +// std::vector +// vector_from_indexed_set(std::vector> const &s) { +// std::vector> result{s.size(), std::nullopt}; +// for (auto const &[idx, value] : s) { +// assert(idx < s.size() && idx >= 0); +// assert(!result.at(idx).has_value()); +// result.at(idx) = value; +// } +// return transform(result, [](std::optional const &v) { +// assert(v.has_value()); +// return v.value(); +// }); +// } +// +// } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h deleted file mode 100644 index 9e9803b8a0..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.dtg.h +++ /dev/null @@ -1,43 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml -/* proj-data -{ - "generated_from": "62224733c501773b41f1fc63a8677949" -} -*/ - -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H - -#include "fmt/format.h" -#include "utils/fmt/vector.h" -#include "utils/graph.h" -#include -#include -#include - -namespace FlexFlow { -struct OperatorAddedResult { - OperatorAddedResult() = delete; - explicit OperatorAddedResult( - ::FlexFlow::Node const &node, - std::vector<::FlexFlow::MultiDiOutput> const &outputs); - - bool operator==(OperatorAddedResult const &) const; - bool operator!=(OperatorAddedResult const &) const; - bool operator<(OperatorAddedResult const &) const; - bool operator>(OperatorAddedResult const &) const; - bool operator<=(OperatorAddedResult const &) const; - bool operator>=(OperatorAddedResult const &) const; - ::FlexFlow::Node node; - std::vector<::FlexFlow::MultiDiOutput> outputs; -}; -} // namespace FlexFlow - -namespace FlexFlow { -std::string format_as(OperatorAddedResult const &); -std::ostream &operator<<(std::ostream &, OperatorAddedResult const &); -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_OPERATOR_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml deleted file mode 100644 index 3c9cb87e85..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAddedResult" - -features = [ - "eq", - "ord", - "fmt", -] - -includes = [ - "", - "utils/graph.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "outputs" -type = "std::vector<::FlexFlow::MultiDiOutput>" diff --git a/lib/pcg/include/pcg/dataflow_input.dtg.h b/lib/pcg/include/pcg/dataflow_input.dtg.h deleted file mode 100644 index c698c75c25..0000000000 --- a/lib/pcg/include/pcg/dataflow_input.dtg.h +++ /dev/null @@ -1,101 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/dataflow_input.variant.toml -/* proj-data -{ - "generated_from": "d6a7f4570e36e257383529e9bf9390ec" -} -*/ - -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H - -#include "utils/graph/multidiedge.h" -#include -#include -#include -#include - -namespace FlexFlow { -struct DataflowInput { - DataflowInput() = delete; - explicit DataflowInput(::FlexFlow::MultiDiOutput const &); - explicit DataflowInput(int const &); - template - static constexpr bool IsPartOfDataflowInput_v = - std::is_same_v || std::is_same_v; - template - ReturnType visit(Visitor &&v) const { - switch (this->index()) { - case 0: { - ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); - return result; - } - case 1: { - ReturnType result = v(this->get()); - return result; - } - default: { - throw std::runtime_error(fmt::format( - "Unknown index {} for type DataflowInput", this->index())); - } - } - } - template - ReturnType visit(Visitor &&v) { - switch (this->index()) { - case 0: { - ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); - return result; - } - case 1: { - ReturnType result = v(this->get()); - return result; - } - default: { - throw std::runtime_error(fmt::format( - "Unknown index {} for type DataflowInput", this->index())); - } - } - } - template - bool has() const { - static_assert(IsPartOfDataflowInput_v, - "DataflowInput::has() expected one of " - "[::FlexFlow::MultiDiOutput, int], received T"); - return std::holds_alternative(this->raw_variant); - } - template - T const &get() const { - static_assert(IsPartOfDataflowInput_v, - "DataflowInput::get() expected one of " - "[::FlexFlow::MultiDiOutput, int], received T"); - return std::get(this->raw_variant); - } - template - T &get() { - static_assert(IsPartOfDataflowInput_v, - "DataflowInput::get() expected one of " - "[::FlexFlow::MultiDiOutput, int], received T"); - return std::get(this->raw_variant); - } - size_t index() const { - return this->raw_variant.index(); - } - bool operator==(DataflowInput const &) const; - bool operator!=(DataflowInput const &) const; - bool operator<(DataflowInput const &) const; - bool operator>(DataflowInput const &) const; - bool operator<=(DataflowInput const &) const; - bool operator>=(DataflowInput const &) const; - std::variant<::FlexFlow::MultiDiOutput, int> raw_variant; -}; -} // namespace FlexFlow -namespace std { -template <> -struct hash<::FlexFlow::DataflowInput> { - size_t operator()(::FlexFlow::DataflowInput const &) const; -}; -} // namespace std - -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_input.variant.toml b/lib/pcg/include/pcg/dataflow_input.variant.toml deleted file mode 100644 index ac7c3ae5d7..0000000000 --- a/lib/pcg/include/pcg/dataflow_input.variant.toml +++ /dev/null @@ -1,21 +0,0 @@ -namespace = "FlexFlow" -name = "DataflowInput" -features = [ - "eq", - "ord", - "hash", - # "json", - # "fmt", -] - -includes = [ - "utils/graph/multidiedge.h" , -] - -[[values]] -type = "::FlexFlow::MultiDiOutput" -key = "internal" - -[[values]] -type = "int" -key = "external" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h similarity index 51% rename from lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h rename to lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h index 5b214d2b58..7b52bea2fd 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +// lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml /* proj-data { - "generated_from": "582054edb983c3cc31d9273ce29421eb" + "generated_from": "6d7fce9dbb6976f4365ecb0be547955c" } */ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_DATAFLOW_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_DATAFLOW_GRAPH_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -20,30 +20,28 @@ #include namespace FlexFlow { -struct V1MultiDiGraph { - V1MultiDiGraph() = delete; - explicit V1MultiDiGraph( +struct V1DataflowGraph { + V1DataflowGraph() = delete; + explicit V1DataflowGraph( std::vector const &nodes, - std::vector const &ports, std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); std::vector nodes; - std::vector ports; std::unordered_set<::FlexFlow::V1GraphEdge> edges; }; } // namespace FlexFlow namespace nlohmann { template <> -struct adl_serializer<::FlexFlow::V1MultiDiGraph> { - static ::FlexFlow::V1MultiDiGraph from_json(json const &); - static void to_json(json &, ::FlexFlow::V1MultiDiGraph const &); +struct adl_serializer<::FlexFlow::V1DataflowGraph> { + static ::FlexFlow::V1DataflowGraph from_json(json const &); + static void to_json(json &, ::FlexFlow::V1DataflowGraph const &); }; } // namespace nlohmann namespace FlexFlow { -std::string format_as(V1MultiDiGraph const &); -std::ostream &operator<<(std::ostream &, V1MultiDiGraph const &); +std::string format_as(V1DataflowGraph const &); +std::ostream &operator<<(std::ostream &, V1DataflowGraph const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_DATAFLOW_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h new file mode 100644 index 0000000000..0e547e7688 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H + +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +V1DataflowGraph to_v1(DataflowGraphView const &); +V1DataflowGraph to_v1(DataflowGraphView const &, + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml similarity index 83% rename from lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml rename to lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index 20ca69eed4..dc9dc96f29 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "V1MultiDiGraph" +name = "V1DataflowGraph" features = [ # "eq", # "ord", @@ -21,10 +21,6 @@ includes = [ name = "nodes" type = "std::vector" -[[fields]] -name = "ports" -type = "std::vector" - [[fields]] name = "edges" type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h index c6ffb55e3b..839741e86f 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml /* proj-data { - "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" + "generated_from": "ac98d063410ebe1c14f58ea8e17c272e" } */ @@ -12,8 +12,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" #include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" -#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" #include #include #include @@ -24,14 +24,12 @@ struct V1JsonableGraph { V1JsonableGraph() = delete; explicit V1JsonableGraph( std::unordered_map const &node_labels, - std::unordered_map const &outputs, std::unordered_map const &output_labels, - ::FlexFlow::V1MultiDiGraph const &graph); + ::FlexFlow::V1DataflowGraph const &graph); std::unordered_map node_labels; - std::unordered_map outputs; std::unordered_map output_labels; - ::FlexFlow::V1MultiDiGraph graph; + ::FlexFlow::V1DataflowGraph graph; }; } // namespace FlexFlow @@ -56,11 +54,9 @@ namespace FlexFlow { template V1JsonableGraph::V1JsonableGraph( std::unordered_map const &node_labels, - std::unordered_map const &outputs, std::unordered_map const &output_labels, - ::FlexFlow::V1MultiDiGraph const &graph) - : node_labels(node_labels), outputs(outputs), output_labels(output_labels), - graph(graph) {} + ::FlexFlow::V1DataflowGraph const &graph) + : node_labels(node_labels), output_labels(output_labels), graph(graph) {} } // namespace FlexFlow namespace nlohmann { @@ -70,18 +66,14 @@ ::FlexFlow::V1JsonableGraph json const &j) { return ::FlexFlow::V1JsonableGraph{ j.at("node_labels").template get>(), - j.at("outputs") - .template get< - std::unordered_map>(), j.at("output_labels").template get>(), - j.at("graph").template get<::FlexFlow::V1MultiDiGraph>()}; + j.at("graph").template get<::FlexFlow::V1DataflowGraph>()}; } template void adl_serializer<::FlexFlow::V1JsonableGraph>::to_json( json &j, ::FlexFlow::V1JsonableGraph const &v) { j["__type"] = "V1JsonableGraph"; j["node_labels"] = v.node_labels; - j["outputs"] = v.outputs; j["output_labels"] = v.output_labels; j["graph"] = v.graph; } @@ -93,7 +85,6 @@ std::string format_as(V1JsonableGraph const &x) { std::ostringstream oss; oss << ""; diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h new file mode 100644 index 0000000000..4f24e1f3b4 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h @@ -0,0 +1,103 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +/* proj-data +{ + "generated_from": "5b6ac94ce5ca0fe62b2309c7a87b583a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" +#include +#include +#include + +namespace FlexFlow { +template +struct V1LabelledDataflowGraph { + V1LabelledDataflowGraph() = delete; + explicit V1LabelledDataflowGraph( + std::unordered_map const &node_labels, + std::unordered_map const &output_labels, + ::FlexFlow::V1DataflowGraph const &graph); + + std::unordered_map node_labels; + std::unordered_map output_labels; + ::FlexFlow::V1DataflowGraph graph; +}; +} // namespace FlexFlow + +namespace nlohmann { +template +struct adl_serializer<::FlexFlow::V1LabelledDataflowGraph> { + static ::FlexFlow::V1LabelledDataflowGraph + from_json(json const &); + static void + to_json(json &, + ::FlexFlow::V1LabelledDataflowGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1LabelledDataflowGraph const &); +template +std::ostream &operator<<(std::ostream &, + V1LabelledDataflowGraph const &); +} // namespace FlexFlow + +namespace FlexFlow { +template +V1LabelledDataflowGraph::V1LabelledDataflowGraph( + std::unordered_map const &node_labels, + std::unordered_map const &output_labels, + ::FlexFlow::V1DataflowGraph const &graph) + : node_labels(node_labels), output_labels(output_labels), graph(graph) {} +} // namespace FlexFlow + +namespace nlohmann { +template +::FlexFlow::V1LabelledDataflowGraph + adl_serializer<::FlexFlow::V1LabelledDataflowGraph>:: + from_json(json const &j) { + return ::FlexFlow::V1LabelledDataflowGraph{ + j.at("node_labels").template get>(), + j.at("output_labels").template get>(), + j.at("graph").template get<::FlexFlow::V1DataflowGraph>()}; +} +template +void adl_serializer<::FlexFlow::V1LabelledDataflowGraph>:: + to_json(json &j, + ::FlexFlow::V1LabelledDataflowGraph const &v) { + j["__type"] = "V1LabelledDataflowGraph"; + j["node_labels"] = v.node_labels; + j["output_labels"] = v.output_labels; + j["graph"] = v.graph; +} +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1LabelledDataflowGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +template +std::ostream &operator<<(std::ostream &s, + V1LabelledDataflowGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h new file mode 100644 index 0000000000..823989a89a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_LABELLED_DATAFLOW_GRAPH_H + +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" + +namespace FlexFlow { + +template +V1LabelledDataflowGraph + to_v1(LabelledDataflowGraphView const &g) { + + bidict nodes = enumerate(get_nodes(g)); + + V1DataflowGraph unlabelled = to_v1(g, nodes.reversed()); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + + std::unordered_map outputs = + map_values(nodes, [&](MultiDiOutput const &o) { + return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; + }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); + + return V1JsonableGraph{ + node_labels, outputs, output_labels, unlabelled}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml similarity index 60% rename from lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml rename to lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index ad9ba21c60..47263a80bc 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "V1JsonableGraph" +name = "V1LabelledDataflowGraph" features = [ # "eq", # "ord", @@ -16,7 +16,7 @@ template_params = [ includes = [ "", - "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h", + "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h", "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", ] @@ -24,15 +24,11 @@ includes = [ name = "node_labels" type = "std::unordered_map" -[[fields]] -name = "outputs" -type = "std::unordered_map" - [[fields]] name = "output_labels" -type = "std::unordered_map" +type = "std::unordered_map>" [[fields]] name = "graph" -type = "::FlexFlow::V1MultiDiGraph" +type = "::FlexFlow::V1DataflowGraph" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h deleted file mode 100644 index 49ff850a29..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H - -#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" -#include "utils/graph.h" - -namespace FlexFlow { - -V1MultiDiGraph to_v1(MultiDiGraphView const &); -V1MultiDiGraph to_v1(MultiDiGraphView const &, - std::unordered_map const &, - std::unordered_map const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h index a6f9f9455e..df66722664 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.dtg.h @@ -3,27 +3,27 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "1339be6e86e9818c36d6ecf5475e2d4b" + "generated_from": "c9b193f4f31976528951a507119483a3" } */ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_DTG_H -#include "pcg/dataflow_graph/dataflow_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" namespace FlexFlow { struct ParallelComputationGraph { ParallelComputationGraph() = delete; explicit ParallelComputationGraph( - ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> const + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph); - ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml index 759a8424d5..c97333701c 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml @@ -3,11 +3,11 @@ name = "ParallelComputationGraph" features = [ ] includes = [ - "pcg/dataflow_graph/dataflow_graph.h", + "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h", "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" +type = "::FlexFlow::LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h index 3026c2169e..f85987cdda 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.dtg.h +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/tensor_guid_t.struct.toml /* proj-data { - "generated_from": "1e3914b97a465f1752ce510614145b37" + "generated_from": "1a659fe73845127890c449f505aef094" } */ @@ -11,7 +11,7 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H #include "fmt/format.h" -#include "utils/graph.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" #include #include #include @@ -19,7 +19,7 @@ namespace FlexFlow { struct tensor_guid_t { tensor_guid_t() = delete; - explicit tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + explicit tensor_guid_t(::FlexFlow::DataflowOutput const &raw_graph_output); bool operator==(tensor_guid_t const &) const; bool operator!=(tensor_guid_t const &) const; @@ -27,7 +27,7 @@ struct tensor_guid_t { bool operator>(tensor_guid_t const &) const; bool operator<=(tensor_guid_t const &) const; bool operator>=(tensor_guid_t const &) const; - ::FlexFlow::MultiDiOutput raw_graph_output; + ::FlexFlow::DataflowOutput raw_graph_output; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml index 795c0166eb..0f710c81e6 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.struct.toml +++ b/lib/pcg/include/pcg/tensor_guid_t.struct.toml @@ -8,9 +8,9 @@ features = [ ] includes = [ - "utils/graph.h" + "utils/graph/dataflow_graph/dataflow_output.dtg.h" ] [[fields]] name = "raw_graph_output" -type = "::FlexFlow::MultiDiOutput" +type = "::FlexFlow::DataflowOutput" diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 21037bbe45..4e0ce7d0a0 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,10 +1,12 @@ #include "pcg/computation_graph.h" #include "utils/containers.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { ComputationGraph make_empty_computation_graph() { - return ComputationGraph{DataflowGraph{}}; + return ComputationGraph{LabelledDataflowGraph{}}; } std::unordered_set get_layers(ComputationGraph const &cg) { @@ -19,7 +21,7 @@ TensorAttrs get_tensor_attrs(ComputationGraph const &cg, std::vector topological_ordering(ComputationGraph const &cg) { std::vector layers = - get_topological_ordering(cg.raw_graph.get_raw_graph()); + get_topological_ordering(cg.raw_graph); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } @@ -27,30 +29,21 @@ std::vector topological_ordering(ComputationGraph const &cg) { std::vector reverse_topological_ordering(ComputationGraph const &cg) { std::vector layers = reversed>( - get_topological_ordering(cg.raw_graph.get_raw_graph())); + get_topological_ordering(cg.raw_graph)); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } -static std::vector - sort_edge_set(std::unordered_set const &edges) { - return transform( - sorted_by(edges, compare_by([](MultiDiEdge const &e) { - return e.src_idx; - })), - [&](MultiDiEdge const &e) -> tensor_guid_t { return tensor_guid_t{e}; }); -} - std::vector get_outgoing_tensors(ComputationGraph const &cg, layer_guid_t n) { - return sort_edge_set( - get_outgoing_edges(cg.raw_graph.get_raw_graph(), n.raw_node)); + return transform(get_outputs(cg.raw_graph, n.raw_node), + [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n) { - return sort_edge_set( - get_incoming_edges(cg.raw_graph.get_raw_graph(), n.raw_node)); + return transform(get_inputs(cg.raw_graph, n.raw_node), + [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc index 327cdc964a..3ba8ff890d 100644 --- a/lib/pcg/src/pcg/computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "bf8996bea2e022265a372d692c2db8ed" + "generated_from": "79ce58a361b164cad98643f961b9e266" } */ @@ -11,7 +11,7 @@ namespace FlexFlow { ComputationGraph::ComputationGraph( - ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const &raw_graph) + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/dataflow_graph/algorithms.cc b/lib/pcg/src/pcg/dataflow_graph/algorithms.cc deleted file mode 100644 index 3ef04c95a3..0000000000 --- a/lib/pcg/src/pcg/dataflow_graph/algorithms.cc +++ /dev/null @@ -1 +0,0 @@ -#include "pcg/dataflow_graph/algorithms.h" diff --git a/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc b/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc deleted file mode 100644 index 6cb8f8fa83..0000000000 --- a/lib/pcg/src/pcg/dataflow_graph/operator_added_result.dtg.cc +++ /dev/null @@ -1,57 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml -/* proj-data -{ - "generated_from": "62224733c501773b41f1fc63a8677949" -} -*/ - -#include "pcg/dataflow_graph/operator_added_result.dtg.h" - -#include - -namespace FlexFlow { -OperatorAddedResult::OperatorAddedResult( - ::FlexFlow::Node const &node, - std::vector<::FlexFlow::MultiDiOutput> const &outputs) - : node(node), outputs(outputs) {} -bool OperatorAddedResult::operator==(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) == - std::tie(other.node, other.outputs); -} -bool OperatorAddedResult::operator!=(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) != - std::tie(other.node, other.outputs); -} -bool OperatorAddedResult::operator<(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) < - std::tie(other.node, other.outputs); -} -bool OperatorAddedResult::operator>(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) > - std::tie(other.node, other.outputs); -} -bool OperatorAddedResult::operator<=(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) <= - std::tie(other.node, other.outputs); -} -bool OperatorAddedResult::operator>=(OperatorAddedResult const &other) const { - return std::tie(this->node, this->outputs) >= - std::tie(other.node, other.outputs); -} -} // namespace FlexFlow - -namespace FlexFlow { -std::string format_as(OperatorAddedResult const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, OperatorAddedResult const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/dataflow_input.dtg.cc b/lib/pcg/src/pcg/dataflow_input.dtg.cc deleted file mode 100644 index bd5a43dfa9..0000000000 --- a/lib/pcg/src/pcg/dataflow_input.dtg.cc +++ /dev/null @@ -1,41 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/dataflow_input.variant.toml -/* proj-data -{ - "generated_from": "d6a7f4570e36e257383529e9bf9390ec" -} -*/ - -#include "pcg/dataflow_input.dtg.h" - -namespace FlexFlow { -DataflowInput::DataflowInput(::FlexFlow::MultiDiOutput const &v) - : raw_variant(v) {} -DataflowInput::DataflowInput(int const &v) : raw_variant(v) {} -bool DataflowInput::operator==(DataflowInput const &other) const { - return this->raw_variant == other.raw_variant; -} -bool DataflowInput::operator!=(DataflowInput const &other) const { - return this->raw_variant != other.raw_variant; -} -bool DataflowInput::operator<(DataflowInput const &other) const { - return this->raw_variant < other.raw_variant; -} -bool DataflowInput::operator>(DataflowInput const &other) const { - return this->raw_variant > other.raw_variant; -} -bool DataflowInput::operator<=(DataflowInput const &other) const { - return this->raw_variant <= other.raw_variant; -} -bool DataflowInput::operator>=(DataflowInput const &other) const { - return this->raw_variant >= other.raw_variant; -} -} // namespace FlexFlow -namespace std { -size_t hash<::FlexFlow::DataflowInput>::operator()( - ::FlexFlow::DataflowInput const &x) const { - return std::hash>{}( - x.raw_variant); -} -} // namespace std diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc new file mode 100644 index 0000000000..211392cbaa --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc @@ -0,0 +1,28 @@ +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +V1DataflowGraph to_v1(DataflowGraphView const &g) { + return to_v1(g, enumerate(get_nodes(g)).reversed()); +} + +V1DataflowGraph to_v1(DataflowGraphView const &g, + bidict const &nodes) { + std::unordered_set edges; + for (DataflowEdge const &e : get_edges(g)) { + edges.insert(V1GraphEdge{nodes.at_l(e.src.node), + e.src.idx, + nodes.at_l(e.dst.node), + e.dst.idx}); + } + + return V1DataflowGraph{ + sorted(values(nodes)), + edges, + }; +} + + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.cc new file mode 100644 index 0000000000..6d5134a2fe --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.cc @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +/* proj-data +{ + "generated_from": "6d7fce9dbb6976f4365ecb0be547955c" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" + +#include + +namespace FlexFlow { +V1DataflowGraph::V1DataflowGraph( + std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +::FlexFlow::V1DataflowGraph + adl_serializer<::FlexFlow::V1DataflowGraph>::from_json(json const &j) { + return ::FlexFlow::V1DataflowGraph{ + j.at("nodes").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer<::FlexFlow::V1DataflowGraph>::to_json( + json &j, ::FlexFlow::V1DataflowGraph const &v) { + j["__type"] = "V1DataflowGraph"; + j["nodes"] = v.nodes; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1DataflowGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1DataflowGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc index 7f7e670782..6098235269 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml /* proj-data { - "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" + "generated_from": "ac98d063410ebe1c14f58ea8e17c272e" } */ diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc new file mode 100644 index 0000000000..89b69e024e --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc @@ -0,0 +1,10 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +/* proj-data +{ + "generated_from": "5b6ac94ce5ca0fe62b2309c7a87b583a" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc deleted file mode 100644 index 626cca4f95..0000000000 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc +++ /dev/null @@ -1,53 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml -/* proj-data -{ - "generated_from": "582054edb983c3cc31d9273ce29421eb" -} -*/ - -#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" - -#include - -namespace FlexFlow { -V1MultiDiGraph::V1MultiDiGraph( - std::vector const &nodes, - std::vector const &ports, - std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) - : nodes(nodes), ports(ports), edges(edges) {} -} // namespace FlexFlow - -namespace nlohmann { -::FlexFlow::V1MultiDiGraph - adl_serializer<::FlexFlow::V1MultiDiGraph>::from_json(json const &j) { - return ::FlexFlow::V1MultiDiGraph{ - j.at("nodes").template get>(), - j.at("ports").template get>(), - j.at("edges") - .template get>()}; -} -void adl_serializer<::FlexFlow::V1MultiDiGraph>::to_json( - json &j, ::FlexFlow::V1MultiDiGraph const &v) { - j["__type"] = "V1MultiDiGraph"; - j["nodes"] = v.nodes; - j["ports"] = v.ports; - j["edges"] = v.edges; -} -} // namespace nlohmann - -namespace FlexFlow { -std::string format_as(V1MultiDiGraph const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, V1MultiDiGraph const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc index 6a1fb33193..11e9a2006f 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "1339be6e86e9818c36d6ecf5475e2d4b" + "generated_from": "c9b193f4f31976528951a507119483a3" } */ @@ -11,7 +11,8 @@ namespace FlexFlow { ParallelComputationGraph::ParallelComputationGraph( - ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + ::FlexFlow::LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const + &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc index 096c9b4374..bb55b8af11 100644 --- a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/tensor_guid_t.struct.toml /* proj-data { - "generated_from": "1e3914b97a465f1752ce510614145b37" + "generated_from": "1a659fe73845127890c449f505aef094" } */ @@ -12,7 +12,7 @@ #include namespace FlexFlow { -tensor_guid_t::tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output) +tensor_guid_t::tensor_guid_t(::FlexFlow::DataflowOutput const &raw_graph_output) : raw_graph_output(raw_graph_output) {} bool tensor_guid_t::operator==(tensor_guid_t const &other) const { return std::tie(this->raw_graph_output) == std::tie(other.raw_graph_output); @@ -38,7 +38,7 @@ namespace std { size_t hash::operator()( ::FlexFlow::tensor_guid_t const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + result ^= std::hash<::FlexFlow::DataflowOutput>{}(x.raw_graph_output) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h index 948a7abae6..3661c5e6b2 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml /* proj-data { - "generated_from": "d80cf2e618d64df284c2647430a12a86" + "generated_from": "c220bfd8b5a57e4941e4739c84d20054" } */ @@ -12,8 +12,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "utils/fmt.h" -#include "utils/hash-utils-core.h" +#include "utils/fmt/vector.h" +#include "utils/hash/vector.h" #include #include #include diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml index 91313f159b..46b703a7fc 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml @@ -10,8 +10,8 @@ features = [ includes = [ "", - "utils/hash-utils-core.h", - "utils/fmt.h", + "utils/hash/vector.h", + "utils/fmt/vector.h", ] [[values]] diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc index 27a82c4ffe..928de8201d 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml /* proj-data { - "generated_from": "d80cf2e618d64df284c2647430a12a86" + "generated_from": "c220bfd8b5a57e4941e4739c84d20054" } */ diff --git a/lib/utils/include/utils/check_fmtable.h b/lib/utils/include/utils/check_fmtable.h index 3b4e55c459..3c0d1368b1 100644 --- a/lib/utils/include/utils/check_fmtable.h +++ b/lib/utils/include/utils/check_fmtable.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H +#include + #define CHECK_FMTABLE(...) \ static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ #__VA_ARGS__ " must be fmtable"); diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index b02c95bf77..46979f4945 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -253,6 +253,9 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, template void inplace_sorted_by(C &c, F const &f); +template +std::vector sorted(C const &c); + template std::vector sorted_by(C const &c, F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index fbaf572df1..60df0caca3 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -595,6 +595,13 @@ void inplace_sorted_by(C &c, F const &f) { std::sort(c.begin(), c.end(), custom_comparator); } +template +std::vector sorted(C const &c) { + std::vector result(c.begin(), c.end()); + inplace_sorted_by(result, [](Elem const &l, Elem const &r) { return l < r; }); + return result; +} + template std::vector sorted_by(C const &c, F const &f) { std::vector result(c.begin(), c.end()); diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h new file mode 100644 index 0000000000..1744130134 --- /dev/null +++ b/lib/utils/include/utils/fmt/map.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MAP_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include + +namespace fmt { + +template +struct formatter< + ::std::map, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::map const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + /* CHECK_FMTABLE(K); */ + /* CHECK_FMTABLE(V); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return + * fmt::to_string(p); }); */ + std::string result = ""; + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::map const &m) { + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h new file mode 100644 index 0000000000..bc50757400 --- /dev/null +++ b/lib/utils/include/utils/fmt/set.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_SET_H + +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::set, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + return formatter::format("{" + result + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::set const &x) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph.h b/lib/utils/include/utils/graph.h index 80ef621c88..91f0ea6eb5 100644 --- a/lib/utils/include/utils/graph.h +++ b/lib/utils/include/utils/graph.h @@ -1,18 +1,14 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_H #define _FLEXFLOW_UTILS_GRAPH_H -#include "graph/adjacency_digraph.h" -#include "graph/adjacency_multidigraph.h" +#include "graph/digraph/adjacency_digraph.h" #include "graph/algorithms.h" -#include "graph/construction.h" -#include "graph/digraph.h" -#include "graph/labelled_graphs.h" -#include "graph/multidigraph.h" -#include "graph/node.h" -#include "graph/open_graphs.h" -#include "graph/serialparallel.h" +#include "graph/digraph/digraph.h" +// #include "graph/labelled_graphs.h" +#include "graph/node/node.dtg.h" +// #include "graph/open_graphs.h" +#include "graph/serial_parallel/serialparallel.h" #include "graph/traversal.h" -#include "graph/undirected.h" -#include "graph/views.h" +#include "graph/undirected/undirected_graph.h" #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 6f64c3459d..15c10f68e3 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -8,6 +8,7 @@ // #include "utils/graph/upward_open_multidigraph/upward_open_multidigraph_view.h" // #include "utils/graph/downward_open_multidigraph/downward_open_multidigraph_view.h" #include "utils/dot_file.h" +#include "utils/graph/graph_split.dtg.h" namespace FlexFlow { @@ -17,7 +18,6 @@ std::vector add_nodes(DiGraph &, int); // std::vector add_nodes(MultiDiGraph &, int); // std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); -std::unordered_set get_nodes(GraphView const &); // std::unordered_set get_nodes(OpenMultiDiEdge const &); std::unordered_set query_nodes(GraphView const &, @@ -209,9 +209,6 @@ std::unordered_set> std::unordered_set get_transitive_reduction_delta(DiGraphView const &); -using GraphSplit = - std::pair, std::unordered_set>; - // std::pair split_edge(MultiDiEdge const &e); // MultiDiEdge unsplit_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &); diff --git a/lib/utils/include/utils/graph/construction.h b/lib/utils/include/utils/graph/construction.h deleted file mode 100644 index 655afe9c2c..0000000000 --- a/lib/utils/include/utils/graph/construction.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_CONSTRUCTION_H -#define _FLEXFLOW_UTILS_GRAPH_CONSTRUCTION_H - -#include "multidigraph.h" -#include "node.h" -#include -#include -#include - -namespace FlexFlow { - -template -G make_multidigraph(std::size_t num_nodes, - std::function( - std::vector const &)> const &edges) { - G g; - std::vector nodes; - for (std::size_t i = 0; i < num_nodes; i++) { - nodes.push_back(g.add_node()); - } - - for (MultiDiEdge const &e : edges(nodes)) { - g.add_edge(e); - } - return g; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h new file mode 100644 index 0000000000..247460f4df --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_H + +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include + +namespace FlexFlow { + +std::unordered_set get_edges(DataflowGraphView const &); +std::vector get_incoming_edges(DataflowGraphView const &, Node const &); +std::vector get_inputs(DataflowGraphView const &, Node const &); +std::vector get_outputs(DataflowGraphView const &, Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h index aa4f20e575..5057cd8dd8 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.dtg.h @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml /* proj-data { - "generated_from": "111e640382a80b659bc33dd86a416ded" + "generated_from": "e88f46c93e5d1c025271ad70a3bcd105" } */ @@ -11,7 +11,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_DTG_H #include "fmt/format.h" -#include "utils/fmt/unordered_set.h" #include "utils/graph/node/node.dtg.h" #include "utils/graph/query_set.h" #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml index 6957a87863..0b0c5a41d8 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml @@ -10,7 +10,6 @@ features = [ includes = [ "utils/graph/query_set.h", "utils/graph/node/node.dtg.h", - "utils/fmt/unordered_set.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h index dd07355e48..54708cff31 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h @@ -22,6 +22,9 @@ struct DataflowGraphView : virtual DiGraphView { create(Args &&...args) { return DataflowGraphView(make_cow_ptr(std::forward(args)...)); } +protected: + using DiGraphView::DiGraphView; + private: IDataflowGraphView const &get_interface() const; }; diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h new file mode 100644 index 0000000000..f373a06dae --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_H + +#include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" + +namespace FlexFlow { + +DataflowOutputQuery dataflow_output_query_all(); +DataflowOutputQuery dataflow_output_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h new file mode 100644 index 0000000000..4e9e508e39 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_UNORDERED_SET_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_UNORDERED_SET_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/i_dataflow_graph.h" +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +struct UnorderedSetDataflowGraph : public IDataflowGraph { +public: + UnorderedSetDataflowGraph(); + + NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set query_edges(DataflowEdgeQuery const &) const override; + std::unordered_set query_outputs(DataflowOutputQuery const &) const override; + + UnorderedSetDataflowGraph *clone() const override; +private: + UnorderedSetDataflowGraph(NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges, + std::unordered_set const &outputs); + +private: + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set edges; + std::unordered_set outputs; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h new file mode 100644 index 0000000000..a3575cec6f --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml +/* proj-data +{ + "generated_from": "817156d78fd6385f97978fdd02d1b925" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct DataflowGraphOutput { + DataflowGraphOutput() = delete; + explicit DataflowGraphOutput(int const &index); + + bool operator==(DataflowGraphOutput const &) const; + bool operator!=(DataflowGraphOutput const &) const; + bool operator<(DataflowGraphOutput const &) const; + bool operator>(DataflowGraphOutput const &) const; + bool operator<=(DataflowGraphOutput const &) const; + bool operator>=(DataflowGraphOutput const &) const; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowGraphOutput> { + size_t operator()(::FlexFlow::DataflowGraphOutput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowGraphOutput const &); +std::ostream &operator<<(std::ostream &, DataflowGraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_OUTPUT_DTG_H diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml new file mode 100644 index 0000000000..fa5f0c6e37 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DataflowGraphOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "", +] + +[[fields]] +name = "index" +type = "int" diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h new file mode 100644 index 0000000000..4ec0443901 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml +/* proj-data +{ + "generated_from": "2488765bd934738ef9111699bb8b71e3" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowOutputEdge { + DataflowOutputEdge() = delete; + explicit DataflowOutputEdge(::FlexFlow::DataflowOutput const &src, + ::FlexFlow::DataflowGraphOutput const &dst); + + bool operator==(DataflowOutputEdge const &) const; + bool operator!=(DataflowOutputEdge const &) const; + bool operator<(DataflowOutputEdge const &) const; + bool operator>(DataflowOutputEdge const &) const; + bool operator<=(DataflowOutputEdge const &) const; + bool operator>=(DataflowOutputEdge const &) const; + ::FlexFlow::DataflowOutput src; + ::FlexFlow::DataflowGraphOutput dst; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowOutputEdge> { + size_t operator()(::FlexFlow::DataflowOutputEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutputEdge const &); +std::ostream &operator<<(std::ostream &, DataflowOutputEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml new file mode 100644 index 0000000000..eaba1b47b9 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowOutputEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::DataflowOutput" + +[[fields]] +name = "dst" +type = "::FlexFlow::DataflowGraphOutput" diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.h new file mode 100644 index 0000000000..aa7caa83ec --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.h @@ -0,0 +1,116 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml +/* proj-data +{ + "generated_from": "0c40ab695b9c1dca5465aea45190a3fa" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DOWNWARD_OPEN_DATAFLOW_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DOWNWARD_OPEN_DATAFLOW_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" +#include "utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct DownwardOpenDataflowEdge { + DownwardOpenDataflowEdge() = delete; + explicit DownwardOpenDataflowEdge(::FlexFlow::DataflowOutputEdge const &); + explicit DownwardOpenDataflowEdge(::FlexFlow::DataflowEdge const &); + template + static constexpr bool IsPartOfDownwardOpenDataflowEdge_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowOutputEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowEdge>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type DownwardOpenDataflowEdge", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowOutputEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowEdge>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type DownwardOpenDataflowEdge", + this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfDownwardOpenDataflowEdge_v, + "DownwardOpenDataflowEdge::has() expected one of " + "[::FlexFlow::DataflowOutputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfDownwardOpenDataflowEdge_v, + "DownwardOpenDataflowEdge::get() expected one of " + "[::FlexFlow::DataflowOutputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfDownwardOpenDataflowEdge_v, + "DownwardOpenDataflowEdge::get() expected one of " + "[::FlexFlow::DataflowOutputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(DownwardOpenDataflowEdge const &) const; + bool operator!=(DownwardOpenDataflowEdge const &) const; + bool operator<(DownwardOpenDataflowEdge const &) const; + bool operator>(DownwardOpenDataflowEdge const &) const; + bool operator<=(DownwardOpenDataflowEdge const &) const; + bool operator>=(DownwardOpenDataflowEdge const &) const; + std::variant<::FlexFlow::DataflowOutputEdge, ::FlexFlow::DataflowEdge> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DownwardOpenDataflowEdge> { + size_t operator()(::FlexFlow::DownwardOpenDataflowEdge const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::DownwardOpenDataflowEdge const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::DownwardOpenDataflowEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DOWNWARD_OPEN_DATAFLOW_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml new file mode 100644 index 0000000000..7bcd87c7e9 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DownwardOpenDataflowEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", + "utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h" +] + +[[values]] +type = "::FlexFlow::DataflowOutputEdge" + +[[values]] +type = "::FlexFlow::DataflowEdge" diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_graph.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_graph.h new file mode 100644 index 0000000000..d9a3fd6053 --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_graph.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h" +#include "utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph.h" + +namespace FlexFlow { + +struct DownwardOpenDataflowGraph : virtual DataflowGraph { +public: + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(DataflowEdgeQuery const &) const; + std::unordered_set query_outputs(DataflowOutputQuery const &) const; + std::vector get_graph_outputs() const; + +protected: + using DataflowGraph::DataflowGraph; + +private: + IDownwardOpenDataflowGraph &get_interface(); + IDownwardOpenDataflowGraph const &get_interface() const; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph.h new file mode 100644 index 0000000000..7e163098fc --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_I_DOWNWARD_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_I_DOWNWARD_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" + +namespace FlexFlow { + +struct IDownwardOpenDataflowGraph : virtual public IDownwardOpenDataflowGraphView { + virtual NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) = 0; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph_view.h new file mode 100644 index 0000000000..c6711c755c --- /dev/null +++ b/lib/utils/include/utils/graph/downward_open_dataflow_graph/i_downward_open_dataflow_graph_view.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_I_DOWNWARD_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DOWNWARD_OPEN_DATAFLOW_GRAPH_I_DOWNWARD_OPEN_DATAFLOW_GRAPH_VIEW_H + +namespace FlexFlow { + +struct IDownwardOpenDataflowGraphView : virtual public IDownwardOpenDataflowGraphView { + virtual std::unordered_set query_graph_outputs() const; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/graph_split.dtg.h b/lib/utils/include/utils/graph/graph_split.dtg.h new file mode 100644 index 0000000000..9d9d04cc50 --- /dev/null +++ b/lib/utils/include/utils/graph/graph_split.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/graph_split.struct.toml +/* proj-data +{ + "generated_from": "bf08a68806136ac698f1206f84cb907f" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_GRAPH_SPLIT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_GRAPH_SPLIT_DTG_H + +#include "fmt/format.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/hash/unordered_set.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct GraphSplit { + GraphSplit() = delete; + explicit GraphSplit(std::unordered_set<::FlexFlow::Node> const &first, + std::unordered_set<::FlexFlow::Node> const &second); + + bool operator==(GraphSplit const &) const; + bool operator!=(GraphSplit const &) const; + std::unordered_set<::FlexFlow::Node> first; + std::unordered_set<::FlexFlow::Node> second; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::GraphSplit> { + size_t operator()(::FlexFlow::GraphSplit const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(GraphSplit const &); +std::ostream &operator<<(std::ostream &, GraphSplit const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_GRAPH_SPLIT_DTG_H diff --git a/lib/utils/include/utils/graph/graph_split.struct.toml b/lib/utils/include/utils/graph/graph_split.struct.toml new file mode 100644 index 0000000000..1f393a9318 --- /dev/null +++ b/lib/utils/include/utils/graph/graph_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "GraphSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h new file mode 100644 index 0000000000..34ae475ab4 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_I_LABELLED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_I_LABELLED_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledDataflowGraph : virtual public ILabelledDataflowGraphView { +public: + virtual NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) = 0; + + virtual ~ILabelledDataflowGraph() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h new file mode 100644 index 0000000000..9f0fc0f30d --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_I_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_I_LABELLED_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledDataflowGraphView : virtual public IDataflowGraphView { +public: + virtual NodeLabel const &at(Node const &) const = 0; + virtual OutputLabel const &at(DataflowOutput const &) const = 0; + + virtual ~ILabelledDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h new file mode 100644 index 0000000000..00eb0250b5 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_LABELLED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_LABELLED_DATAFLOW_GRAPH_H + +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledDataflowGraph : virtual LabelledDataflowGraphView { +private: + using Interface = ILabelledDataflowGraph; +public: + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) { + return this->get_interface().add_node(node_label, inputs, output_labels); + } + +private: + Interface &get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + } + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h new file mode 100644 index 0000000000..0b372a0f70 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_LABELLED_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledDataflowGraphView : virtual public DataflowGraphView { +private: + using Interface = ILabelledDataflowGraphView; +public: + NodeLabel const &at(Node const &n) const { + return this->get_interface().at(n); + } + OutputLabel const &at(DataflowOutput const &o) const { + return this->get_interface().at(o); + } + + template + static typename std::enable_if::value, + LabelledDataflowGraphView>::type + create() { + return LabelledDataflowGraphView(make_cow_ptr()); + } +protected: + using DataflowGraphView::DataflowGraphView; + +private: + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/algorithms.h b/lib/utils/include/utils/graph/node/algorithms.h new file mode 100644 index 0000000000..fbc18d562f --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_H + +#include "utils/graph/node/graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(GraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h new file mode 100644 index 0000000000..e456420094 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &); +std::vector get_inputs(OpenDataflowGraphView const &); +std::vector get_incoming_edges(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h new file mode 100644 index 0000000000..685e1e26b3 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml +/* proj-data +{ + "generated_from": "7d6fe1350bb6f70771a7481a8e36aa2e" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_INPUT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_INPUT_DTG_H + +#include "fmt/format.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowGraphInput { + DataflowGraphInput() = delete; + explicit DataflowGraphInput(int const &idx); + + bool operator==(DataflowGraphInput const &) const; + bool operator!=(DataflowGraphInput const &) const; + bool operator<(DataflowGraphInput const &) const; + bool operator>(DataflowGraphInput const &) const; + bool operator<=(DataflowGraphInput const &) const; + bool operator>=(DataflowGraphInput const &) const; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowGraphInput> { + size_t operator()(::FlexFlow::DataflowGraphInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowGraphInput const &); +std::ostream &operator<<(std::ostream &, DataflowGraphInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_INPUT_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml new file mode 100644 index 0000000000..6d047ed878 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "DataflowGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h new file mode 100644 index 0000000000..a717cc2cb2 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml +/* proj-data +{ + "generated_from": "adf27ca64d88e17594764cefbcb7934f" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowInputEdge { + DataflowInputEdge() = delete; + explicit DataflowInputEdge(::FlexFlow::DataflowGraphInput const &src, + ::FlexFlow::DataflowInput const &dst); + + bool operator==(DataflowInputEdge const &) const; + bool operator!=(DataflowInputEdge const &) const; + bool operator<(DataflowInputEdge const &) const; + bool operator>(DataflowInputEdge const &) const; + bool operator<=(DataflowInputEdge const &) const; + bool operator>=(DataflowInputEdge const &) const; + ::FlexFlow::DataflowGraphInput src; + ::FlexFlow::DataflowInput dst; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowInputEdge> { + size_t operator()(::FlexFlow::DataflowInputEdge const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInputEdge const &); +std::ostream &operator<<(std::ostream &, DataflowInputEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml new file mode 100644 index 0000000000..fdfcfcf511 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowInputEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "src" +type = "::FlexFlow::DataflowGraphInput" + +[[fields]] +name = "dst" +type = "::FlexFlow::DataflowInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h new file mode 100644 index 0000000000..9add99a920 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml +/* proj-data +{ + "generated_from": "7bc0c4aa108438c9f24536f8b669b532" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +namespace FlexFlow { +struct DataflowInputEdgeQuery { + DataflowInputEdgeQuery() = delete; + explicit DataflowInputEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::DataflowGraphInput> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dst_nodes, + ::FlexFlow::query_set const &dst_idxs); + + bool operator==(DataflowInputEdgeQuery const &) const; + bool operator!=(DataflowInputEdgeQuery const &) const; + bool operator<(DataflowInputEdgeQuery const &) const; + bool operator>(DataflowInputEdgeQuery const &) const; + bool operator<=(DataflowInputEdgeQuery const &) const; + bool operator>=(DataflowInputEdgeQuery const &) const; + ::FlexFlow::query_set<::FlexFlow::DataflowGraphInput> srcs; + ::FlexFlow::query_set<::FlexFlow::Node> dst_nodes; + ::FlexFlow::query_set dst_idxs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::DataflowInputEdgeQuery> { + size_t operator()(::FlexFlow::DataflowInputEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInputEdgeQuery const &); +std::ostream &operator<<(std::ostream &, DataflowInputEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h new file mode 100644 index 0000000000..0a5c45013d --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_H + +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h" + +namespace FlexFlow { + +DataflowInputEdgeQuery dataflow_input_edge_query_all(); +DataflowInputEdgeQuery dataflow_input_edge_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml new file mode 100644 index 0000000000..544a05af85 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DataflowInputEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/query_set.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "srcs" +type = "::FlexFlow::query_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "dst_nodes" +type = "::FlexFlow::query_set<::FlexFlow::Node>" + +[[fields]] +name = "dst_idxs" +type = "::FlexFlow::query_set" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph.h b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph.h new file mode 100644 index 0000000000..66d551c69e --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_I_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_I_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" + +namespace FlexFlow { + +struct IOpenDataflowGraph : virtual public IOpenDataflowGraphView { + virtual NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) = 0; + virtual DataflowGraphInput add_input() = 0; + virtual IOpenDataflowGraph *clone() const = 0; + + virtual ~IOpenDataflowGraph() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h new file mode 100644 index 0000000000..c485e32c68 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_I_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_I_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" + +namespace FlexFlow { + +struct IOpenDataflowGraphView : virtual public IDataflowGraphView { + virtual std::vector get_inputs() const = 0; + virtual std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const = 0; + + std::unordered_set query_edges(DataflowEdgeQuery const &) const override final; + + virtual ~IOpenDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOpenDataflowGraphView); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h new file mode 100644 index 0000000000..4fbbb3431a --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h @@ -0,0 +1,113 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml +/* proj-data +{ + "generated_from": "33e3c8ad4602c3e20c29b6c0dfa104ca" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OpenDataflowEdge { + OpenDataflowEdge() = delete; + explicit OpenDataflowEdge(::FlexFlow::DataflowInputEdge const &); + explicit OpenDataflowEdge(::FlexFlow::DataflowEdge const &); + template + static constexpr bool IsPartOfOpenDataflowEdge_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowInputEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenDataflowEdge", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowInputEdge>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowEdge>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenDataflowEdge", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOpenDataflowEdge_v, + "OpenDataflowEdge::has() expected one of " + "[::FlexFlow::DataflowInputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOpenDataflowEdge_v, + "OpenDataflowEdge::get() expected one of " + "[::FlexFlow::DataflowInputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOpenDataflowEdge_v, + "OpenDataflowEdge::get() expected one of " + "[::FlexFlow::DataflowInputEdge, ::FlexFlow::DataflowEdge], " + "received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OpenDataflowEdge const &) const; + bool operator!=(OpenDataflowEdge const &) const; + bool operator<(OpenDataflowEdge const &) const; + bool operator>(OpenDataflowEdge const &) const; + bool operator<=(OpenDataflowEdge const &) const; + bool operator>=(OpenDataflowEdge const &) const; + std::variant<::FlexFlow::DataflowInputEdge, ::FlexFlow::DataflowEdge> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OpenDataflowEdge> { + size_t operator()(::FlexFlow::OpenDataflowEdge const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenDataflowEdge const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::OpenDataflowEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h new file mode 100644 index 0000000000..e4f9704ce4 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" + +namespace FlexFlow { + +int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml new file mode 100644 index 0000000000..29f14fcf0d --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OpenDataflowEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h", + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataflowInputEdge" + +[[values]] +type = "::FlexFlow::DataflowEdge" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h new file mode 100644 index 0000000000..512684bdbe --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h @@ -0,0 +1,50 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml +/* proj-data +{ + "generated_from": "661c106abdb03bf6cc434d87cfafefb5" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OpenDataflowEdgeQuery { + OpenDataflowEdgeQuery() = delete; + explicit OpenDataflowEdgeQuery( + ::FlexFlow::DataflowInputEdgeQuery const &input_edge_query, + ::FlexFlow::DataflowEdgeQuery const &standard_edge_query); + + bool operator==(OpenDataflowEdgeQuery const &) const; + bool operator!=(OpenDataflowEdgeQuery const &) const; + bool operator<(OpenDataflowEdgeQuery const &) const; + bool operator>(OpenDataflowEdgeQuery const &) const; + bool operator<=(OpenDataflowEdgeQuery const &) const; + bool operator>=(OpenDataflowEdgeQuery const &) const; + ::FlexFlow::DataflowInputEdgeQuery input_edge_query; + ::FlexFlow::DataflowEdgeQuery standard_edge_query; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::OpenDataflowEdgeQuery> { + size_t operator()(::FlexFlow::OpenDataflowEdgeQuery const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OpenDataflowEdgeQuery const &); +std::ostream &operator<<(std::ostream &, OpenDataflowEdgeQuery const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h new file mode 100644 index 0000000000..d6b13ab2a0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" + +namespace FlexFlow { + +OpenDataflowEdgeQuery open_dataflow_edge_query_all(); +OpenDataflowEdgeQuery open_dataflow_edge_query_none(); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml new file mode 100644 index 0000000000..1e2bb9221e --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OpenDataflowEdgeQuery" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h", +] + +[[fields]] +name = "input_edge_query" +type = "::FlexFlow::DataflowInputEdgeQuery" + +[[fields]] +name = "standard_edge_query" +type = "::FlexFlow::DataflowEdgeQuery" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph.h new file mode 100644 index 0000000000..1a3108a43a --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" + +namespace FlexFlow { + +struct OpenDataflowGraph : virtual public OpenDataflowGraphView { +public: + NodeAddedResult add_node(std::vector const &inputs, + int num_outputs); + DataflowGraphInput add_input(); + + template + static typename std::enable_if::value, + OpenDataflowGraph>::type + create(Args &&...args) { + return OpenDataflowGraph(make_cow_ptr(std::forward(args)...)); + } +protected: + using OpenDataflowGraphView::OpenDataflowGraphView; + +private: + IOpenDataflowGraph &get_interface(); + IOpenDataflowGraph const &get_interface() const; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h new file mode 100644 index 0000000000..b3875eb10d --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +struct OpenDataflowGraphView : virtual DataflowGraphView { +public: + OpenDataflowGraphView(OpenDataflowGraphView const &) = default; + OpenDataflowGraphView &operator=(OpenDataflowGraphView const &) = default; + + std::vector get_inputs() const; + std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const; + + template + static typename std::enable_if::value, + OpenDataflowGraphView>::type + create(Args &&...args) { + return OpenDataflowGraphView(make_cow_ptr(std::forward(args)...)); + } + +protected: + using DataflowGraphView::DataflowGraphView; + +private: + IOpenDataflowGraphView const &get_interface() const; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h new file mode 100644 index 0000000000..34b56dff56 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h @@ -0,0 +1,113 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml +/* proj-data +{ + "generated_from": "a212e5a39ee0d8c9ef39bc4892e15416" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OpenDataflowValue { + OpenDataflowValue() = delete; + explicit OpenDataflowValue(::FlexFlow::DataflowOutput const &); + explicit OpenDataflowValue(::FlexFlow::DataflowGraphInput const &); + template + static constexpr bool IsPartOfOpenDataflowValue_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowGraphInput>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenDataflowValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::DataflowOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::DataflowGraphInput>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenDataflowValue", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfOpenDataflowValue_v, + "OpenDataflowValue::has() expected one of [::FlexFlow::DataflowOutput, " + "::FlexFlow::DataflowGraphInput], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfOpenDataflowValue_v, + "OpenDataflowValue::get() expected one of [::FlexFlow::DataflowOutput, " + "::FlexFlow::DataflowGraphInput], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfOpenDataflowValue_v, + "OpenDataflowValue::get() expected one of [::FlexFlow::DataflowOutput, " + "::FlexFlow::DataflowGraphInput], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OpenDataflowValue const &) const; + bool operator!=(OpenDataflowValue const &) const; + bool operator<(OpenDataflowValue const &) const; + bool operator>(OpenDataflowValue const &) const; + bool operator<=(OpenDataflowValue const &) const; + bool operator>=(OpenDataflowValue const &) const; + std::variant<::FlexFlow::DataflowOutput, ::FlexFlow::DataflowGraphInput> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OpenDataflowValue> { + size_t operator()(::FlexFlow::OpenDataflowValue const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenDataflowValue const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::OpenDataflowValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml new file mode 100644 index 0000000000..ba28a8772a --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OpenDataflowValue" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataflowOutput" + +[[values]] +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h b/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h new file mode 100644 index 0000000000..f251136eb0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_UNORDERED_SET_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_UNORDERED_SET_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph.h" +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +struct UnorderedSetOpenDataflowGraph : public IOpenDataflowGraph { +public: + UnorderedSetOpenDataflowGraph(); + + NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const override; + std::unordered_set query_outputs(DataflowOutputQuery const &) const override; + std::vector get_inputs() const override; + + DataflowGraphInput add_input() override; + UnorderedSetOpenDataflowGraph *clone() const override; +private: + UnorderedSetOpenDataflowGraph(NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &standard_edges, + std::unordered_set const &input_edges, + std::unordered_set const &outputs, + std::vector const &graph_inputs); + +private: + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set standard_edges; + std::unordered_set input_edges; + std::unordered_set outputs; + std::vector graph_inputs; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetOpenDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 1f2ab6757f..cbbd0a092d 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -10,6 +10,7 @@ #include "utils/optional.h" #include "utils/hash-utils.h" #include "utils/fmt/unordered_set.h" +#include "utils/hash/set.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/serialparallel.h b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h index f1cf977eb3..d032707efc 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serialparallel.h +++ b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h @@ -20,6 +20,9 @@ SerialParallelDecomposition get_serial_parallel_decomposition(DiGraphView const &); std::unordered_set get_nodes(SerialParallelDecomposition const &sp); +std::unordered_set get_nodes(Serial const &); +std::unordered_set get_nodes(Parallel const &); +std::unordered_set get_nodes(Node const &); // std::unordered_map parallel_extend(MultiDiGraph &g, // MultiDiGraph const &ext); diff --git a/lib/utils/include/utils/hash-utils-core.h b/lib/utils/include/utils/hash-utils-core.h deleted file mode 100644 index ea333563d0..0000000000 --- a/lib/utils/include/utils/hash-utils-core.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_HASH_UTILS_CORE_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_HASH_UTILS_CORE_H - -#include -#include -#include -#include -#include - -namespace FlexFlow { - -template -std::size_t get_std_hash(T const &v) { - std::hash hasher; - return hasher(v); -} - -// tuple hashing pulled from -// https://www.variadic.xyz/2018/01/15/hashing-stdpair-and-stdtuple/ -template -inline void hash_combine(std::size_t &seed, T const &v) { - std::hash hasher; - seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - -template -inline void hash_combine(std::size_t &seed, T const &v, Ts... rest) { - hash_combine(seed, v); - hash_combine(seed, rest...); -} - -template -void unordered_container_hash(std::size_t &seed, T const &t) { - hash_combine(seed, t.size()); - size_t total = 0; - for (auto const &v : t) { - total += get_std_hash(v); - } - hash_combine(seed, total); -} - -template -void iter_hash(std::size_t &seed, It start, It end) { - hash_combine(seed, std::distance(start, end)); - for (; start < end; start++) { - hash_combine(seed, *start); - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/hash-utils.h b/lib/utils/include/utils/hash-utils.h index 831a12e554..ea333563d0 100644 --- a/lib/utils/include/utils/hash-utils.h +++ b/lib/utils/include/utils/hash-utils.h @@ -1,51 +1,52 @@ -#ifndef _FLEXFLOW_HASH_UTILS_H -#define _FLEXFLOW_HASH_UTILS_H +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_HASH_UTILS_CORE_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_HASH_UTILS_CORE_H + +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +std::size_t get_std_hash(T const &v) { + std::hash hasher; + return hasher(v); +} + +// tuple hashing pulled from +// https://www.variadic.xyz/2018/01/15/hashing-stdpair-and-stdtuple/ +template +inline void hash_combine(std::size_t &seed, T const &v) { + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +template +inline void hash_combine(std::size_t &seed, T const &v, Ts... rest) { + hash_combine(seed, v); + hash_combine(seed, rest...); +} -#include "containers.h" -#include "hash-utils-core.h" -#include -#include - -using namespace FlexFlow; - -namespace std { template -struct hash> { - size_t operator()(std::unordered_set const &s) const { - size_t result = 0; - unordered_container_hash(result, s); - return result; +void unordered_container_hash(std::size_t &seed, T const &t) { + hash_combine(seed, t.size()); + size_t total = 0; + for (auto const &v : t) { + total += get_std_hash(v); } -}; - -template -struct hash> { - size_t operator()(std::set const &s) const { - size_t result = 0; - unordered_container_hash(result, s); - return result; + hash_combine(seed, total); +} + +template +void iter_hash(std::size_t &seed, It start, It end) { + hash_combine(seed, std::distance(start, end)); + for (; start < end; start++) { + hash_combine(seed, *start); } -}; - -template -struct hash> { - size_t operator()(std::unordered_map const &m) const { - size_t result = 0; - unordered_container_hash(result, m); - return result; - } -}; - -template -struct hash> { - size_t operator()(std::map const &m) const { - size_t result = 0; - unordered_container_hash(result, m); - return result; - } -}; - +} -} // namespace std +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/hash/map.h b/lib/utils/include/utils/hash/map.h new file mode 100644 index 0000000000..48e9cdeac0 --- /dev/null +++ b/lib/utils/include/utils/hash/map.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MAP_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::map const &m) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, m); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/pair.h b/lib/utils/include/utils/hash/pair.h index 5d0af39848..0a8fb61564 100644 --- a/lib/utils/include/utils/hash/pair.h +++ b/lib/utils/include/utils/hash/pair.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_PAIR_H #include -#include "utils/hash-utils-core.h" +#include "utils/hash-utils.h" namespace std { @@ -11,8 +11,8 @@ struct hash> { size_t operator()(std::pair const &p) const { size_t seed = 283746; - hash_combine(seed, p.first); - hash_combine(seed, p.second); + ::FlexFlow::hash_combine(seed, p.first); + ::FlexFlow::hash_combine(seed, p.second); return seed; } diff --git a/lib/utils/include/utils/hash/set.h b/lib/utils/include/utils/hash/set.h new file mode 100644 index 0000000000..1f565382a9 --- /dev/null +++ b/lib/utils/include/utils/hash/set.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_SET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::set const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/tuple.h b/lib/utils/include/utils/hash/tuple.h index de64264064..76d228c642 100644 --- a/lib/utils/include/utils/hash/tuple.h +++ b/lib/utils/include/utils/hash/tuple.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_TUPLE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_TUPLE_H -#include "utils/hash-utils-core.h" +#include "utils/hash-utils.h" #include namespace std { @@ -23,7 +23,7 @@ struct hash> { inline typename std::enable_if < Idx::type hash_combine_tup(size_t &seed, std::tuple const &tup) const { - hash_combine(seed, std::get(tup)); + ::FlexFlow::hash_combine(seed, std::get(tup)); // on to next element hash_combine_tup(seed, tup); diff --git a/lib/utils/include/utils/hash/unordered_map.h b/lib/utils/include/utils/hash/unordered_map.h new file mode 100644 index 0000000000..1435784249 --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_map.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MAP_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_map const &m) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, m); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/unordered_set.h b/lib/utils/include/utils/hash/unordered_set.h new file mode 100644 index 0000000000..acf10bd491 --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_set.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_SET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_set const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/vector.h b/lib/utils/include/utils/hash/vector.h index 3785076288..3c9ec5cbe7 100644 --- a/lib/utils/include/utils/hash/vector.h +++ b/lib/utils/include/utils/hash/vector.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_VECTOR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_VECTOR_H -#include "utils/hash-utils-core.h" +#include "utils/hash-utils.h" #include namespace std { @@ -10,7 +10,7 @@ template struct hash> { size_t operator()(std::vector const &vec) const { size_t seed = 0; - iter_hash(seed, vec.cbegin(), vec.cend()); + ::FlexFlow::iter_hash(seed, vec.cbegin(), vec.cend()); return seed; } }; diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 643315ff64..76f03549a4 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H #include "fmt.decl.h" -#include "hash-utils-core.h" +#include "hash-utils.h" #include "test_types.h" #include "type_traits_core.h" #include diff --git a/lib/utils/src/fp16.cc b/lib/utils/src/fp16.cc index c42f8ff179..f9dbf486ab 100644 --- a/lib/utils/src/fp16.cc +++ b/lib/utils/src/fp16.cc @@ -1,9 +1,10 @@ #include "utils/fp16.h" +#include "utils/hash-utils.h" namespace std { size_t hash::operator()(half h) const { - return get_std_hash(static_cast(h)); + return ::FlexFlow::get_std_hash(static_cast(h)); } } // namespace std diff --git a/lib/utils/src/utils/check_fmtable.cc b/lib/utils/src/utils/check_fmtable.cc new file mode 100644 index 0000000000..c466f46acf --- /dev/null +++ b/lib/utils/src/utils/check_fmtable.cc @@ -0,0 +1 @@ +#include "utils/check_fmtable.h" diff --git a/lib/utils/src/utils/fmt/expected.cc b/lib/utils/src/utils/fmt/expected.cc new file mode 100644 index 0000000000..e5d6a8dc32 --- /dev/null +++ b/lib/utils/src/utils/fmt/expected.cc @@ -0,0 +1 @@ +#include "utils/fmt/expected.h" diff --git a/lib/utils/src/utils/fmt/map.cc b/lib/utils/src/utils/fmt/map.cc new file mode 100644 index 0000000000..21db320044 --- /dev/null +++ b/lib/utils/src/utils/fmt/map.cc @@ -0,0 +1 @@ +#include "utils/fmt/map.h" diff --git a/lib/utils/src/utils/fmt/pair.cc b/lib/utils/src/utils/fmt/pair.cc new file mode 100644 index 0000000000..16e5f82a3c --- /dev/null +++ b/lib/utils/src/utils/fmt/pair.cc @@ -0,0 +1 @@ +#include "utils/fmt/pair.h" diff --git a/lib/utils/src/utils/fmt/set.cc b/lib/utils/src/utils/fmt/set.cc new file mode 100644 index 0000000000..857367af48 --- /dev/null +++ b/lib/utils/src/utils/fmt/set.cc @@ -0,0 +1 @@ +#include "utils/fmt/set.h" diff --git a/lib/utils/src/utils/fmt/unordered_map.cc b/lib/utils/src/utils/fmt/unordered_map.cc new file mode 100644 index 0000000000..f8746e85a0 --- /dev/null +++ b/lib/utils/src/utils/fmt/unordered_map.cc @@ -0,0 +1 @@ +#include "utils/fmt/unordered_map.h" diff --git a/lib/utils/src/utils/fmt/unordered_set.cc b/lib/utils/src/utils/fmt/unordered_set.cc new file mode 100644 index 0000000000..354eb2f9e7 --- /dev/null +++ b/lib/utils/src/utils/fmt/unordered_set.cc @@ -0,0 +1 @@ +#include "utils/fmt/unordered_set.h" diff --git a/lib/utils/src/utils/fmt/vector.cc b/lib/utils/src/utils/fmt/vector.cc new file mode 100644 index 0000000000..507778c8e6 --- /dev/null +++ b/lib/utils/src/utils/fmt/vector.cc @@ -0,0 +1 @@ +#include "utils/fmt/vector.h" diff --git a/lib/utils/src/utils/graph.cc b/lib/utils/src/utils/graph.cc new file mode 100644 index 0000000000..a8aceb6403 --- /dev/null +++ b/lib/utils/src/utils/graph.cc @@ -0,0 +1 @@ +#include "utils/graph.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 5c86ff1086..0e45056ce3 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -1,6 +1,7 @@ #include "utils/graph/algorithms.h" #include "utils/containers.h" #include "utils/exception.h" +#include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" #include "utils/graph/traversal.h" #include "utils/graph/views/views.h" @@ -53,10 +54,6 @@ std::vector add_nodes(DiGraph &g, int num_nodes) { // return node_ports; // } -std::unordered_set get_nodes(GraphView const &g) { - return g.query_nodes(node_query_all()); -} - // std::unordered_set get_nodes(InputMultiDiEdge const &edge) { // return {edge.dst}; // } diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc new file mode 100644 index 0000000000..f7db16dfe2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -0,0 +1,35 @@ +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(DataflowGraphView const &g) { + return g.query_edges(dataflow_edge_query_all()); +} + +std::vector get_incoming_edges(DataflowGraphView const &g, Node const &n) { + return sorted_by(g.query_edges(DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + {n}, + query_set::matchall(), + }), [](DataflowEdge const &l, DataflowEdge const &r) { return l.dst.idx < r.dst.idx; }); +} + +std::vector get_inputs(DataflowGraphView const &g, Node const &n) { + return transform(get_incoming_edges(g, n), + [](DataflowEdge const &e) { return e.src; }); +} + +std::vector get_outputs(DataflowGraphView const &g, Node const &n) { + return sorted_by(g.query_outputs(DataflowOutputQuery{ + query_set{n}, + query_set::matchall(), + }), + [](DataflowOutput const &l, DataflowOutput const &r) { + return l.idx < r.idx; + }); + +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc index d80b5b7afd..f66b4a89a1 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.dtg.cc @@ -3,7 +3,7 @@ // lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.struct.toml /* proj-data { - "generated_from": "111e640382a80b659bc33dd86a416ded" + "generated_from": "e88f46c93e5d1c025271ad70a3bcd105" } */ diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc index 5fc3b177f2..18dc7516e8 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc @@ -24,7 +24,7 @@ IDataflowGraph &DataflowGraph::get_interface() { } IDataflowGraph const &DataflowGraph::get_interface() const { - return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc new file mode 100644 index 0000000000..b739c7da68 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc @@ -0,0 +1,19 @@ +#include "utils/graph/dataflow_graph/dataflow_output_query.h" + +namespace FlexFlow { + +DataflowOutputQuery dataflow_output_query_all() { + return DataflowOutputQuery{ + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowOutputQuery dataflow_output_query_none() { + return DataflowOutputQuery{ + query_set::match_none(), + query_set::match_none(), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph.cc new file mode 100644 index 0000000000..4e74f7e711 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/i_dataflow_graph.cc @@ -0,0 +1 @@ +#include "utils/graph/dataflow_graph/i_dataflow_graph.h" diff --git a/lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc new file mode 100644 index 0000000000..6fd7177f4d --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc @@ -0,0 +1,60 @@ +#include "utils/graph/dataflow_graph/unordered_set_dataflow_graph.h" +#include "utils/containers/enumerate_vector.h" + +namespace FlexFlow { + +UnorderedSetDataflowGraph::UnorderedSetDataflowGraph() {} +UnorderedSetDataflowGraph::UnorderedSetDataflowGraph(NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges, + std::unordered_set const &outputs) + : node_source(node_source), nodes(nodes), edges(edges), outputs(outputs) +{} + +NodeAddedResult UnorderedSetDataflowGraph::add_node(std::vector const &inputs, + int num_outputs) { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + + for (auto const &[input_idx, input_src] : enumerate_vector(inputs)) { + this->edges.insert(DataflowEdge{input_src, DataflowInput{new_node, input_idx}}); + } + + std::vector new_outputs = transform(count(num_outputs), + [&](int output_idx) { return DataflowOutput{new_node, output_idx}; }); + extend(this->outputs, new_outputs); + + return NodeAddedResult{new_node, new_outputs}; +} + +std::unordered_set UnorderedSetDataflowGraph::query_nodes(NodeQuery const &q) const { + return apply_query(q.nodes, this->nodes); +} + +std::unordered_set UnorderedSetDataflowGraph::query_edges(DataflowEdgeQuery const &q) const { + return filter(this->edges, [&](DataflowEdge const &e) { + return includes(q.src_nodes, e.src.node) + && includes(q.dst_nodes, e.dst.node) + && includes(q.src_idxs, e.src.idx) + && includes(q.dst_idxs, e.dst.idx); + }); +} + +std::unordered_set UnorderedSetDataflowGraph::query_outputs(DataflowOutputQuery const &q) const { + return filter(this->outputs, [&](DataflowOutput const &o) { + return includes(q.nodes, o.node) + && includes(q.output_idxs, o.idx); + }); +} + +UnorderedSetDataflowGraph *UnorderedSetDataflowGraph::clone() const { + return new UnorderedSetDataflowGraph{ + this->node_source, + this->nodes, + this->edges, + this->outputs, + }; +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.cc b/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.cc new file mode 100644 index 0000000000..579556addd --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.cc @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_graph_output.struct.toml +/* proj-data +{ + "generated_from": "817156d78fd6385f97978fdd02d1b925" +} +*/ + +#include "utils/graph/downward_open_dataflow_graph/dataflow_graph_output.dtg.h" + +#include + +namespace FlexFlow { +DataflowGraphOutput::DataflowGraphOutput(int const &index) : index(index) {} +bool DataflowGraphOutput::operator==(DataflowGraphOutput const &other) const { + return std::tie(this->index) == std::tie(other.index); +} +bool DataflowGraphOutput::operator!=(DataflowGraphOutput const &other) const { + return std::tie(this->index) != std::tie(other.index); +} +bool DataflowGraphOutput::operator<(DataflowGraphOutput const &other) const { + return std::tie(this->index) < std::tie(other.index); +} +bool DataflowGraphOutput::operator>(DataflowGraphOutput const &other) const { + return std::tie(this->index) > std::tie(other.index); +} +bool DataflowGraphOutput::operator<=(DataflowGraphOutput const &other) const { + return std::tie(this->index) <= std::tie(other.index); +} +bool DataflowGraphOutput::operator>=(DataflowGraphOutput const &other) const { + return std::tie(this->index) >= std::tie(other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowGraphOutput const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowGraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowGraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.cc b/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.cc new file mode 100644 index 0000000000..243118a558 --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/dataflow_output_edge.struct.toml +/* proj-data +{ + "generated_from": "2488765bd934738ef9111699bb8b71e3" +} +*/ + +#include "utils/graph/downward_open_dataflow_graph/dataflow_output_edge.dtg.h" + +#include + +namespace FlexFlow { +DataflowOutputEdge::DataflowOutputEdge( + ::FlexFlow::DataflowOutput const &src, + ::FlexFlow::DataflowGraphOutput const &dst) + : src(src), dst(dst) {} +bool DataflowOutputEdge::operator==(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) == std::tie(other.src, other.dst); +} +bool DataflowOutputEdge::operator!=(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) != std::tie(other.src, other.dst); +} +bool DataflowOutputEdge::operator<(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) < std::tie(other.src, other.dst); +} +bool DataflowOutputEdge::operator>(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) > std::tie(other.src, other.dst); +} +bool DataflowOutputEdge::operator<=(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) <= std::tie(other.src, other.dst); +} +bool DataflowOutputEdge::operator>=(DataflowOutputEdge const &other) const { + return std::tie(this->src, this->dst) >= std::tie(other.src, other.dst); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowOutputEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataflowOutput>{}(x.src) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataflowGraphOutput>{}(x.dst) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowOutputEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowOutputEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.cc b/lib/utils/src/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.cc new file mode 100644 index 0000000000..23e1ec2d03 --- /dev/null +++ b/lib/utils/src/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.variant.toml +/* proj-data +{ + "generated_from": "0c40ab695b9c1dca5465aea45190a3fa" +} +*/ + +#include "utils/graph/downward_open_dataflow_graph/downward_open_dataflow_edge.dtg.h" + +#include + +namespace FlexFlow { +DownwardOpenDataflowEdge::DownwardOpenDataflowEdge( + ::FlexFlow::DataflowOutputEdge const &v) + : raw_variant(v) {} +DownwardOpenDataflowEdge::DownwardOpenDataflowEdge( + ::FlexFlow::DataflowEdge const &v) + : raw_variant(v) {} +bool DownwardOpenDataflowEdge::operator==( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant == other.raw_variant; +} +bool DownwardOpenDataflowEdge::operator!=( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant != other.raw_variant; +} +bool DownwardOpenDataflowEdge::operator<( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant < other.raw_variant; +} +bool DownwardOpenDataflowEdge::operator>( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant > other.raw_variant; +} +bool DownwardOpenDataflowEdge::operator<=( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool DownwardOpenDataflowEdge::operator>=( + DownwardOpenDataflowEdge const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::DownwardOpenDataflowEdge>::operator()( + ::FlexFlow::DownwardOpenDataflowEdge const &x) const { + return std::hash< + std::variant<::FlexFlow::DataflowOutputEdge, ::FlexFlow::DataflowEdge>>{}( + x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::DownwardOpenDataflowEdge const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DownwardOpenDataflowEdge", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::DownwardOpenDataflowEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/graph_split.dtg.cc b/lib/utils/src/utils/graph/graph_split.dtg.cc new file mode 100644 index 0000000000..8ea5650873 --- /dev/null +++ b/lib/utils/src/utils/graph/graph_split.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/graph_split.struct.toml +/* proj-data +{ + "generated_from": "bf08a68806136ac698f1206f84cb907f" +} +*/ + +#include "utils/graph/graph_split.dtg.h" + +#include + +namespace FlexFlow { +GraphSplit::GraphSplit(std::unordered_set<::FlexFlow::Node> const &first, + std::unordered_set<::FlexFlow::Node> const &second) + : first(first), second(second) {} +bool GraphSplit::operator==(GraphSplit const &other) const { + return std::tie(this->first, this->second) == + std::tie(other.first, other.second); +} +bool GraphSplit::operator!=(GraphSplit const &other) const { + return std::tie(this->first, this->second) != + std::tie(other.first, other.second); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::GraphSplit const &x) const { + size_t result = 0; + result ^= std::hash>{}(x.first) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.second) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(GraphSplit const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GraphSplit const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.cc new file mode 100644 index 0000000000..e826ff7564 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.cc new file mode 100644 index 0000000000..ed74183c61 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/node/algorithms.cc b/lib/utils/src/utils/graph/node/algorithms.cc new file mode 100644 index 0000000000..69fcdfa067 --- /dev/null +++ b/lib/utils/src/utils/graph/node/algorithms.cc @@ -0,0 +1,11 @@ +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_query.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(GraphView const &g) { + return g.query_nodes(node_query_all()); +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc new file mode 100644 index 0000000000..38dff4510e --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -0,0 +1,31 @@ +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &g) { + return g.query_edges(open_dataflow_edge_query_all()); +} + +std::vector get_inputs(OpenDataflowGraphView const &g) { + return g.get_inputs(); +} + +std::vector get_incoming_edges(OpenDataflowGraphView const &g, Node const &n) { + return sorted_by(g.query_edges(OpenDataflowEdgeQuery{ + DataflowInputEdgeQuery{ + query_set::matchall(), + {n}, + query_set::matchall(), + }, + DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + {n}, + query_set::matchall(), + }, + }), [](OpenDataflowEdge const &l, OpenDataflowEdge const &r) { return get_open_dataflow_edge_dst_idx(l) < get_open_dataflow_edge_dst_idx(r); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.cc new file mode 100644 index 0000000000..65a1efd4e9 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.cc @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml +/* proj-data +{ + "generated_from": "7d6fe1350bb6f70771a7481a8e36aa2e" +} +*/ + +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" + +#include + +namespace FlexFlow { +DataflowGraphInput::DataflowGraphInput(int const &idx) : idx(idx) {} +bool DataflowGraphInput::operator==(DataflowGraphInput const &other) const { + return std::tie(this->idx) == std::tie(other.idx); +} +bool DataflowGraphInput::operator!=(DataflowGraphInput const &other) const { + return std::tie(this->idx) != std::tie(other.idx); +} +bool DataflowGraphInput::operator<(DataflowGraphInput const &other) const { + return std::tie(this->idx) < std::tie(other.idx); +} +bool DataflowGraphInput::operator>(DataflowGraphInput const &other) const { + return std::tie(this->idx) > std::tie(other.idx); +} +bool DataflowGraphInput::operator<=(DataflowGraphInput const &other) const { + return std::tie(this->idx) <= std::tie(other.idx); +} +bool DataflowGraphInput::operator>=(DataflowGraphInput const &other) const { + return std::tie(this->idx) >= std::tie(other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowGraphInput const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowGraphInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowGraphInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.cc new file mode 100644 index 0000000000..e91f7c7f9c --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge.struct.toml +/* proj-data +{ + "generated_from": "adf27ca64d88e17594764cefbcb7934f" +} +*/ + +#include "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h" + +#include + +namespace FlexFlow { +DataflowInputEdge::DataflowInputEdge(::FlexFlow::DataflowGraphInput const &src, + ::FlexFlow::DataflowInput const &dst) + : src(src), dst(dst) {} +bool DataflowInputEdge::operator==(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) == std::tie(other.src, other.dst); +} +bool DataflowInputEdge::operator!=(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) != std::tie(other.src, other.dst); +} +bool DataflowInputEdge::operator<(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) < std::tie(other.src, other.dst); +} +bool DataflowInputEdge::operator>(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) > std::tie(other.src, other.dst); +} +bool DataflowInputEdge::operator<=(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) <= std::tie(other.src, other.dst); +} +bool DataflowInputEdge::operator>=(DataflowInputEdge const &other) const { + return std::tie(this->src, this->dst) >= std::tie(other.src, other.dst); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowInputEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataflowGraphInput>{}(x.src) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataflowInput>{}(x.dst) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInputEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowInputEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc new file mode 100644 index 0000000000..c3c6711304 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc @@ -0,0 +1,20 @@ +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" + +namespace FlexFlow { + +DataflowInputEdgeQuery dataflow_input_edge_query_all() { + return DataflowInputEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set::matchall(), + }; +} +DataflowInputEdgeQuery dataflow_input_edge_query_none() { + return DataflowInputEdgeQuery{ + query_set::match_none(), + query_set::match_none(), + query_set::match_none(), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.cc new file mode 100644 index 0000000000..a2a6b9b129 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.struct.toml +/* proj-data +{ + "generated_from": "7bc0c4aa108438c9f24536f8b669b532" +} +*/ + +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h" + +#include + +namespace FlexFlow { +DataflowInputEdgeQuery::DataflowInputEdgeQuery( + ::FlexFlow::query_set<::FlexFlow::DataflowGraphInput> const &srcs, + ::FlexFlow::query_set<::FlexFlow::Node> const &dst_nodes, + ::FlexFlow::query_set const &dst_idxs) + : srcs(srcs), dst_nodes(dst_nodes), dst_idxs(dst_idxs) {} +bool DataflowInputEdgeQuery::operator==( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) == + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +bool DataflowInputEdgeQuery::operator!=( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) != + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +bool DataflowInputEdgeQuery::operator<( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) < + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +bool DataflowInputEdgeQuery::operator>( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) > + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +bool DataflowInputEdgeQuery::operator<=( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) <= + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +bool DataflowInputEdgeQuery::operator>=( + DataflowInputEdgeQuery const &other) const { + return std::tie(this->srcs, this->dst_nodes, this->dst_idxs) >= + std::tie(other.srcs, other.dst_nodes, other.dst_idxs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::DataflowInputEdgeQuery const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::DataflowGraphInput>>{}( + x.srcs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set<::FlexFlow::Node>>{}(x.dst_nodes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::query_set>{}(x.dst_idxs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(DataflowInputEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DataflowInputEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph.cc new file mode 100644 index 0000000000..4a50f0bf0f --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph.cc @@ -0,0 +1 @@ +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph.h" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc new file mode 100644 index 0000000000..0fba80d612 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc @@ -0,0 +1,19 @@ +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" + +namespace FlexFlow { + +std::unordered_set IOpenDataflowGraphView::query_edges(DataflowEdgeQuery const &q) const { + OpenDataflowEdgeQuery open_query = OpenDataflowEdgeQuery{ + dataflow_input_edge_query_none(), + q, + }; + + std::unordered_set open_edges = this->query_edges(open_query); + + return transform(open_edges, [](OpenDataflowEdge const &e) { + return e.get(); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc new file mode 100644 index 0000000000..632c77df2c --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/overload.h" + +namespace FlexFlow { + +int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { + return e.visit(overload { + [](DataflowEdge const &e) { return e.dst.idx; }, + [](DataflowInputEdge const &e) { return e.dst.idx; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.cc new file mode 100644 index 0000000000..0e1e8425e6 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.cc @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.variant.toml +/* proj-data +{ + "generated_from": "33e3c8ad4602c3e20c29b6c0dfa104ca" +} +*/ + +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" + +#include + +namespace FlexFlow { +OpenDataflowEdge::OpenDataflowEdge(::FlexFlow::DataflowInputEdge const &v) + : raw_variant(v) {} +OpenDataflowEdge::OpenDataflowEdge(::FlexFlow::DataflowEdge const &v) + : raw_variant(v) {} +bool OpenDataflowEdge::operator==(OpenDataflowEdge const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OpenDataflowEdge::operator!=(OpenDataflowEdge const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OpenDataflowEdge::operator<(OpenDataflowEdge const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OpenDataflowEdge::operator>(OpenDataflowEdge const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OpenDataflowEdge::operator<=(OpenDataflowEdge const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OpenDataflowEdge::operator>=(OpenDataflowEdge const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OpenDataflowEdge>::operator()( + ::FlexFlow::OpenDataflowEdge const &x) const { + return std::hash< + std::variant<::FlexFlow::DataflowInputEdge, ::FlexFlow::DataflowEdge>>{}( + x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenDataflowEdge const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OpenDataflowEdge", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OpenDataflowEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc new file mode 100644 index 0000000000..9d72c8a009 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc @@ -0,0 +1,21 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" + +namespace FlexFlow { + +OpenDataflowEdgeQuery open_dataflow_edge_query_all() { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all(), + dataflow_edge_query_all(), + }; +} + +OpenDataflowEdgeQuery open_dataflow_edge_query_none() { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_none(), + dataflow_edge_query_none(), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.cc new file mode 100644 index 0000000000..c2f0c1665d --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.cc @@ -0,0 +1,77 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.struct.toml +/* proj-data +{ + "generated_from": "661c106abdb03bf6cc434d87cfafefb5" +} +*/ + +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" + +#include + +namespace FlexFlow { +OpenDataflowEdgeQuery::OpenDataflowEdgeQuery( + ::FlexFlow::DataflowInputEdgeQuery const &input_edge_query, + ::FlexFlow::DataflowEdgeQuery const &standard_edge_query) + : input_edge_query(input_edge_query), + standard_edge_query(standard_edge_query) {} +bool OpenDataflowEdgeQuery::operator==( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) == + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool OpenDataflowEdgeQuery::operator!=( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) != + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool OpenDataflowEdgeQuery::operator<( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) < + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool OpenDataflowEdgeQuery::operator>( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) > + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool OpenDataflowEdgeQuery::operator<=( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) <= + std::tie(other.input_edge_query, other.standard_edge_query); +} +bool OpenDataflowEdgeQuery::operator>=( + OpenDataflowEdgeQuery const &other) const { + return std::tie(this->input_edge_query, this->standard_edge_query) >= + std::tie(other.input_edge_query, other.standard_edge_query); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::OpenDataflowEdgeQuery const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::DataflowInputEdgeQuery>{}(x.input_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataflowEdgeQuery>{}(x.standard_edge_query) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OpenDataflowEdgeQuery const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OpenDataflowEdgeQuery const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph.cc new file mode 100644 index 0000000000..527af6e091 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph.cc @@ -0,0 +1,22 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" + +namespace FlexFlow { + +NodeAddedResult OpenDataflowGraph::add_node(std::vector const &inputs, + int num_outputs) { + return this->get_interface().add_node(inputs, num_outputs); +} + +DataflowGraphInput OpenDataflowGraph::add_input() { + return this->get_interface().add_input(); +} + +IOpenDataflowGraph &OpenDataflowGraph::get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); +} + +IOpenDataflowGraph const &OpenDataflowGraph::get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc new file mode 100644 index 0000000000..8c031f68ec --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc @@ -0,0 +1,18 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::vector OpenDataflowGraphView::get_inputs() const { + return this->get_interface().get_inputs(); +} + +std::unordered_set OpenDataflowGraphView::query_edges(OpenDataflowEdgeQuery const &q) const { + return this->get_interface().query_edges(q); +} + +IOpenDataflowGraphView const &OpenDataflowGraphView::get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); +} + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.cc new file mode 100644 index 0000000000..9e769d1e59 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.dtg.cc @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.variant.toml +/* proj-data +{ + "generated_from": "a212e5a39ee0d8c9ef39bc4892e15416" +} +*/ + +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +#include + +namespace FlexFlow { +OpenDataflowValue::OpenDataflowValue(::FlexFlow::DataflowOutput const &v) + : raw_variant(v) {} +OpenDataflowValue::OpenDataflowValue(::FlexFlow::DataflowGraphInput const &v) + : raw_variant(v) {} +bool OpenDataflowValue::operator==(OpenDataflowValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OpenDataflowValue::operator!=(OpenDataflowValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OpenDataflowValue::operator<(OpenDataflowValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OpenDataflowValue::operator>(OpenDataflowValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OpenDataflowValue::operator<=(OpenDataflowValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OpenDataflowValue::operator>=(OpenDataflowValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OpenDataflowValue>::operator()( + ::FlexFlow::OpenDataflowValue const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OpenDataflowValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OpenDataflowValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OpenDataflowValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc new file mode 100644 index 0000000000..66d416bdb8 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc @@ -0,0 +1,79 @@ +#include "utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h" + +namespace FlexFlow { + +UnorderedSetOpenDataflowGraph::UnorderedSetOpenDataflowGraph() {} + +UnorderedSetOpenDataflowGraph::UnorderedSetOpenDataflowGraph(NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &standard_edges, + std::unordered_set const &input_edges, + std::unordered_set const &outputs, + std::vector const &graph_inputs) + : node_source(node_source), + nodes(nodes), + standard_edges(standard_edges), + input_edges(input_edges), + outputs(outputs), + graph_inputs(graph_inputs) +{ } + + +NodeAddedResult UnorderedSetOpenDataflowGraph::add_node(std::vector const &inputs, + int num_outputs) { + NOT_IMPLEMENTED(); +} + +std::unordered_set UnorderedSetOpenDataflowGraph::query_nodes(NodeQuery const &q) const { + return apply_query(q.nodes, this->nodes); +} + +std::unordered_set UnorderedSetOpenDataflowGraph::query_edges(OpenDataflowEdgeQuery const &q) const { + std::unordered_set standard_edges = filter(this->standard_edges, [&](DataflowEdge const &e) { + return includes(q.standard_edge_query.src_nodes, e.src.node) + && includes(q.standard_edge_query.dst_nodes, e.dst.node) + && includes(q.standard_edge_query.src_idxs, e.src.idx) + && includes(q.standard_edge_query.dst_idxs, e.dst.idx); + }); + std::unordered_set input_edges = filter(this->input_edges, [&](DataflowInputEdge const &e) { + return includes(q.input_edge_query.srcs, e.src) + && includes(q.input_edge_query.dst_nodes, e.dst.node) + && includes(q.input_edge_query.dst_idxs, e.dst.idx); + }); + return set_union( + transform(standard_edges, [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }), + transform(input_edges, [](DataflowInputEdge const &e) { return OpenDataflowEdge{e}; }) + ); + +} + +std::unordered_set UnorderedSetOpenDataflowGraph::query_outputs(DataflowOutputQuery const &q) const { + return filter(this->outputs, [&](DataflowOutput const &o) { + return includes(q.nodes, o.node) + && includes(q.output_idxs, o.idx); + }); +} + +std::vector UnorderedSetOpenDataflowGraph::get_inputs() const { + return this->graph_inputs; +} + +DataflowGraphInput UnorderedSetOpenDataflowGraph::add_input() { + int idx = this->graph_inputs.size(); + DataflowGraphInput result = DataflowGraphInput{idx}; + this->graph_inputs.push_back(result); + return result; +} + +UnorderedSetOpenDataflowGraph *UnorderedSetOpenDataflowGraph::clone() const { + return new UnorderedSetOpenDataflowGraph{ + this->node_source, + this->nodes, + this->standard_edges, + this->input_edges, + this->outputs, + this->graph_inputs, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/query_set.cc b/lib/utils/src/utils/graph/query_set.cc new file mode 100644 index 0000000000..6300fac42b --- /dev/null +++ b/lib/utils/src/utils/graph/query_set.cc @@ -0,0 +1 @@ +#include "utils/graph/query_set.h" diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc index 233eb028e2..76aa3b2c00 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc @@ -41,45 +41,14 @@ std::unordered_set from_source_to_sink(DiGraphView const &g, return result; } -struct FlattenAST { - void add_flattened_child_to_parent(IntermediateSpDecompositionTree &parent, - SplitAST const &child) { - if (std::holds_alternative(child)) { - parent.children.push_back(child); - return; - } - - IntermediateSpDecompositionTree child_node = get(child); - - if (parent.type == child_node.type) { - extend(parent.children, child_node.children); - } else { - parent.children.push_back(child); - } - } - - SplitAST operator()(IntermediateSpDecompositionTree const &ast_node) { - IntermediateSpDecompositionTree result(ast_node.type); - for (SplitAST const &child : ast_node.children) { - SplitAST flattened_child = flatten_ast(child); - add_flattened_child_to_parent(result, flattened_child); - } - return result; - } - - SplitAST operator()(Node const &ast_node) { - return ast_node; - } -}; - SerialParallelDecomposition get_serial_parallel_decomposition(DiGraphView const &g) { - SplitAST ast = sp_decomposition(g); + std::variant ast = sp_decomposition(g); return to_final_ast(ast); } std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return sp.visit([](auto &&t) { return get_nodes(t); }); + return sp.visit>([](auto &&t) { return get_nodes(t); }); } std::unordered_set get_nodes(Serial const &serial) { @@ -93,7 +62,7 @@ std::unordered_set get_nodes(Serial const &serial) { std::unordered_set get_nodes(Parallel const ¶llel) { return set_union( transform(parallel.children, [](std::variant const &child) { - return visit(GetNodes{}, child); + return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc index 2fd1e81aa0..9fa4492bb6 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc @@ -2,7 +2,7 @@ #include "utils/graph/algorithms.h" #include "utils/graph/serial_parallel/sink_settings.dtg.h" #include "utils/graph/serial_parallel/source_settings.dtg.h" -#include "utils/graph/serial_parallel/split_ast_node.dtg.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -48,7 +48,8 @@ DiGraphView source_to_sink_subgraph(DiGraphView const &g, g, from_source_to_sink(g, srcs, sinks, include_src, include_sink)); } -IntermediateSpDecompositionTree sp_decomposition(DiGraphView const &g) { +std::variant + sp_decomposition(DiGraphView const &g) { if (num_nodes(g) == 1) { return get_only(get_nodes(g)); } @@ -89,18 +90,50 @@ IntermediateSpDecompositionTree parallel_decomposition(DiGraphView const &g) { return split; } -SplitAST flatten_ast(SplitAST const &ast) { - return visit(FlattenAST{}, ast); +struct FlattenAST { + void add_flattened_child_to_parent(IntermediateSpDecompositionTree &parent, + std::variant const &child) { + if (std::holds_alternative(child)) { + parent.children.push_back(child); + return; + } + + IntermediateSpDecompositionTree child_node = get(child); + + if (parent.type == child_node.type) { + extend(parent.children, child_node.children); + } else { + parent.children.push_back(child); + } + } + + std::variant operator()(IntermediateSpDecompositionTree const &ast_node) { + IntermediateSpDecompositionTree result(ast_node.type, {}); + for (std::variant const &child : ast_node.children) { + std::variant flattened_child = flatten_ast(child); + add_flattened_child_to_parent(result, flattened_child); + } + return result; + } + + std::variant operator()(Node const &ast_node) { + return ast_node; + } +}; + + +std::variant flatten_ast(std::variant const &ast) { + return std::visit(FlattenAST{}, ast); } struct ToFinalAST { std::variant operator()(IntermediateSpDecompositionTree const &node) { if (node.type == SplitType::SERIAL) { - return Serial{transform(node.children, [](SplitAST const &s) { + return Serial{transform(node.children, [](std::variant const &s) { return narrow>(internal_to_final_ast(s)).value(); })}; } else { - return Parallel{transform(node.children, [](SplitAST const &s) { + return Parallel{transform(node.children, [](std::variant const &s) { return narrow>(internal_to_final_ast(s)).value(); })}; } @@ -111,11 +144,11 @@ struct ToFinalAST { } }; -std::variant internal_to_final_ast(SplitAST const &ast) { +std::variant internal_to_final_ast(std::variant const &ast) { return visit(ToFinalAST{}, ast); } -SerialParallelDecomposition to_final_ast(SplitAST const &ast) { +SerialParallelDecomposition to_final_ast(std::variant const &ast) { return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, internal_to_final_ast(ast)); } diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h index 927f50cff1..6b7671d5ef 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h @@ -12,17 +12,15 @@ namespace FlexFlow { struct ParallelInternal; -using SplitAST = std::variant; - -IntermediateSpDecompositionTree sp_decomposition(DiGraphView const &g); +std::variant sp_decomposition(DiGraphView const &g); IntermediateSpDecompositionTree parallel_decomposition(DiGraphView const &g); std::unordered_set from_source_to_sink(DiGraphView const &, Node const &src, Node const &sink); -std::variant internal_to_final_ast(SplitAST const &); -SerialParallelDecomposition to_final_ast(SplitAST const &); -SplitAST flatten_ast(SplitAST const &ast); +std::variant internal_to_final_ast(std::variant const &); +SerialParallelDecomposition to_final_ast(std::variant const &); +std::variant flatten_ast(std::variant const &ast); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/traversal.cc b/lib/utils/src/utils/graph/traversal.cc index 38758e249e..aed0eb81ef 100644 --- a/lib/utils/src/utils/graph/traversal.cc +++ b/lib/utils/src/utils/graph/traversal.cc @@ -1,6 +1,7 @@ #include "utils/graph/traversal.h" #include "utils/containers.h" #include "utils/graph/algorithms.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 7af1dc8fcc..61f20a7d4d 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,5 +1,5 @@ #include "utils/graph/undirected/undirected_edge.h" -#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 518e9784e9..2226aaad1e 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -6,6 +6,7 @@ #include "utils/graph/undirected/undirected_edge_query.h" #include "utils/graph/node/node_query.h" #include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/hash/pair.cc b/lib/utils/src/utils/hash/pair.cc new file mode 100644 index 0000000000..38740f6af5 --- /dev/null +++ b/lib/utils/src/utils/hash/pair.cc @@ -0,0 +1 @@ +#include "utils/hash/pair.h" diff --git a/lib/utils/src/utils/hash/set.cc b/lib/utils/src/utils/hash/set.cc new file mode 100644 index 0000000000..ac109492a0 --- /dev/null +++ b/lib/utils/src/utils/hash/set.cc @@ -0,0 +1 @@ +#include "utils/hash/set.h" diff --git a/lib/utils/src/utils/hash/tuple.cc b/lib/utils/src/utils/hash/tuple.cc new file mode 100644 index 0000000000..922f98700e --- /dev/null +++ b/lib/utils/src/utils/hash/tuple.cc @@ -0,0 +1 @@ +#include "utils/hash/tuple.h" diff --git a/lib/utils/src/utils/hash/unordered_map.cc b/lib/utils/src/utils/hash/unordered_map.cc new file mode 100644 index 0000000000..52c4140641 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_map.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_map.h" diff --git a/lib/utils/src/utils/hash/unordered_set.cc b/lib/utils/src/utils/hash/unordered_set.cc new file mode 100644 index 0000000000..907555d908 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_set.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_set.h" diff --git a/lib/utils/src/utils/hash/vector.cc b/lib/utils/src/utils/hash/vector.cc new file mode 100644 index 0000000000..dc1e7edfb3 --- /dev/null +++ b/lib/utils/src/utils/hash/vector.cc @@ -0,0 +1 @@ +#include "utils/hash/vector.h" diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index 40ff07285e..d139e328a9 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -3,6 +3,7 @@ ff_add_test_executable( utils-tests SRC_PATTERNS src/test_cow_ptr.cc + src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc new file mode 100644 index 0000000000..a82e3afbb5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc @@ -0,0 +1,73 @@ +#include "test/utils/doctest.h" +#include "utils/graph/dataflow_graph/unordered_set_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UnorderedSetDataflowGraph") { + DataflowGraph g = DataflowGraph::create(); + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_edges(dataflow_edge_query_all()); + std::unordered_set correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_outputs(dataflow_output_query_all()); + std::unordered_set correct = {}; + REQUIRE(result == correct); + } + + NodeAddedResult added = g.add_node({}, 2); + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {added.node}; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_edges(dataflow_edge_query_all()); + std::unordered_set correct = {}; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_outputs(dataflow_output_query_all()); + std::unordered_set correct = without_order(added.outputs); + REQUIRE(result == correct); + } + + NodeAddedResult added2 = g.add_node(added.outputs, 3); + + { + std::unordered_set result = g.query_nodes(node_query_all()); + std::unordered_set correct = {added.node, added2.node}; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_edges(dataflow_edge_query_all()); + std::unordered_set correct = { + DataflowEdge{added.outputs.at(0), DataflowInput{added2.node, 0}}, + DataflowEdge{added.outputs.at(1), DataflowInput{added2.node, 1}}, + }; + REQUIRE(result == correct); + } + + { + std::unordered_set result = g.query_outputs(dataflow_output_query_all()); + std::unordered_set correct = set_union(without_order(added.outputs), without_order(added2.outputs)); + REQUIRE(result == correct); + } + } +} From 64a3403a83d5512e546379b02d3dfa2505e126e0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 23 Jun 2024 22:08:00 -0700 Subject: [PATCH 11/71] Start refactoring substitutions --- .proj.toml | 8 +- lib/pcg/include/pcg/file_format/v1/graphs.h | 6 +- .../v1/graphs/v1_jsonable_graph.dtg.h | 100 -------------- .../graphs/v1_labelled_dataflow_graph.dtg.h | 81 ++++++----- .../v1/graphs/v1_labelled_dataflow_graph.h | 24 ++-- .../v1_labelled_dataflow_graph.struct.toml | 8 +- .../v1/graphs/v1_operator_graph.dtg.h | 47 ------- .../v1/graphs/v1_operator_graph.struct.toml | 26 ---- lib/pcg/include/pcg/layer_guid_t.dtg.h | 4 +- lib/pcg/include/pcg/layer_guid_t.struct.toml | 2 +- .../parallel_layer_guid_t.dtg.h | 4 +- .../parallel_layer_guid_t.struct.toml | 2 +- .../parallel_tensor_guid_t.dtg.h | 8 +- .../parallel_tensor_guid_t.struct.toml | 4 +- lib/pcg/src/file_format/v1/graphs.cc | 130 +----------------- lib/pcg/src/pcg/computation_graph.cc | 8 +- lib/pcg/src/pcg/computation_graph_builder.cc | 14 +- .../v1/graphs/v1_dataflow_graph.cc | 11 +- .../v1/graphs/v1_jsonable_graph.dtg.cc | 10 -- .../v1/graphs/v1_labelled_dataflow_graph.cc | 1 + .../graphs/v1_labelled_dataflow_graph.dtg.cc | 2 +- .../v1/graphs/v1_operator_graph.dtg.cc | 49 ------- lib/pcg/src/pcg/layer_guid_t.dtg.cc | 2 +- .../parallel_computation_graph.cc | 23 ++-- .../parallel_computation_graph_builder.cc | 15 +- .../parallel_layer_guid_t.dtg.cc | 2 +- .../parallel_tensor_guid_t.dtg.cc | 6 +- .../test/src/pcg/dataflow_graph/algorithms.cc | 76 ---------- .../parallel_computation_graph.cc | 30 ++-- .../parallel_computation_graph_builder.cc | 1 + .../src/test_computation_graph_builder.cc | 2 + .../operator_attribute_expr.h | 2 +- .../operator_attribute_pattern.dtg.h | 5 +- .../operator_attribute_pattern.struct.toml | 3 +- .../operator_attribute_value.dtg.h | 3 +- .../operator_attribute_value.variant.toml | 1 + .../output_graph/output_graph_expr.dtg.h | 15 +- .../output_graph_expr.struct.toml | 4 +- .../include/substitutions/pcg_pattern.dtg.h | 11 +- .../{graph_pattern.h => pcg_pattern.h} | 10 +- .../substitutions/pcg_pattern.struct.toml | 4 +- .../sub_parallel_computation_graph.dtg.h | 14 +- .../sub_parallel_computation_graph.h | 2 +- ...sub_parallel_computation_graph.struct.toml | 8 +- .../include/substitutions/substitution.dtg.h | 16 +-- .../substitutions/substitution.struct.toml | 4 +- .../tensor_pattern/satisfies_pattern.h | 2 +- .../unlabelled/edge_splits.dtg.h | 13 +- .../unlabelled/edge_splits.struct.toml | 7 +- .../match_additional_criterion.dtg.h | 19 +-- .../match_additional_criterion.struct.toml | 9 +- .../unlabelled/match_split.dtg.h | 28 +++- .../substitutions/unlabelled/match_split.h | 4 +- .../unlabelled/match_split.struct.toml | 8 +- .../multidigraph_pattern_match.dtg.h | 36 ----- .../multidigraph_pattern_match.struct.toml | 24 ---- .../unlabelled/pattern_edge.dtg.h | 39 ------ .../unlabelled/pattern_edge.struct.toml | 15 -- .../unlabelled/pattern_matching.h | 11 +- .../unlabelled/pattern_value.dtg.h | 47 +++++++ .../unlabelled/pattern_value.struct.toml | 16 +++ .../unlabelled/pattern_value_use.struct.toml | 16 +++ ...abelled_dataflow_graph_pattern_match.dtg.h | 55 ++++++++ ...d_dataflow_graph_pattern_match.struct.toml | 24 ++++ .../unlabelled/unlabelled_graph_pattern.dtg.h | 8 +- .../unlabelled/unlabelled_graph_pattern.h | 15 +- .../unlabelled_graph_pattern.struct.toml | 4 +- .../operator_pattern/get_attribute.cc | 10 ++ .../operator_attribute_pattern.dtg.cc | 2 +- .../operator_attribute_value.dtg.cc | 2 +- .../output_graph/output_graph_expr.dtg.cc | 8 +- .../{graph_pattern.cc => pcg_pattern.cc} | 2 +- .../src/substitutions/pcg_pattern.dtg.cc | 4 +- .../sub_parallel_computation_graph.cc | 4 +- .../sub_parallel_computation_graph.dtg.cc | 8 +- .../src/substitutions/substitution.dtg.cc | 10 +- .../unlabelled/edge_splits.dtg.cc | 12 +- .../match_additional_criterion.dtg.cc | 10 +- .../unlabelled/match_split.dtg.cc | 37 ++++- .../multidigraph_pattern_match.dtg.cc | 29 ---- .../unlabelled/pattern_edge.dtg.cc | 43 ------ .../unlabelled/pattern_matching.cc | 15 +- .../unlabelled/pattern_value.dtg.cc | 65 +++++++++ ...belled_dataflow_graph_pattern_match.dtg.cc | 63 +++++++++ .../unlabelled/unlabelled_graph_pattern.cc | 49 +++---- .../unlabelled_graph_pattern.dtg.cc | 4 +- lib/utils/include/utils/bidict.h | 5 + lib/utils/include/utils/containers/group_by.h | 21 +++ .../include/utils/containers/set_minus.h | 19 +++ .../utils/containers/without_nullopts.h | 12 ++ lib/utils/include/utils/graph.h | 14 +- lib/utils/include/utils/graph/algorithms.h | 19 --- .../utils/graph/dataflow_graph/algorithms.h | 1 + .../dataflow_graph/dataflow_edge_query.h | 2 + .../dataflow_graph/dataflow_output_query.h | 2 + .../include/utils/graph/digraph/algorithms.h | 21 ++- .../adjacency_digraph.h | 4 +- .../unordered_set_dataflow_graph.h | 0 ...ordered_set_labelled_open_dataflow_graph.h | 114 +++++++++++++++ .../labelled_dataflow_graph.h | 12 ++ .../labelled_dataflow_graph_view.h | 3 + .../i_labelled_open_dataflow_graph.h | 29 ++++ .../i_labelled_open_dataflow_graph_view.h | 27 ++++ .../labelled_open_dataflow_graph.h | 37 +++++ .../labelled_open_dataflow_graph_view.h | 39 ++++++ .../include/utils/graph/node/algorithms.h | 2 + .../graph/open_dataflow_graph/algorithms.h | 5 +- .../dataflow_input_edge_query.h | 2 + .../open_dataflow_graph/open_dataflow_edge.h | 3 + .../open_dataflow_edge_query.h | 2 + lib/utils/include/utils/graph/views/views.h | 2 +- lib/utils/src/utils/containers/group_by.cc | 1 + lib/utils/src/utils/containers/set_minus.cc | 1 + .../src/utils/containers/without_nullopts.cc | 8 ++ lib/utils/src/utils/graph/algorithms.cc | 130 +----------------- .../utils/graph/dataflow_graph/algorithms.cc | 4 + .../dataflow_graph/dataflow_edge_query.cc | 7 + .../dataflow_graph/dataflow_output_query.cc | 5 + .../src/utils/graph/digraph/algorithms.cc | 130 ++++++++++++++++++ .../adjacency_digraph.cc | 2 +- .../hashmap_undirected_graph.cc | 2 +- .../instances}/hashmap_undirected_graph.h | 0 .../unordered_set_dataflow_graph.cc | 2 +- ...rdered_set_labelled_open_dataflow_graph.cc | 7 + .../i_labelled_open_dataflow_graph.cc | 1 + .../i_labelled_open_dataflow_graph_view.cc | 1 + .../labelled_open_dataflow_graph.cc | 1 + .../labelled_open_dataflow_graph_view.cc | 1 + lib/utils/src/utils/graph/node/algorithms.cc | 7 + .../graph/open_dataflow_graph/algorithms.cc | 15 ++ .../dataflow_input_edge_query.cc | 6 + .../open_dataflow_graph/open_dataflow_edge.cc | 14 ++ .../open_dataflow_edge_query.cc | 8 ++ .../serialparallel_internal.cc | 1 + lib/utils/src/utils/graph/traversal.cc | 3 +- lib/utils/test/src/test_algorithms.cc | 3 +- .../utils/graph/dataflow_graph/algorithms.cc | 59 ++++++++ ...ph.cc => unordered_open_dataflow_graph.cc} | 2 +- 138 files changed, 1311 insertions(+), 1092 deletions(-) delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml delete mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc delete mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc delete mode 100644 lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc rename lib/substitutions/include/substitutions/{graph_pattern.h => pcg_pattern.h} (67%) delete mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml delete mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml rename lib/substitutions/src/substitutions/{graph_pattern.cc => pcg_pattern.cc} (97%) delete mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc delete mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_value.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc create mode 100644 lib/utils/include/utils/containers/group_by.h create mode 100644 lib/utils/include/utils/containers/set_minus.h rename lib/utils/include/utils/graph/{digraph => instances}/adjacency_digraph.h (88%) rename lib/utils/include/utils/graph/{dataflow_graph => instances}/unordered_set_dataflow_graph.h (100%) create mode 100644 lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h create mode 100644 lib/utils/src/utils/containers/group_by.cc create mode 100644 lib/utils/src/utils/containers/set_minus.cc create mode 100644 lib/utils/src/utils/containers/without_nullopts.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms.cc rename lib/utils/src/utils/graph/{digraph => instances}/adjacency_digraph.cc (97%) rename lib/utils/src/utils/graph/{undirected => instances}/hashmap_undirected_graph.cc (97%) rename lib/utils/{include/utils/graph/undirected => src/utils/graph/instances}/hashmap_undirected_graph.h (100%) rename lib/utils/src/utils/graph/{dataflow_graph => instances}/unordered_set_dataflow_graph.cc (97%) create mode 100644 lib/utils/src/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc rename lib/utils/test/src/utils/graph/dataflow_graph/{unordered_set_dataflow_graph.cc => unordered_open_dataflow_graph.cc} (97%) diff --git a/.proj.toml b/.proj.toml index a31561632d..f6e3cd2308 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,12 +6,12 @@ header_extension = ".h" build_targets = [ "utils", "op-attrs", - "kernels", + # "kernels", "pcg", - # "substitutions", + "substitutions", # "compiler", - "substitution-generator", - "local-execution", + # "substitution-generator", + # "local-execution", ] test_targets = [ "utils-tests", diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 6090d60e1a..702c79c2b6 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H #include "pcg/computation_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" #include "pcg/layer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" @@ -12,12 +12,12 @@ namespace FlexFlow { -using V1ComputationGraph = V1JsonableGraph; +using V1ComputationGraph = V1LabelledDataflowGraph; CHECK_IS_JSONABLE(V1ComputationGraph); V1ComputationGraph to_v1(ComputationGraph const &); using V1ParallelComputationGraph = - V1JsonableGraph; + V1LabelledDataflowGraph; CHECK_IS_JSONABLE(V1ParallelComputationGraph); V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h deleted file mode 100644 index 839741e86f..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h +++ /dev/null @@ -1,100 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml -/* proj-data -{ - "generated_from": "ac98d063410ebe1c14f58ea8e17c272e" -} -*/ - -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H - -#include "fmt/format.h" -#include "nlohmann/json.hpp" -#include "pcg/file_format/v1/graphs/v1_dataflow_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" -#include -#include -#include - -namespace FlexFlow { -template -struct V1JsonableGraph { - V1JsonableGraph() = delete; - explicit V1JsonableGraph( - std::unordered_map const &node_labels, - std::unordered_map const &output_labels, - ::FlexFlow::V1DataflowGraph const &graph); - - std::unordered_map node_labels; - std::unordered_map output_labels; - ::FlexFlow::V1DataflowGraph graph; -}; -} // namespace FlexFlow - -namespace nlohmann { -template -struct adl_serializer<::FlexFlow::V1JsonableGraph> { - static ::FlexFlow::V1JsonableGraph from_json(json const &); - static void to_json(json &, - ::FlexFlow::V1JsonableGraph const &); -}; -} // namespace nlohmann - -namespace FlexFlow { -template -std::string format_as(V1JsonableGraph const &); -template -std::ostream &operator<<(std::ostream &, - V1JsonableGraph const &); -} // namespace FlexFlow - -namespace FlexFlow { -template -V1JsonableGraph::V1JsonableGraph( - std::unordered_map const &node_labels, - std::unordered_map const &output_labels, - ::FlexFlow::V1DataflowGraph const &graph) - : node_labels(node_labels), output_labels(output_labels), graph(graph) {} -} // namespace FlexFlow - -namespace nlohmann { -template -::FlexFlow::V1JsonableGraph - adl_serializer<::FlexFlow::V1JsonableGraph>::from_json( - json const &j) { - return ::FlexFlow::V1JsonableGraph{ - j.at("node_labels").template get>(), - j.at("output_labels").template get>(), - j.at("graph").template get<::FlexFlow::V1DataflowGraph>()}; -} -template -void adl_serializer<::FlexFlow::V1JsonableGraph>::to_json( - json &j, ::FlexFlow::V1JsonableGraph const &v) { - j["__type"] = "V1JsonableGraph"; - j["node_labels"] = v.node_labels; - j["output_labels"] = v.output_labels; - j["graph"] = v.graph; -} -} // namespace nlohmann - -namespace FlexFlow { -template -std::string format_as(V1JsonableGraph const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -template -std::ostream &operator<<(std::ostream &s, - V1JsonableGraph const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h index 4f24e1f3b4..70a8c3396a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml /* proj-data { - "generated_from": "5b6ac94ce5ca0fe62b2309c7a87b583a" + "generated_from": "89120d1975d2e727327594b4ab8a4952" } */ @@ -19,62 +19,67 @@ #include namespace FlexFlow { -template +template struct V1LabelledDataflowGraph { V1LabelledDataflowGraph() = delete; explicit V1LabelledDataflowGraph( - std::unordered_map const &node_labels, - std::unordered_map const &output_labels, + std::unordered_map const &node_labels, + std::unordered_map> const &output_labels, ::FlexFlow::V1DataflowGraph const &graph); - std::unordered_map node_labels; - std::unordered_map output_labels; + std::unordered_map node_labels; + std::unordered_map> output_labels; ::FlexFlow::V1DataflowGraph graph; }; } // namespace FlexFlow namespace nlohmann { -template -struct adl_serializer<::FlexFlow::V1LabelledDataflowGraph> { - static ::FlexFlow::V1LabelledDataflowGraph +template +struct adl_serializer< + ::FlexFlow::V1LabelledDataflowGraph> { + static ::FlexFlow::V1LabelledDataflowGraph from_json(json const &); - static void - to_json(json &, - ::FlexFlow::V1LabelledDataflowGraph const &); + static void to_json( + json &, + ::FlexFlow::V1LabelledDataflowGraph const &); }; } // namespace nlohmann namespace FlexFlow { -template -std::string format_as(V1LabelledDataflowGraph const &); -template -std::ostream &operator<<(std::ostream &, - V1LabelledDataflowGraph const &); +template +std::string format_as(V1LabelledDataflowGraph const &); +template +std::ostream & + operator<<(std::ostream &, + V1LabelledDataflowGraph const &); } // namespace FlexFlow namespace FlexFlow { -template -V1LabelledDataflowGraph::V1LabelledDataflowGraph( - std::unordered_map const &node_labels, - std::unordered_map const &output_labels, +template +V1LabelledDataflowGraph::V1LabelledDataflowGraph( + std::unordered_map const &node_labels, + std::unordered_map> const &output_labels, ::FlexFlow::V1DataflowGraph const &graph) : node_labels(node_labels), output_labels(output_labels), graph(graph) {} } // namespace FlexFlow namespace nlohmann { -template -::FlexFlow::V1LabelledDataflowGraph - adl_serializer<::FlexFlow::V1LabelledDataflowGraph>:: - from_json(json const &j) { - return ::FlexFlow::V1LabelledDataflowGraph{ - j.at("node_labels").template get>(), - j.at("output_labels").template get>(), +template +::FlexFlow::V1LabelledDataflowGraph adl_serializer< + ::FlexFlow::V1LabelledDataflowGraph>:: + from_json(json const &j) { + return ::FlexFlow::V1LabelledDataflowGraph{ + j.at("node_labels").template get>(), + j.at("output_labels") + .template get>>(), j.at("graph").template get<::FlexFlow::V1DataflowGraph>()}; } -template -void adl_serializer<::FlexFlow::V1LabelledDataflowGraph>:: - to_json(json &j, - ::FlexFlow::V1LabelledDataflowGraph const &v) { +template +void adl_serializer< + ::FlexFlow::V1LabelledDataflowGraph>:: + to_json( + json &j, + ::FlexFlow::V1LabelledDataflowGraph const &v) { j["__type"] = "V1LabelledDataflowGraph"; j["node_labels"] = v.node_labels; j["output_labels"] = v.output_labels; @@ -83,8 +88,9 @@ void adl_serializer<::FlexFlow::V1LabelledDataflowGraph>:: } // namespace nlohmann namespace FlexFlow { -template -std::string format_as(V1LabelledDataflowGraph const &x) { +template +std::string + format_as(V1LabelledDataflowGraph const &x) { std::ostringstream oss; oss << ""; return oss.str(); } -template -std::ostream &operator<<(std::ostream &s, - V1LabelledDataflowGraph const &x) { +template +std::ostream & + operator<<(std::ostream &s, + V1LabelledDataflowGraph const &x) { return s << fmt::to_string(x); } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index 823989a89a..20e730271f 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -3,6 +3,9 @@ #include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" +#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -13,19 +16,20 @@ V1LabelledDataflowGraph bidict nodes = enumerate(get_nodes(g)); V1DataflowGraph unlabelled = to_v1(g, nodes.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return g.at(n); }); - std::unordered_map outputs = - map_values(nodes, [&](MultiDiOutput const &o) { - return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; + std::unordered_map node_labels = + map_values(nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); }); + + std::unordered_map> output_labels = + map_values(nodes.as_unordered_map(), [&](Node const &n) { + return transform(get_outputs(g, n), + [&](DataflowOutput const &o) { + return g.at(o); + }); }); - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); - - return V1JsonableGraph{ - node_labels, outputs, output_labels, unlabelled}; + return V1LabelledDataflowGraph{ + node_labels, output_labels, unlabelled}; } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index 47263a80bc..0a6a148159 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -10,8 +10,8 @@ features = [ ] template_params = [ - "NodeT", - "TensorT", + "NodeLabel", + "OutputLabel", ] includes = [ @@ -22,11 +22,11 @@ includes = [ [[fields]] name = "node_labels" -type = "std::unordered_map" +type = "std::unordered_map" [[fields]] name = "output_labels" -type = "std::unordered_map>" +type = "std::unordered_map>" [[fields]] name = "graph" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h deleted file mode 100644 index f1e9cb5a5c..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h +++ /dev/null @@ -1,47 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml -/* proj-data -{ - "generated_from": "fed215ca219af1bd375801eb2e33b473" -} -*/ - -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H - -#include "fmt/format.h" -#include "nlohmann/json.hpp" -#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" -#include -#include -#include - -namespace FlexFlow { -struct V1OperatorGraph { - V1OperatorGraph() = delete; - explicit V1OperatorGraph( - std::vector const &nodes, - std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); - - std::vector nodes; - std::unordered_set<::FlexFlow::V1GraphEdge> edges; -}; -} // namespace FlexFlow - -namespace nlohmann { -template <> -struct adl_serializer<::FlexFlow::V1OperatorGraph> { - static ::FlexFlow::V1OperatorGraph from_json(json const &); - static void to_json(json &, ::FlexFlow::V1OperatorGraph const &); -}; -} // namespace nlohmann - -namespace FlexFlow { -std::string format_as(V1OperatorGraph const &); -std::ostream &operator<<(std::ostream &, V1OperatorGraph const &); -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml deleted file mode 100644 index 2715ae176b..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "V1OperatorGraph" -features = [ - # "eq", - # "ord", - # "hash", - "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "", - "", - "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", - "utils/fmt/unordered_set.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "nodes" -type = "std::vector" - -[[fields]] -name = "edges" -type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.h b/lib/pcg/include/pcg/layer_guid_t.dtg.h index 9b0e3338d9..b4e2012899 100644 --- a/lib/pcg/include/pcg/layer_guid_t.dtg.h +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/layer_guid_t.struct.toml /* proj-data { - "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" + "generated_from": "7876f785878716e3f2af2b4a5c1cab28" } */ @@ -11,7 +11,7 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H #include "fmt/format.h" -#include "utils/graph.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/pcg/include/pcg/layer_guid_t.struct.toml b/lib/pcg/include/pcg/layer_guid_t.struct.toml index c6d4073f58..7f820cbd6d 100644 --- a/lib/pcg/include/pcg/layer_guid_t.struct.toml +++ b/lib/pcg/include/pcg/layer_guid_t.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph.h", + "utils/graph/node/node.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h index c204a5f95c..a51aa0951c 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml /* proj-data { - "generated_from": "c31301efeb92e151b04943786aa7bec1" + "generated_from": "74a9d264b25676dd5dfd62538af8cf82" } */ @@ -11,7 +11,7 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_LAYER_GUID_T_DTG_H #include "fmt/format.h" -#include "utils/graph.h" +#include "utils/graph/node/node.dtg.h" #include #include #include diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml index 63fb25a45b..85436460aa 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "utils/graph.h" + "utils/graph/node/node.dtg.h" ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h index 55a1ebcc75..af1a40f60e 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml /* proj-data { - "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" + "generated_from": "ff4f90460638385dc94c7f0e87a0bf7f" } */ @@ -11,7 +11,7 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_DTG_H #include "fmt/format.h" -#include "utils/graph/multidiedge.h" +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" #include #include #include @@ -20,7 +20,7 @@ namespace FlexFlow { struct parallel_tensor_guid_t { parallel_tensor_guid_t() = delete; explicit parallel_tensor_guid_t( - ::FlexFlow::MultiDiOutput const &raw_graph_output); + ::FlexFlow::DataflowOutput const &raw_graph_output); bool operator==(parallel_tensor_guid_t const &) const; bool operator!=(parallel_tensor_guid_t const &) const; @@ -28,7 +28,7 @@ struct parallel_tensor_guid_t { bool operator>(parallel_tensor_guid_t const &) const; bool operator<=(parallel_tensor_guid_t const &) const; bool operator>=(parallel_tensor_guid_t const &) const; - ::FlexFlow::MultiDiOutput raw_graph_output; + ::FlexFlow::DataflowOutput raw_graph_output; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml index 7837d7b39b..a9e8bbc917 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml @@ -8,9 +8,9 @@ features = [ ] includes = [ - "utils/graph/multidiedge.h" + "utils/graph/dataflow_graph/dataflow_output.dtg.h" ] [[fields]] name = "raw_graph_output" -type = "::FlexFlow::MultiDiOutput" +type = "::FlexFlow::DataflowOutput" diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index 8317c9ec6e..a8930d54ec 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,139 +1,11 @@ #include "pcg/file_format/v1/graphs.h" #include "pcg/dataflow_graph/dataflow_graph.h" -#include "pcg/file_format/v1/graphs/v1_multidigraph.h" -#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" #include "utils/graph/algorithms.h" #include "utils/integer_conversions.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" namespace FlexFlow { -/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict - * const &nodes) { */ -/* std::unordered_set edges; */ -/* for (MultiDiEdge const &e : get_edges(g)) { */ -/* size_t src_node = nodes.at_l(get_src_node(e)); */ -/* size_t dst_node = nodes.at_l(get_dst_node(e)); */ -/* size_t src_idx = size_t_from_int(get_src_idx(e)); */ -/* size_t dst_idx = size_t_from_int(get_dst_idx(e)); */ -/* V1GraphEdge v1_e = {src_node, src_idx, dst_node, dst_idx}; */ -/* edges.insert(v1_e); */ -/* } */ - -/* return V1OperatorGraph{ */ -/* count(nodes.size()), */ -/* edges, */ -/* }; */ -/* } */ - -static V1MultiDiGraph to_v1(MultiDiGraphView const &g, - bidict const &nodes, - bidict const &node_ports) { - std::unordered_set edges; - for (MultiDiEdge const &e : get_edges(g)) { - edges.insert(V1GraphEdge{nodes.at_l(e.src), - node_ports.at_l(e.src_idx), - nodes.at_l(e.dst), - node_ports.at_l(e.dst_idx)}); - } - - return V1MultiDiGraph{ - count(nodes.size()), - count(node_ports.size()), - edges, - }; -} - -/* static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { */ -/* return to_v1(g, */ -/* enumerate(get_nodes(g)).reversed(), */ -/* enumerate(get_present_node_ports(g)).reversed()); */ -/* } */ - -/* template */ -/* static V1JsonableGraph */ -/* to_v1(LabelledOperatorGraphView const &g) { */ - -/* bidict nodes = enumerate(get_nodes(g)); */ - -/* V1OperatorGraph unlabelled = to_v1(g, nodes.reversed()); */ -/* std::unordered_map node_labels = */ -/* map_values(nodes, [&](Node const &n) { return g.at(n); }); */ - -/* bidict outputs_bidict = - * enumerate(get_outputs(g)); */ -/* std::unordered_map outputs = */ -/* map_values(outputs_bidict, [&](OperatorGraphOutput const &o) { */ -/* return V1GraphOutput{nodes.at_r(get_node(o)), - * size_t_from_int(get_idx(o))}; */ -/* }); */ - -/* std::unordered_map output_labels = map_values( */ -/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); - */ - -/* return {node_labels, outputs, output_labels, unlabelled}; */ -/* } */ - -template -static bidict - get_ports_by_idx(DataflowGraph const &g) { - bidict result; - for (NodePort const &p : get_present_node_ports(g.get_raw_graph())) { - size_t idx = size_t_from_int(g.idx_for_port(p)); - result.equate(idx, p); - } - return result; -} - -template -static V1JsonableGraph - to_v1(DataflowGraph const &g) { - - bidict nodes = enumerate(get_nodes(g.get_raw_graph())); - bidict node_ports = get_ports_by_idx(g); - - V1MultiDiGraph unlabelled = - to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return g.at(n); }); - - bidict outputs_bidict = - enumerate(get_outputs(g.get_raw_graph())); - std::unordered_map outputs = - map_values(outputs_bidict, [&](MultiDiOutput const &o) { - return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; - }); - - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); - - return V1JsonableGraph{ - node_labels, outputs, output_labels, unlabelled}; -} - -template -static V1JsonableGraph - to_v1(OutputLabelledMultiDiGraphView const &g) { - bidict nodes = enumerate(get_nodes(g)); - bidict node_ports = enumerate(get_present_node_ports(g)); - - V1MultiDiGraph unlabelled = to_v1(g, nodes.reversed(), node_ports.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return g.at(n); }); - - bidict outputs_bidict = enumerate(get_outputs(g)); - std::unordered_map outputs = - map_values(outputs_bidict, [&](MultiDiOutput const &o) { - return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; - }); - - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); - - return V1JsonableGraph{ - node_labels, outputs, output_labels, unlabelled}; -} - V1ComputationGraph to_v1(ComputationGraph const &g) { return to_v1(g.raw_graph); } diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 4e0ce7d0a0..d40869a721 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,12 +1,18 @@ #include "pcg/computation_graph.h" #include "utils/containers.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { ComputationGraph make_empty_computation_graph() { - return ComputationGraph{LabelledDataflowGraph{}}; + return ComputationGraph{ + LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph + >() + }; } std::unordered_set get_layers(ComputationGraph const &cg) { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index d3dcf79ca6..6caa574e8e 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -38,7 +38,7 @@ std::vector ComputationGraphBuilder::add_layer( std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { - std::vector raw_weight_tensors; + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; TensorAttrs weight_tensor_attrs = kv.second; @@ -51,24 +51,24 @@ std::vector ComputationGraphBuilder::add_layer( ComputationGraphOpAttrs{WeightAttrs{}}, weight_name, }; - std::vector weight_layer_inputs = {}; + std::vector weight_layer_inputs = {}; std::vector weight_output_attrs = {weight_tensor_attrs}; raw_weight_tensors.push_back(get_only(this->computation_graph.raw_graph - .add_operator(weight_layer_attrs, + .add_node(weight_layer_attrs, weight_layer_inputs, weight_output_attrs) .outputs)); } - std::vector raw_inputs = transform( + std::vector raw_inputs = transform( inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = + std::vector raw_outputs = this->computation_graph.raw_graph - .add_operator( + .add_node( layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) .outputs; return transform(raw_outputs, - [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); + [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } tensor_guid_t diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc index 211392cbaa..971378a1c1 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc @@ -1,6 +1,7 @@ #include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -9,13 +10,13 @@ V1DataflowGraph to_v1(DataflowGraphView const &g) { } V1DataflowGraph to_v1(DataflowGraphView const &g, - bidict const &nodes) { + std::unordered_map const &nodes) { std::unordered_set edges; for (DataflowEdge const &e : get_edges(g)) { - edges.insert(V1GraphEdge{nodes.at_l(e.src.node), - e.src.idx, - nodes.at_l(e.dst.node), - e.dst.idx}); + edges.insert(V1GraphEdge{nodes.at(e.src.node), + size_t_from_int(e.src.idx), + nodes.at(e.dst.node), + size_t_from_int(e.dst.idx)}); } return V1DataflowGraph{ diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc deleted file mode 100644 index 6098235269..0000000000 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc +++ /dev/null @@ -1,10 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml -/* proj-data -{ - "generated_from": "ac98d063410ebe1c14f58ea8e17c272e" -} -*/ - -#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc new file mode 100644 index 0000000000..d353ccdda3 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.cc @@ -0,0 +1 @@ +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc index 89b69e024e..af33062d2f 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml /* proj-data { - "generated_from": "5b6ac94ce5ca0fe62b2309c7a87b583a" + "generated_from": "89120d1975d2e727327594b4ab8a4952" } */ diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc deleted file mode 100644 index d80e433b24..0000000000 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc +++ /dev/null @@ -1,49 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml -/* proj-data -{ - "generated_from": "fed215ca219af1bd375801eb2e33b473" -} -*/ - -#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" - -#include - -namespace FlexFlow { -V1OperatorGraph::V1OperatorGraph( - std::vector const &nodes, - std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) - : nodes(nodes), edges(edges) {} -} // namespace FlexFlow - -namespace nlohmann { -::FlexFlow::V1OperatorGraph - adl_serializer<::FlexFlow::V1OperatorGraph>::from_json(json const &j) { - return ::FlexFlow::V1OperatorGraph{ - j.at("nodes").template get>(), - j.at("edges") - .template get>()}; -} -void adl_serializer<::FlexFlow::V1OperatorGraph>::to_json( - json &j, ::FlexFlow::V1OperatorGraph const &v) { - j["__type"] = "V1OperatorGraph"; - j["nodes"] = v.nodes; - j["edges"] = v.edges; -} -} // namespace nlohmann - -namespace FlexFlow { -std::string format_as(V1OperatorGraph const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, V1OperatorGraph const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc index 91343f704f..8fd71b43ed 100644 --- a/lib/pcg/src/pcg/layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/layer_guid_t.struct.toml /* proj-data { - "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" + "generated_from": "7876f785878716e3f2af2b4a5c1cab28" } */ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 174ac07977..4cc152d7b3 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,12 +1,17 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" -#include "pcg/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { ParallelComputationGraph empty_parallel_computation_graph() { return ParallelComputationGraph{ - DataflowGraph{}}; + LabelledDataflowGraph + ::create>() + }; } std::unordered_set @@ -19,14 +24,14 @@ ParallelLayerAddedResult add_parallel_layer(ParallelComputationGraph &pcg, ParallelLayerAttrs const &layer_attrs, std::vector const &inputs, std::vector const &output_labels) { - std::vector unwrapped_inputs = transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); - OperatorAddedResult op_added = pcg.raw_graph.add_operator(layer_attrs, + std::vector unwrapped_inputs = transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + NodeAddedResult op_added = pcg.raw_graph.add_node(layer_attrs, unwrapped_inputs, output_labels); return ParallelLayerAddedResult{ parallel_layer_guid_t{op_added.node}, transform(op_added.outputs, - [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }), + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }), }; } @@ -35,7 +40,7 @@ std::vector parallel_layer_guid_t const &l) { return transform( get_inputs(pcg.raw_graph, l.raw_graph_node), - [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } std::vector @@ -43,12 +48,12 @@ std::vector parallel_layer_guid_t const &l) { return transform( get_outputs(pcg.raw_graph, l.raw_graph_node), - [](MultiDiOutput const &o) { return parallel_tensor_guid_t{o}; }); + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, parallel_tensor_guid_t const &t) { - return parallel_layer_guid_t{t.raw_graph_output.src}; + return parallel_layer_guid_t{t.raw_graph_output.node}; } ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, @@ -63,7 +68,7 @@ ParallelTensorAttrs } std::vector topological_ordering(ParallelComputationGraph const &pcg) { - return transform(topological_ordering(pcg.raw_graph), + return transform(get_topological_ordering(pcg.raw_graph), [](Node const &n) { return parallel_layer_guid_t{n}; }); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 29723ed078..8c15bfeb0a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -4,6 +4,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/enumerate_vector.h" namespace FlexFlow { @@ -444,7 +445,7 @@ std::vector ParallelComputationGraphBuilder::add_layer( std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { - std::vector raw_weight_tensors; + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; ParallelTensorAttrs weight_tensor_attrs = kv.second; @@ -457,26 +458,26 @@ std::vector ParallelComputationGraphBuilder::add_layer( PCGOperatorAttrs{WeightAttrs{}}, weight_name, }; - std::vector weight_layer_inputs = {}; + std::vector weight_layer_inputs = {}; std::vector weight_output_attrs = { weight_tensor_attrs}; raw_weight_tensors.push_back(get_only(this->pcg.raw_graph - .add_operator(weight_layer_attrs, + .add_node(weight_layer_attrs, weight_layer_inputs, weight_output_attrs) .outputs)); } - std::vector raw_inputs = + std::vector raw_inputs = transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = + std::vector raw_outputs = this->pcg.raw_graph - .add_operator( + .add_node( layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) .outputs; - return transform(raw_outputs, [](MultiDiOutput const &o) { + return transform(raw_outputs, [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc index 794a2078e7..5dc2a11e5c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.struct.toml /* proj-data { - "generated_from": "c31301efeb92e151b04943786aa7bec1" + "generated_from": "74a9d264b25676dd5dfd62538af8cf82" } */ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc index bc10f450c2..efc18b72ef 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.struct.toml /* proj-data { - "generated_from": "de2c2d33bfa5cd72f0e51954d6879f38" + "generated_from": "ff4f90460638385dc94c7f0e87a0bf7f" } */ @@ -13,7 +13,7 @@ namespace FlexFlow { parallel_tensor_guid_t::parallel_tensor_guid_t( - ::FlexFlow::MultiDiOutput const &raw_graph_output) + ::FlexFlow::DataflowOutput const &raw_graph_output) : raw_graph_output(raw_graph_output) {} bool parallel_tensor_guid_t::operator==( parallel_tensor_guid_t const &other) const { @@ -45,7 +45,7 @@ namespace std { size_t hash::operator()( ::FlexFlow::parallel_tensor_guid_t const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + result ^= std::hash<::FlexFlow::DataflowOutput>{}(x.raw_graph_output) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } diff --git a/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc b/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc deleted file mode 100644 index 7032133cdb..0000000000 --- a/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/dataflow_graph/algorithms.h" -#include "test/utils/doctest.h" -#include "utils/fmt/unordered_set.h" - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_inputs/get_outputs") { - DataflowGraph g; - - int n1_label = 1; - int n2_label = 2; - int n3_label = 3; - int n4_label = 4; - - std::string o1_label = "o1"; - std::string o2_label = "o2"; - std::string o3_label = "o3"; - std::string o4_label = "o4"; - - OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); - Node n1 = n1_added.node; - MultiDiOutput o1 = get_only(n1_added.outputs); - - OperatorAddedResult n2_added = g.add_operator(n2_label, {}, {o2_label}); - Node n2 = n2_added.node; - MultiDiOutput o2 = get_only(n2_added.outputs); - - OperatorAddedResult n3_added = g.add_operator(n3_label, {}, {o3_label}); - Node n3 = n3_added.node; - MultiDiOutput o3 = get_only(n3_added.outputs); - - OperatorAddedResult n4_added = - g.add_operator(n4_label, {o1, o2, o3}, {o4_label}); - Node n4 = n4_added.node; - MultiDiOutput o4 = get_only(n4_added.outputs); - - SUBCASE("get_inputs") { - std::vector result = get_inputs(g, n4); - std::vector correct = {o1, o2, o3}; - CHECK(result == correct); - } - - SUBCASE("get_outputs") { - std::vector result = get_outputs(g, n4); - std::vector correct = {o4}; - CHECK(result == correct); - } - } - - TEST_CASE("topological_ordering") { - DataflowGraph g; - - int n1_label = 1; - int n2_label = 2; - int n3_label = 3; - - std::string o1_label = "o1"; - std::string o2_label = "o2"; - std::string o3_label = "o3"; - - OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); - Node n1 = n1_added.node; - MultiDiOutput o1 = get_only(n1_added.outputs); - - OperatorAddedResult n2_added = g.add_operator(n2_label, {o1}, {o2_label}); - Node n2 = n2_added.node; - MultiDiOutput o2 = get_only(n2_added.outputs); - - OperatorAddedResult n3_added = g.add_operator(n3_label, {o2}, {o3_label}); - Node n3 = n3_added.node; - MultiDiOutput o3 = get_only(n3_added.outputs); - - std::vector result = topological_ordering(g); - std::vector correct = { n1, n2, n3 }; - CHECK(result == correct); - } -} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 25d7e3afe7..2d27d1267d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,6 +1,8 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "test/utils/rapidcheck.h" +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("topological_ordering") { // TODO(@lockshaw) should probably be replaced with a rapidcheck test that compares @@ -13,19 +15,19 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorAttrs tensor_label = some(); ParallelLayerAddedResult layer1_added = add_parallel_layer(pcg, layer_label, {}, {tensor_label}); - parallel_layer_guid_t layer1 = layer1_added.parallel_layer; - parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); - - ParallelLayerAddedResult layer2_added = add_parallel_layer(pcg, layer_label, {tensor1}, {tensor_label}); - parallel_layer_guid_t layer2 = layer2_added.parallel_layer; - parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); - - ParallelLayerAddedResult layer3_added = add_parallel_layer(pcg, layer_label, {tensor2}, {tensor_label}); - parallel_layer_guid_t layer3 = layer3_added.parallel_layer; - parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); - - std::vector result = topological_ordering(pcg); - std::vector correct = { layer1, layer2, layer3 }; - CHECK(result == correct); + // parallel_layer_guid_t layer1 = layer1_added.parallel_layer; + // parallel_tensor_guid_t tensor1 = get_only(layer1_added.outputs); + // + // ParallelLayerAddedResult layer2_added = add_parallel_layer(pcg, layer_label, {tensor1}, {tensor_label}); + // parallel_layer_guid_t layer2 = layer2_added.parallel_layer; + // parallel_tensor_guid_t tensor2 = get_only(layer2_added.outputs); + // + // ParallelLayerAddedResult layer3_added = add_parallel_layer(pcg, layer_label, {tensor2}, {tensor_label}); + // parallel_layer_guid_t layer3 = layer3_added.parallel_layer; + // parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); + // + // std::vector result = topological_ordering(pcg); + // std::vector correct = { layer1, layer2, layer3 }; + // CHECK(result == correct); } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 50ad727c12..6c0b19d70e 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -5,6 +5,7 @@ #include "test/utils/doctest.h" #include "utils/containers.h" #include "utils/containers/without_nullopts.h" +#include "utils/hash/pair.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index 34be83c281..08c7d5d879 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -2,6 +2,8 @@ #include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphBuilder") { ComputationGraphBuilder b; diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h index 4528847771..e63c03207b 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H -#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" #include "substitutions/operator_pattern/operator_attribute_value.dtg.h" #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h index 4a491af2f6..3f0cf87df4 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml /* proj-data { - "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" + "generated_from": "256aae1d067ff00dda6cf9a94032d17a" } */ @@ -13,7 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" -#include "utils/fmt.h" +#include "utils/fmt/unordered_set.h" +#include "utils/hash/unordered_set.h" #include #include #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml index 6facf7d3bc..8b7797af99 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml @@ -11,8 +11,9 @@ features = [ includes = [ "", - "utils/fmt.h", + "utils/fmt/unordered_set.h", "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", + "utils/hash/unordered_set.h", ] [[fields]] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h index 080909d147..9a3e0ebbbe 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml /* proj-data { - "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" + "generated_from": "c5c01fab8309c4abd9915570d2005390" } */ @@ -20,6 +20,7 @@ #include "op-attrs/pool_op.dtg.h" #include "op-attrs/regularizer_attrs.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/hash/vector.h" #include #include #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 9ab88e63c2..29ae87afb7 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -21,6 +21,7 @@ includes = [ "op-attrs/tensor_shape.dtg.h", "op-attrs/datatype.dtg.h", "", + "utils/hash/vector.h", ] [[values]] diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h index 1e78d76777..8a228b66ce 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml /* proj-data { - "generated_from": "9084c9afb2724504a6f4db4288a83a0d" + "generated_from": "9ce2d1b90d941d5362bdd9d671ff4349" } */ @@ -11,17 +11,18 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H #include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" -#include "utils/graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" namespace FlexFlow { struct OutputGraphExpr { OutputGraphExpr() = delete; - explicit OutputGraphExpr( - ::FlexFlow::NodeLabelledOpenMultiDiGraph< - ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); + explicit OutputGraphExpr(::FlexFlow::LabelledOpenDataflowGraph< + ::FlexFlow::OutputOperatorAttrsAssignment, + std::nullopt_t> const &raw_graph); - ::FlexFlow::NodeLabelledOpenMultiDiGraph< - ::FlexFlow::OutputOperatorAttrsAssignment> + ::FlexFlow::LabelledOpenDataflowGraph< + ::FlexFlow::OutputOperatorAttrsAssignment, + std::nullopt_t> raw_graph; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml index 37d87f7820..5caeff92f5 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -3,10 +3,10 @@ name = "OutputGraphExpr" features = [] includes = [ - "utils/graph.h", + "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::NodeLabelledOpenMultiDiGraph<::FlexFlow::OutputOperatorAttrsAssignment>" +type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::nullopt_t>" diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h index 98aec04e61..bc780a276e 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/pcg_pattern.struct.toml /* proj-data { - "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" + "generated_from": "95b0a94000f16024bd541c492bf8a9b1" } */ @@ -12,18 +12,17 @@ #include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" -#include "utils/graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" namespace FlexFlow { struct PCGPattern { PCGPattern() = delete; - explicit PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + explicit PCGPattern(::FlexFlow::LabelledOpenDataflowGraph< ::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern> const &raw_graph); - ::FlexFlow::OutputLabelledOpenMultiDiGraph< - ::FlexFlow::OperatorAttributePattern, - ::FlexFlow::TensorAttributePattern> + ::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> raw_graph; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h similarity index 67% rename from lib/substitutions/include/substitutions/graph_pattern.h rename to lib/substitutions/include/substitutions/pcg_pattern.h index 5f03a6e92e..0d99818860 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_H #include "substitutions/pcg_pattern.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" @@ -13,13 +13,13 @@ namespace FlexFlow { UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); TensorAttributePattern get_tensor_pattern(PCGPattern const &, - PatternEdge const &); + PatternValue const &); OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, PCGPattern const &, - MultiDiGraphPatternMatch const &); + UnlabelledDataflowGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml index 191d66a38c..31e8820b09 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml +++ b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml @@ -2,11 +2,11 @@ namespace = "FlexFlow" name = "PCGPattern" features = [] includes = [ - "utils/graph.h", + "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" +type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h index f0d6882dc9..8157f3bc70 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h @@ -3,27 +3,27 @@ // lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "0022d1b2c1447667695a120c154a0168" + "generated_from": "c8f31135c257713d2a44680af5eb7feb" } */ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H -#include "pcg/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" -#include "utils/graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" namespace FlexFlow { struct SubParallelComputationGraph { SubParallelComputationGraph() = delete; explicit SubParallelComputationGraph( - ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::LabelledOpenDataflowGraph< ::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs> const &raw_graph); - ::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> + ::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 5d40f3f975..9f45887206 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -11,7 +11,7 @@ PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, Node const &); ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, - OpenMultiDiEdge const &); + OpenDataflowValue const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml index 1ba04b544c..bcd5e42fc0 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml @@ -3,11 +3,11 @@ name = "SubParallelComputationGraph" features = [ ] includes = [ - "pcg/parallel_layer_attrs.dtg.h", - "pcg/parallel_tensor_attrs.dtg.h", - "utils/graph.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" +type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/substitution.dtg.h b/lib/substitutions/include/substitutions/substitution.dtg.h index 3515299acb..9ead9f5a12 100644 --- a/lib/substitutions/include/substitutions/substitution.dtg.h +++ b/lib/substitutions/include/substitutions/substitution.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/substitution.struct.toml /* proj-data { - "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" + "generated_from": "9e0ea4f3e23858068cc975534e6c4cf7" } */ @@ -18,19 +18,19 @@ struct Substitution { Substitution() = delete; explicit Substitution(::FlexFlow::PCGPattern const &pcg_pattern, ::FlexFlow::OutputGraphExpr const &output_graph_expr, - ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, - ::FlexFlow::InputMultiDiEdge> const + ::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, + ::FlexFlow::OpenDataflowValue> const &input_edge_match_to_output, - ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::OutputMultiDiEdge> const + ::FlexFlow::bidict<::FlexFlow::DataflowOutput, + ::FlexFlow::DataflowOutput> const &output_edge_match_to_output); ::FlexFlow::PCGPattern pcg_pattern; ::FlexFlow::OutputGraphExpr output_graph_expr; - ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge> + ::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, + ::FlexFlow::OpenDataflowValue> input_edge_match_to_output; - ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::OutputMultiDiEdge> + ::FlexFlow::bidict<::FlexFlow::DataflowOutput, ::FlexFlow::DataflowOutput> output_edge_match_to_output; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml index eb630e9308..f370ef80fd 100644 --- a/lib/substitutions/include/substitutions/substitution.struct.toml +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -17,8 +17,8 @@ type = "::FlexFlow::OutputGraphExpr" [[fields]] name = "input_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>" +type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::OpenDataflowValue>" [[fields]] name = "output_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::OutputMultiDiEdge>" +type = "::FlexFlow::bidict<::FlexFlow::DataflowOutput, ::FlexFlow::DataflowOutput>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h index b8b46669c6..e44a5ab0c7 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h index a69a5b5f6b..88ad6d0ef1 100644 --- a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml /* proj-data { - "generated_from": "f172b041a99f4de1d396e5d451a5e64d" + "generated_from": "5002edd65d1a15b8a2aae04c671d1a73" } */ @@ -17,19 +17,8 @@ namespace FlexFlow { struct UnlabelledPatternEdgeSplits { - UnlabelledPatternEdgeSplits() = delete; - explicit UnlabelledPatternEdgeSplits( - ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, - std::pair<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::InputMultiDiEdge>> const - &unwrapped); - bool operator==(UnlabelledPatternEdgeSplits const &) const; bool operator!=(UnlabelledPatternEdgeSplits const &) const; - ::FlexFlow::bidict< - ::FlexFlow::MultiDiEdge, - std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>> - unwrapped; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml index fa714296c8..07b6753958 100644 --- a/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml @@ -10,6 +10,7 @@ includes = [ "", ] -[[fields]] -name = "unwrapped" -type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>>" +fields = [] +# [[fields]] +# name = "unwrapped" +# type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>>" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h index f6c1df278a..cd57e8da3d 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h @@ -3,16 +3,17 @@ // lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml /* proj-data { - "generated_from": "2dff356c85dccda1fce8f714d41c6202" + "generated_from": "86e465bc7dbcbece46db9919f4a61a22" } */ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H -#include "substitutions/unlabelled/pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" -#include "utils/graph.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" #include namespace FlexFlow { @@ -21,15 +22,15 @@ struct MatchAdditionalCriterion { explicit MatchAdditionalCriterion( std::function const &node_criterion, - std::function const - &edge_criterion); + std::function const + &value_criterion); std::function node_criterion; - std::function - edge_criterion; + std::function + value_criterion; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml index c0107d84e9..9eb62933f1 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml @@ -4,9 +4,10 @@ features = [] includes = [ "", - "utils/graph.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", "substitutions/unlabelled/pattern_node.dtg.h", - "substitutions/unlabelled/pattern_edge.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", ] [[fields]] @@ -14,5 +15,5 @@ name = "node_criterion" type = "std::function" [[fields]] -name = "edge_criterion" -type = "std::function" +name = "value_criterion" +type = "std::function" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h index e0c8f00969..ff1eccffee 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h @@ -3,27 +3,43 @@ // lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml /* proj-data { - "generated_from": "e44c4347e07263a493cbbd5caccedd22" + "generated_from": "bf7b8b9b9a9bad1d6c4d632d58ca82ab" } */ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "fmt/format.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" +#include +#include #include namespace FlexFlow { struct MatchSplit { MatchSplit() = delete; - explicit MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, - MultiDiGraphPatternMatch const &postfix_submatch); + explicit MatchSplit( + UnlabelledDataflowGraphPatternMatch const &prefix_submatch, + UnlabelledDataflowGraphPatternMatch const &postfix_submatch); bool operator==(MatchSplit const &) const; bool operator!=(MatchSplit const &) const; - MultiDiGraphPatternMatch prefix_submatch; - MultiDiGraphPatternMatch postfix_submatch; + UnlabelledDataflowGraphPatternMatch prefix_submatch; + UnlabelledDataflowGraphPatternMatch postfix_submatch; }; } // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::MatchSplit> { + size_t operator()(::FlexFlow::MatchSplit const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(MatchSplit const &); +std::ostream &operator<<(std::ostream &, MatchSplit const &); +} // namespace FlexFlow + #endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h index a23bc3f89a..957ce6eaa0 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H #include "substitutions/unlabelled/match_split.dtg.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/pattern_split.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" @@ -10,7 +10,7 @@ namespace FlexFlow { MatchSplit empty_match_split(); MatchSplit apply_split(UnlabelledGraphPattern const &pattern, - MultiDiGraphPatternMatch const &match, + UnlabelledDataflowGraphPatternMatch const &match, PatternSplit const &split); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml index 3fd77e7b4a..05c7451351 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml @@ -3,16 +3,18 @@ name = "MatchSplit" features = [ "eq", # "ord", + "hash", + "fmt", ] includes = [ - "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" ] [[fields]] name = "prefix_submatch" -type = "MultiDiGraphPatternMatch" +type = "UnlabelledDataflowGraphPatternMatch" [[fields]] name = "postfix_submatch" -type = "MultiDiGraphPatternMatch" +type = "UnlabelledDataflowGraphPatternMatch" diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h deleted file mode 100644 index 32a5228a9b..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h +++ /dev/null @@ -1,36 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml -/* proj-data -{ - "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" -} -*/ - -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H - -#include "substitutions/unlabelled/pattern_edge.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" -#include "utils/bidict.h" -#include "utils/graph.h" -#include - -namespace FlexFlow { -struct MultiDiGraphPatternMatch { - MultiDiGraphPatternMatch() = delete; - explicit MultiDiGraphPatternMatch( - ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const - &node_assignment, - ::FlexFlow::bidict<::FlexFlow::PatternEdge, - ::FlexFlow::OpenMultiDiEdge> const &edge_assignment); - - bool operator==(MultiDiGraphPatternMatch const &) const; - bool operator!=(MultiDiGraphPatternMatch const &) const; - ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; - ::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge> - edge_assignment; -}; -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml deleted file mode 100644 index 778767ab62..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +++ /dev/null @@ -1,24 +0,0 @@ -namespace = "FlexFlow" -# TODO(@lockshaw): rename to UnlabelledGraphPatternMatch -name = "MultiDiGraphPatternMatch" -features = [ - "eq", - # "ord", - # "hash", - # "fmt", -] - -includes = [ - "utils/bidict.h", - "utils/graph.h", - "substitutions/unlabelled/pattern_edge.dtg.h", - "substitutions/unlabelled/pattern_node.dtg.h", -] - -[[fields]] -name = "node_assignment" -type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" - -[[fields]] -name = "edge_assignment" -type = "::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h deleted file mode 100644 index 8303cd8c9c..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h +++ /dev/null @@ -1,39 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml -/* proj-data -{ - "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" -} -*/ - -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H - -#include "utils/graph.h" -#include -#include - -namespace FlexFlow { -struct PatternEdge { - PatternEdge() = delete; - explicit PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); - - bool operator==(PatternEdge const &) const; - bool operator!=(PatternEdge const &) const; - bool operator<(PatternEdge const &) const; - bool operator>(PatternEdge const &) const; - bool operator<=(PatternEdge const &) const; - bool operator>=(PatternEdge const &) const; - ::FlexFlow::OpenMultiDiEdge raw_edge; -}; -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::PatternEdge> { - size_t operator()(::FlexFlow::PatternEdge const &) const; -}; -} // namespace std - -#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml deleted file mode 100644 index 4abfa1c0db..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "PatternEdge" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::OpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index 223886b411..baea34afeb 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -3,21 +3,22 @@ #include "substitutions/unlabelled/match_additional_criterion.dtg.h" #include "substitutions/unlabelled/match_split.dtg.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "utils/graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" namespace FlexFlow { bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, + OpenDataflowGraphView const &graph, + UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion); -std::vector +std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, + OpenDataflowGraphView const &graph, MatchAdditionalCriterion const &additional_criterion); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.h new file mode 100644 index 0000000000..8a4c7b24a7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml +/* proj-data +{ + "generated_from": "60f3d7ccfb3b61349ff2cf61a0bfd1c0" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct PatternValue { + PatternValue() = delete; + explicit PatternValue( + ::FlexFlow::OpenDataflowValue const &raw_dataflow_value); + + bool operator==(PatternValue const &) const; + bool operator!=(PatternValue const &) const; + bool operator<(PatternValue const &) const; + bool operator>(PatternValue const &) const; + bool operator<=(PatternValue const &) const; + bool operator>=(PatternValue const &) const; + ::FlexFlow::OpenDataflowValue raw_dataflow_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::PatternValue> { + size_t operator()(::FlexFlow::PatternValue const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(PatternValue const &); +std::ostream &operator<<(std::ostream &, PatternValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml new file mode 100644 index 0000000000..c9b52b4c9e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "PatternValue" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +[[fields]] +name = "raw_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml new file mode 100644 index 0000000000..35630eac70 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "PatternValueUse" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_input" +type = "::FlexFlow::DataflowInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h new file mode 100644 index 0000000000..2209d9da09 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h @@ -0,0 +1,55 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "8e2550c2e4cd04bb1458f9e3f4ac05ba" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_DTG_H + +#include "fmt/format.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "utils/bidict.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct UnlabelledDataflowGraphPatternMatch { + UnlabelledDataflowGraphPatternMatch() = delete; + explicit UnlabelledDataflowGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternValue, + ::FlexFlow::OpenDataflowValue> const + &value_assignment); + + bool operator==(UnlabelledDataflowGraphPatternMatch const &) const; + bool operator!=(UnlabelledDataflowGraphPatternMatch const &) const; + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; + ::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue> + value_assignment; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::UnlabelledDataflowGraphPatternMatch> { + size_t + operator()(::FlexFlow::UnlabelledDataflowGraphPatternMatch const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(UnlabelledDataflowGraphPatternMatch const &); +std::ostream &operator<<(std::ostream &, + UnlabelledDataflowGraphPatternMatch const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml new file mode 100644 index 0000000000..f6c8c475cb --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "UnlabelledDataflowGraphPatternMatch" +features = [ + "eq", + # "ord", + "hash", + "fmt", +] + +includes = [ + "utils/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" + +[[fields]] +name = "value_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h index 972dda4200..af6fa5faa5 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h @@ -3,22 +3,22 @@ // lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml /* proj-data { - "generated_from": "f494ed79eb1ba4010155e456b452157f" + "generated_from": "7d8730b1ab76f6356bb09084d1c55f06" } */ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H -#include "utils/graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" namespace FlexFlow { struct UnlabelledGraphPattern { UnlabelledGraphPattern() = delete; explicit UnlabelledGraphPattern( - ::FlexFlow::OpenMultiDiGraphView const &raw_graph); + ::FlexFlow::OpenDataflowGraphView const &raw_graph); - ::FlexFlow::OpenMultiDiGraphView raw_graph; + ::FlexFlow::OpenDataflowGraphView raw_graph; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 9bb63037be..3de76f6ab2 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -1,25 +1,24 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H -#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" -#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" -#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" namespace FlexFlow { size_t num_nodes(UnlabelledGraphPattern const &); bool is_singleton_pattern(UnlabelledGraphPattern const &); std::unordered_set get_nodes(UnlabelledGraphPattern const &); -std::unordered_set get_edges(UnlabelledGraphPattern const &); +std::unordered_set get_values(UnlabelledGraphPattern const &); +std::unordered_set get_value_uses(UnlabelledGraphPattern const &, PatternValue const &); std::vector get_topological_ordering(UnlabelledGraphPattern const &); -std::unordered_set - get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); -std::unordered_set - get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::vector + get_inputs_to_pattern_node(UnlabelledGraphPattern const &, PatternNode const &); +std::vector + get_outputs_from_pattern_node(UnlabelledGraphPattern const &, PatternNode const &); UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, std::unordered_set const &); diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml index 03f4bd5523..74371f21ef 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml @@ -2,9 +2,9 @@ namespace = "FlexFlow" name = "UnlabelledGraphPattern" features = [] includes = [ - "utils/graph.h" + "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" ] [[fields]] name = "raw_graph" -type = "::FlexFlow::OpenMultiDiGraphView" +type = "::FlexFlow::OpenDataflowGraphView" diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index fb3199979d..eeda1e325f 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -370,6 +370,16 @@ std::optional get_attribute(TransposeAttrs const &p, } } +std::optional get_attribute(WeightAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { return p.visit>( diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc index 8caa7bd720..853b7f15aa 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml /* proj-data { - "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" + "generated_from": "256aae1d067ff00dda6cf9a94032d17a" } */ diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc index 376a9c2ce8..73e443064c 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml /* proj-data { - "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" + "generated_from": "c5c01fab8309c4abd9915570d2005390" } */ diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc index 3e945beded..b613dc0ab6 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc @@ -3,15 +3,15 @@ // lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml /* proj-data { - "generated_from": "9084c9afb2724504a6f4db4288a83a0d" + "generated_from": "9ce2d1b90d941d5362bdd9d671ff4349" } */ #include "substitutions/output_graph/output_graph_expr.dtg.h" namespace FlexFlow { -OutputGraphExpr::OutputGraphExpr( - ::FlexFlow::NodeLabelledOpenMultiDiGraph< - ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph) +OutputGraphExpr::OutputGraphExpr(::FlexFlow::LabelledOpenDataflowGraph< + ::FlexFlow::OutputOperatorAttrsAssignment, + std::nullopt_t> const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/graph_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc similarity index 97% rename from lib/substitutions/src/substitutions/graph_pattern.cc rename to lib/substitutions/src/substitutions/pcg_pattern.cc index 22cf12b4cf..bf59ab5080 100644 --- a/lib/substitutions/src/substitutions/graph_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -1,4 +1,4 @@ -#include "substitutions/graph_pattern.h" +#include "substitutions/pcg_pattern.h" #include "substitutions/operator_pattern/satisfies_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/satisfies_pattern.h" diff --git a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc index 9056a5ebdd..d55a386a7a 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc @@ -3,14 +3,14 @@ // lib/substitutions/include/substitutions/pcg_pattern.struct.toml /* proj-data { - "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" + "generated_from": "95b0a94000f16024bd541c492bf8a9b1" } */ #include "substitutions/pcg_pattern.dtg.h" namespace FlexFlow { -PCGPattern::PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< +PCGPattern::PCGPattern(::FlexFlow::LabelledOpenDataflowGraph< ::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern> const &raw_graph) : raw_graph(raw_graph) {} diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 7736113819..5cb4825d29 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -15,8 +15,8 @@ PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, - OpenMultiDiEdge const &e) { - return spcg.raw_graph.at(e); + OpenDataflowValue const &v) { + return spcg.raw_graph.at(v); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc index eabee4a906..252078c345 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "0022d1b2c1447667695a120c154a0168" + "generated_from": "c8f31135c257713d2a44680af5eb7feb" } */ @@ -11,8 +11,8 @@ namespace FlexFlow { SubParallelComputationGraph::SubParallelComputationGraph( - ::FlexFlow::OutputLabelledOpenMultiDiGraph< - ::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + ::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const + &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.dtg.cc b/lib/substitutions/src/substitutions/substitution.dtg.cc index 81c8a572df..e4383cb0db 100644 --- a/lib/substitutions/src/substitutions/substitution.dtg.cc +++ b/lib/substitutions/src/substitutions/substitution.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/substitution.struct.toml /* proj-data { - "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" + "generated_from": "9e0ea4f3e23858068cc975534e6c4cf7" } */ @@ -13,11 +13,11 @@ namespace FlexFlow { Substitution::Substitution( ::FlexFlow::PCGPattern const &pcg_pattern, ::FlexFlow::OutputGraphExpr const &output_graph_expr, - ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, - ::FlexFlow::InputMultiDiEdge> const + ::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, + ::FlexFlow::OpenDataflowValue> const &input_edge_match_to_output, - ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::OutputMultiDiEdge> const + ::FlexFlow::bidict<::FlexFlow::DataflowOutput, + ::FlexFlow::DataflowOutput> const &output_edge_match_to_output) : pcg_pattern(pcg_pattern), output_graph_expr(output_graph_expr), input_edge_match_to_output(input_edge_match_to_output), diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc index 30e7b78725..ddc2723238 100644 --- a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc @@ -3,25 +3,19 @@ // lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml /* proj-data { - "generated_from": "f172b041a99f4de1d396e5d451a5e64d" + "generated_from": "5002edd65d1a15b8a2aae04c671d1a73" } */ #include "substitutions/unlabelled/edge_splits.dtg.h" namespace FlexFlow { -UnlabelledPatternEdgeSplits::UnlabelledPatternEdgeSplits( - ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, - std::pair<::FlexFlow::OutputMultiDiEdge, - ::FlexFlow::InputMultiDiEdge>> const - &unwrapped) - : unwrapped(unwrapped) {} bool UnlabelledPatternEdgeSplits::operator==( UnlabelledPatternEdgeSplits const &other) const { - return std::tie(this->unwrapped) == std::tie(other.unwrapped); + return std::tie() == std::tie(); } bool UnlabelledPatternEdgeSplits::operator!=( UnlabelledPatternEdgeSplits const &other) const { - return std::tie(this->unwrapped) != std::tie(other.unwrapped); + return std::tie() != std::tie(); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc index 650bc0ec68..d23eddfe9b 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml /* proj-data { - "generated_from": "2dff356c85dccda1fce8f714d41c6202" + "generated_from": "86e465bc7dbcbece46db9919f4a61a22" } */ @@ -13,8 +13,8 @@ namespace FlexFlow { MatchAdditionalCriterion::MatchAdditionalCriterion( std::function const &node_criterion, - std::function const - &edge_criterion) - : node_criterion(node_criterion), edge_criterion(edge_criterion) {} + std::function const + &value_criterion) + : node_criterion(node_criterion), value_criterion(value_criterion) {} } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc index a45186aa3f..e960b44dd7 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc @@ -3,15 +3,18 @@ // lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml /* proj-data { - "generated_from": "e44c4347e07263a493cbbd5caccedd22" + "generated_from": "bf7b8b9b9a9bad1d6c4d632d58ca82ab" } */ #include "substitutions/unlabelled/match_split.dtg.h" +#include + namespace FlexFlow { -MatchSplit::MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, - MultiDiGraphPatternMatch const &postfix_submatch) +MatchSplit::MatchSplit( + UnlabelledDataflowGraphPatternMatch const &prefix_submatch, + UnlabelledDataflowGraphPatternMatch const &postfix_submatch) : prefix_submatch(prefix_submatch), postfix_submatch(postfix_submatch) {} bool MatchSplit::operator==(MatchSplit const &other) const { return std::tie(this->prefix_submatch, this->postfix_submatch) == @@ -22,3 +25,31 @@ bool MatchSplit::operator!=(MatchSplit const &other) const { std::tie(other.prefix_submatch, other.postfix_submatch); } } // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::MatchSplit const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.prefix_submatch) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.postfix_submatch) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(MatchSplit const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MatchSplit const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc deleted file mode 100644 index 822092fff8..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc +++ /dev/null @@ -1,29 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml -/* proj-data -{ - "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" -} -*/ - -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" - -namespace FlexFlow { -MultiDiGraphPatternMatch::MultiDiGraphPatternMatch( - ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const - &node_assignment, - ::FlexFlow::bidict<::FlexFlow::PatternEdge, - ::FlexFlow::OpenMultiDiEdge> const &edge_assignment) - : node_assignment(node_assignment), edge_assignment(edge_assignment) {} -bool MultiDiGraphPatternMatch::operator==( - MultiDiGraphPatternMatch const &other) const { - return std::tie(this->node_assignment, this->edge_assignment) == - std::tie(other.node_assignment, other.edge_assignment); -} -bool MultiDiGraphPatternMatch::operator!=( - MultiDiGraphPatternMatch const &other) const { - return std::tie(this->node_assignment, this->edge_assignment) != - std::tie(other.node_assignment, other.edge_assignment); -} -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc deleted file mode 100644 index 51ea760af3..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc +++ /dev/null @@ -1,43 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml -/* proj-data -{ - "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" -} -*/ - -#include "substitutions/unlabelled/pattern_edge.dtg.h" - -namespace FlexFlow { -PatternEdge::PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge) - : raw_edge(raw_edge) {} -bool PatternEdge::operator==(PatternEdge const &other) const { - return std::tie(this->raw_edge) == std::tie(other.raw_edge); -} -bool PatternEdge::operator!=(PatternEdge const &other) const { - return std::tie(this->raw_edge) != std::tie(other.raw_edge); -} -bool PatternEdge::operator<(PatternEdge const &other) const { - return std::tie(this->raw_edge) < std::tie(other.raw_edge); -} -bool PatternEdge::operator>(PatternEdge const &other) const { - return std::tie(this->raw_edge) > std::tie(other.raw_edge); -} -bool PatternEdge::operator<=(PatternEdge const &other) const { - return std::tie(this->raw_edge) <= std::tie(other.raw_edge); -} -bool PatternEdge::operator>=(PatternEdge const &other) const { - return std::tie(this->raw_edge) >= std::tie(other.raw_edge); -} -} // namespace FlexFlow - -namespace std { -size_t hash::operator()( - ::FlexFlow::PatternEdge const &x) const { - size_t result = 0; - result ^= std::hash<::FlexFlow::OpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + - (result << 6) + (result >> 2); - return result; -} -} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 335b9664ea..d0b6956222 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -1,8 +1,5 @@ #include "substitutions/unlabelled/pattern_matching.h" -#include "substitutions/unlabelled/input_pattern_edge.h" #include "substitutions/unlabelled/match_split.h" -#include "substitutions/unlabelled/output_pattern_edge.h" -#include "substitutions/unlabelled/pattern_edge.h" #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include @@ -11,8 +8,8 @@ namespace FlexFlow { bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, + OpenDataflowGraphView const &graph, + UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { if (is_singleton_pattern(pattern)) { PatternNode pattern_node = get_only(get_nodes(pattern)); @@ -20,8 +17,10 @@ bool unlabelled_pattern_does_match( if (!additional_criterion.node_criterion(pattern_node, matched_node)) { return false; } - for (PatternEdge const &e : get_edges(pattern)) { - OpenMultiDiEdge matched_edge = match.edge_assignment.at_l(e); + for (PatternValue const &pattern_value : get_values(pattern)) { + OpenDataflowValue matched_value = match.value_assignment.at_l(v); + + assert(is_input_edge(e) || is_output_edge(e)); if (is_input_edge(e)) { @@ -48,7 +47,7 @@ bool unlabelled_pattern_does_match( } } - if (!additional_criterion.edge_criterion(e, matched_edge)) { + if (!additional_criterion.value_criterion(pattern_value, matched_value)) { return false; } } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_value.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_value.dtg.cc new file mode 100644 index 0000000000..5e9cb069c8 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_value.dtg.cc @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml +/* proj-data +{ + "generated_from": "60f3d7ccfb3b61349ff2cf61a0bfd1c0" +} +*/ + +#include "substitutions/unlabelled/pattern_value.dtg.h" + +#include + +namespace FlexFlow { +PatternValue::PatternValue( + ::FlexFlow::OpenDataflowValue const &raw_dataflow_value) + : raw_dataflow_value(raw_dataflow_value) {} +bool PatternValue::operator==(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) == + std::tie(other.raw_dataflow_value); +} +bool PatternValue::operator!=(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) != + std::tie(other.raw_dataflow_value); +} +bool PatternValue::operator<(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) < + std::tie(other.raw_dataflow_value); +} +bool PatternValue::operator>(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) > + std::tie(other.raw_dataflow_value); +} +bool PatternValue::operator<=(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) <= + std::tie(other.raw_dataflow_value); +} +bool PatternValue::operator>=(PatternValue const &other) const { + return std::tie(this->raw_dataflow_value) >= + std::tie(other.raw_dataflow_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::PatternValue const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OpenDataflowValue>{}(x.raw_dataflow_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(PatternValue const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, PatternValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc new file mode 100644 index 0000000000..3d5598539b --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "8e2550c2e4cd04bb1458f9e3f4ac05ba" +} +*/ + +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" + +#include + +namespace FlexFlow { +UnlabelledDataflowGraphPatternMatch::UnlabelledDataflowGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternValue, + ::FlexFlow::OpenDataflowValue> const &value_assignment) + : node_assignment(node_assignment), value_assignment(value_assignment) {} +bool UnlabelledDataflowGraphPatternMatch::operator==( + UnlabelledDataflowGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->value_assignment) == + std::tie(other.node_assignment, other.value_assignment); +} +bool UnlabelledDataflowGraphPatternMatch::operator!=( + UnlabelledDataflowGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->value_assignment) != + std::tie(other.node_assignment, other.value_assignment); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::UnlabelledDataflowGraphPatternMatch const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>>{}( + x.node_assignment) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::bidict<::FlexFlow::PatternValue, + ::FlexFlow::OpenDataflowValue>>{}( + x.value_assignment) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(UnlabelledDataflowGraphPatternMatch const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + UnlabelledDataflowGraphPatternMatch const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index df10507a04..5f9d5b7a1e 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,5 +1,9 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/containers.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" namespace FlexFlow { @@ -13,40 +17,39 @@ bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { return transform(get_nodes(p.raw_graph), - [](Node const &n) { - return PatternNode{n}; }}); + [](Node const &n) { return PatternNode{n}; }); } -std::unordered_set get_edges(UnlabelledGraphPattern const &p) { - return transform(get_nodes(p.raw_graph), - [](OpenMultiDiEdge const &e) { - return PatternEdge{e}; }}); +std::unordered_set get_values(UnlabelledGraphPattern const &p) { + return transform(get_open_dataflow_values(p.raw_graph), + [](OpenDataflowValue const &v) { return PatternValue{v}; }); } std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { - return transform(get_topological_ordering(p), - [](Node const &n) { - return PatternNode{n}; }}); + return transform(get_topological_ordering(p.raw_graph), + [](Node const &n) { return PatternNode{n}; }); } -UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, - std::unordered_set const &n) { - return { - get_subgraph(p.raw_graph, - transform(n, [](PatternNode const &n) { return n.raw_node; })); - }; +std::vector + get_inputs_to_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_inputs(p.raw_graph, n.raw_node), + [](OpenDataflowValue const &v) { return PatternValue{v}; }); } -std::unordered_set - get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_incoming_edges(p.raw_graph, n.raw_node), - [](Node const &n) { return PatternNode{n}; }); +std::vector + get_outputs_from_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_outputs(p.raw_graph, n.raw_node), + [](DataflowOutput const &o) { return PatternValue{OpenDataflowValue{o}}; }); } -std::unordered_set - get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_outgoing_edges(p.raw_graph, n.raw_node), - [](Node const &n) { return PatternNode{n}; }); +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, + std::unordered_set const &n) { + NOT_IMPLEMENTED(); + // return UnlabelledGraphPattern{ + // get_subgraph(p.raw_graph, + // transform(n, [](PatternNode const &n) { return n.raw_node; })); + // }; } + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc index 0bebd8dd91..c9c912474c 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml /* proj-data { - "generated_from": "f494ed79eb1ba4010155e456b452157f" + "generated_from": "7d8730b1ab76f6356bb09084d1c55f06" } */ @@ -11,6 +11,6 @@ namespace FlexFlow { UnlabelledGraphPattern::UnlabelledGraphPattern( - ::FlexFlow::OpenMultiDiGraphView const &raw_graph) + ::FlexFlow::OpenDataflowGraphView const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 6af18c2a4a..11e4ba8b05 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -166,6 +166,11 @@ struct bidict { operator std::unordered_map const &() const { return this->fwd_map; } + + std::unordered_map const &as_unordered_map() const { + return this->fwd_map; + } + bidict(std::unordered_map const &fwd_map, std::unordered_map const &bwd_map) : fwd_map(fwd_map), bwd_map(bwd_map) {} diff --git a/lib/utils/include/utils/containers/group_by.h b/lib/utils/include/utils/containers/group_by.h new file mode 100644 index 0000000000..6f367f8875 --- /dev/null +++ b/lib/utils/include/utils/containers/group_by.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GROUP_BY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GROUP_BY_H + +#include +#include +#include + +namespace FlexFlow { + +template > +std::unordered_map> group_by(std::unordered_set const &vs, F f) { + std::unordered_map> result; + for (V const &v : vs) { + result[f(v)].insert(v); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_minus.h b/lib/utils/include/utils/containers/set_minus.h new file mode 100644 index 0000000000..b5ce6aee4a --- /dev/null +++ b/lib/utils/include/utils/containers/set_minus.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H + +#include + +namespace FlexFlow { + +template +std::unordered_set set_minus(std::unordered_set const &l, std::unordered_set const &r) { + std::unordered_set result = l; + for (T const &t : r) { + result.erase(t); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/without_nullopts.h b/lib/utils/include/utils/containers/without_nullopts.h index f888654b60..ae6d5109cd 100644 --- a/lib/utils/include/utils/containers/without_nullopts.h +++ b/lib/utils/include/utils/containers/without_nullopts.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -17,6 +18,17 @@ std::vector without_nullopts(std::vector> const &v) { return result; } +template +std::unordered_set without_nullopts(std::unordered_set> const &s) { + std::unordered_set result; + for (std::optional const &t : s) { + if (t.has_value()) { + result.insert(t.value()); + } + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph.h b/lib/utils/include/utils/graph.h index 91f0ea6eb5..3bd61b3f91 100644 --- a/lib/utils/include/utils/graph.h +++ b/lib/utils/include/utils/graph.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_H #define _FLEXFLOW_UTILS_GRAPH_H -#include "graph/digraph/adjacency_digraph.h" -#include "graph/algorithms.h" -#include "graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" // #include "graph/labelled_graphs.h" -#include "graph/node/node.dtg.h" +#include "utils/graph/node/node.dtg.h" // #include "graph/open_graphs.h" -#include "graph/serial_parallel/serialparallel.h" -#include "graph/traversal.h" -#include "graph/undirected/undirected_graph.h" +#include "utils/graph/serial_parallel/serialparallel.h" +#include "utils/graph/traversal.h" +#include "utils/graph/undirected/undirected_graph.h" #endif diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 15c10f68e3..ec58f4362c 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -61,7 +61,6 @@ DiGraphView apply_contraction(DiGraphView const &, UndirectedGraphView apply_contraction(UndirectedGraphView const &, std::unordered_map const &); -std::size_t num_nodes(GraphView const &); bool empty(GraphView const &); // void add_edges(MultiDiGraph &, std::vector const &); @@ -97,10 +96,6 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, // std::unordered_set // get_open_inputs(OpenMultiDiGraphView const &); -// std::unordered_set get_incoming_edges(MultiDiGraphView const &, -// Node const &); -std::unordered_set get_incoming_edges(DiGraphView const &, - Node const &); // std::unordered_set // get_incoming_edges(UpwardOpenMultiDiGraphView const &, Node const &); // std::unordered_set @@ -110,13 +105,9 @@ std::unordered_set get_incoming_edges(DiGraphView const &, // std::unordered_set get_incoming_edges(MultiDiGraphView const &, // std::unordered_set); -std::unordered_set - get_incoming_edges(DiGraphView const &, std::unordered_set const &); // std::unordered_set get_outgoing_edges(MultiDiGraphView const &, // Node const &); -std::unordered_set get_outgoing_edges(DiGraphView const &, - Node const &); // std::unordered_set // get_outgoing_edges(UpwardOpenMultiDiGraphView const &, Node const &); // std::unordered_set @@ -127,8 +118,6 @@ std::unordered_set get_outgoing_edges(DiGraphView const &, // std::unordered_set // get_outgoing_edges(MultiDiGraphView const &, // std::unordered_set const &); -std::unordered_set - get_outgoing_edges(DiGraphView const &, std::unordered_set const &); std::unordered_set get_node_edges(UndirectedGraphView const &, Node const &); @@ -136,10 +125,6 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set const &); -std::unordered_set get_predecessors(DiGraphView const &, Node const &); -std::unordered_map> - get_predecessors(DiGraphView const &, std::unordered_set const &); - // Node get_src_node(MultiDiEdge const &); // Node get_dst_node(MultiDiEdge const &); // Node get_dst_node(InputMultiDiEdge const &); @@ -161,10 +146,6 @@ std::unordered_set get_sinks(DiGraphView const &); // std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g); -// bool is_acyclic(MultiDiGraphView const &, std::unordered_set const &); -std::optional is_acyclic(DiGraphView const &); -// std::optional is_acyclic(MultiDiGraphView const &); - std::unordered_map> get_dominators(DiGraphView const &); std::unordered_set get_dominators(DiGraphView const &, Node const &); diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h index 247460f4df..bb1d591daf 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h @@ -10,6 +10,7 @@ std::unordered_set get_edges(DataflowGraphView const &); std::vector get_incoming_edges(DataflowGraphView const &, Node const &); std::vector get_inputs(DataflowGraphView const &, Node const &); std::vector get_outputs(DataflowGraphView const &, Node const &); +std::unordered_set get_all_dataflow_outputs(DataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h index d6d44ce49a..0385cf8a20 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_EDGE_QUERY_H +#include "utils/graph/dataflow_graph/dataflow_edge.dtg.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.dtg.h" namespace FlexFlow { DataflowEdgeQuery dataflow_edge_query_all(); DataflowEdgeQuery dataflow_edge_query_none(); +bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &, DataflowEdge const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h index f373a06dae..f09d5aed04 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_DATAFLOW_OUTPUT_QUERY_H +#include "utils/graph/dataflow_graph/dataflow_output.dtg.h" #include "utils/graph/dataflow_graph/dataflow_output_query.dtg.h" namespace FlexFlow { DataflowOutputQuery dataflow_output_query_all(); DataflowOutputQuery dataflow_output_query_none(); +bool dataflow_output_query_includes_dataflow_output(DataflowOutputQuery const &, DataflowOutput const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index ed615fd7f8..21208f06f7 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -5,12 +5,21 @@ namespace FlexFlow { -std::unordered_set get_nodes(DiGraph const &); -std::unordered_set get_edges(DirectedEdge const &); -std::unordered_set get_incoming_edges(DiGraph const &, Node const &); -std::unordered_set get_outgoing_edges(DiGraph const &, Node const &); -std::unordered_set get_sources(DiGraph const &); -std::vector get_topological_ordering(DiGraph const &); +std::unordered_set get_edges(DiGraphView const &); +std::unordered_set get_incoming_edges(DiGraphView const &, Node const &); +std::unordered_map> get_incoming_edges(DiGraphView const &, std::unordered_set const &); +std::unordered_set get_outgoing_edges(DiGraphView const &, Node const &); +std::unordered_map> get_outgoing_edges(DiGraphView const &, std::unordered_set const &); +std::unordered_set get_sources(DiGraphView const &); +std::unordered_set get_sinks(DiGraphView const &); +std::vector get_topological_ordering(DiGraphView const &); +std::optional is_acyclic(DiGraphView const &); + +DiGraphView flipped(DiGraphView const &g); + +std::unordered_set get_predecessors(DiGraphView const &, Node const &); +std::unordered_map> + get_predecessors(DiGraphView const &, std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/adjacency_digraph.h b/lib/utils/include/utils/graph/instances/adjacency_digraph.h similarity index 88% rename from lib/utils/include/utils/graph/digraph/adjacency_digraph.h rename to lib/utils/include/utils/graph/instances/adjacency_digraph.h index 9a2e13a3a5..e08bd350bc 100644 --- a/lib/utils/include/utils/graph/digraph/adjacency_digraph.h +++ b/lib/utils/include/utils/graph/instances/adjacency_digraph.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_GRAPH_ADJACENCY_DIGRAPH_H -#define _FLEXFLOW_UTILS_GRAPH_ADJACENCY_DIGRAPH_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_ADJACENCY_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_ADJACENCY_DIGRAPH_H #include "utils/graph/digraph/digraph.h" #include diff --git a/lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h similarity index 100% rename from lib/utils/include/utils/graph/dataflow_graph/unordered_set_dataflow_graph.h rename to lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h new file mode 100644 index 0000000000..579aeda83e --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -0,0 +1,114 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_UNORDERED_SET_LABELLED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_UNORDERED_SET_LABELLED_DATAFLOW_GRAPH_H + +#include "utils/containers/enumerate_vector.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" +#include "utils/containers/zip_vectors.h" +#include "utils/containers/without_nullopts.h" + +namespace FlexFlow { + +template +struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflowGraph { +public: + UnorderedSetLabelledOpenDataflowGraph() = default; + + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) override { + Node new_node = this->node_source.new_node(); + this->nodes.insert({new_node, node_label}); + + for (auto const &[input_idx, input] : enumerate_vector(inputs)) { + this->edges.insert(open_dataflow_edge_from_src_and_dst(input, DataflowInput{new_node, input_idx})); + } + + std::vector new_outputs = transform( + count(output_labels.size()), + [&](int output_idx) { return DataflowOutput{new_node, output_idx}; }); + + for (auto const &[output, output_label] : zip(new_outputs, output_labels)) { + this->values.insert({OpenDataflowValue{output}, output_label}); + } + + return NodeAddedResult { + new_node, + new_outputs, + }; + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return filter(keys(this->nodes), + [&](Node const &n) { return includes(q.nodes, n); }); + } + + std::unordered_set query_edges(OpenDataflowEdgeQuery const &q) const override { + return filter(this->edges, + [&](OpenDataflowEdge const &e) { + return open_dataflow_edge_query_includes(q, e); + }); + } + + std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { + return without_nullopts(transform(keys(this->values), + [&](OpenDataflowValue const &v) -> std::optional { + DataflowOutput o = v.get(); + if (dataflow_output_query_includes_dataflow_output(q, o)) { + return o; + } else { + return std::nullopt; + } + })); + } + + std::vector get_inputs() const override { + return this->inputs; + } + + NodeLabel const &at(Node const &n) const override { + return this->nodes.at(n); + } + + ValueLabel const &at(OpenDataflowValue const &v) const override { + return this->values.at(v); + } + + UnorderedSetLabelledOpenDataflowGraph *clone() const override { + return new UnorderedSetLabelledOpenDataflowGraph{ + this->node_source, + this->inputs, + this->nodes, + this->edges, + this->values, + }; + } +private: + UnorderedSetLabelledOpenDataflowGraph(NodeSource const &node_source, + std::vector const &inputs, + std::unordered_map const &nodes, + std::unordered_set const &edges, + std::unordered_map const &values) + : node_source(node_source), + inputs(inputs), + nodes(nodes), + edges(edges), + values(values) + {} + +private: + NodeSource node_source; + std::vector inputs; + std::unordered_map nodes; + std::unordered_set edges; + std::unordered_map values; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetLabelledOpenDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h index 00eb0250b5..ea9b463790 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h @@ -11,12 +11,24 @@ struct LabelledDataflowGraph : virtual LabelledDataflowGraphView; public: + LabelledDataflowGraph(LabelledDataflowGraph const &) = default; + LabelledDataflowGraph &operator=(LabelledDataflowGraph const &) = default; + NodeAddedResult add_node(NodeLabel const &node_label, std::vector const &inputs, std::vector const &output_labels) { return this->get_interface().add_node(node_label, inputs, output_labels); } + template + static typename std::enable_if::value, + LabelledDataflowGraph>::type + create() { + return LabelledDataflowGraph(make_cow_ptr()); + } +protected: + using LabelledDataflowGraphView::LabelledDataflowGraphView; + private: Interface &get_interface() { return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h index 0b372a0f70..94113a62ae 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h @@ -11,6 +11,9 @@ struct LabelledDataflowGraphView : virtual public DataflowGraphView { private: using Interface = ILabelledDataflowGraphView; public: + LabelledDataflowGraphView(LabelledDataflowGraphView const &) = default; + LabelledDataflowGraphView &operator=(LabelledDataflowGraphView const &) = default; + NodeLabel const &at(Node const &n) const { return this->get_interface().at(n); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h new file mode 100644 index 0000000000..4326d06283 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/node_added_result.dtg.h" +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct ILabelledOpenDataflowGraph : virtual public ILabelledOpenDataflowGraphView + , virtual public ILabelledDataflowGraph { + virtual NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) = 0; + + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) override final { + return this->add_node(node_label, transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }), output_labels); + } + + virtual ~ILabelledOpenDataflowGraph() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenDataflowGraph); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h new file mode 100644 index 0000000000..5c691fe225 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_I_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +template +struct ILabelledOpenDataflowGraphView : virtual public ILabelledDataflowGraphView, + virtual public IOpenDataflowGraphView { +public: + virtual ValueLabel const &at(OpenDataflowValue const &) const = 0; + + ValueLabel const &at(DataflowOutput const &o) const override final { + return this->at(OpenDataflowValue{o}); + } + + virtual ~ILabelledOpenDataflowGraphView() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenDataflowGraphView); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h new file mode 100644 index 0000000000..c41749e333 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +template +struct LabelledOpenDataflowGraph : virtual public LabelledOpenDataflowGraphView { +private: + using Interface = ILabelledOpenDataflowGraph; + +public: + LabelledOpenDataflowGraph(LabelledOpenDataflowGraph const &) = default; + LabelledOpenDataflowGraph &operator=(LabelledOpenDataflowGraph const &) = default; + + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) { + return this->get_interface().add_node(node_label, inputs, output_labels); + } +protected: + using LabelledOpenDataflowGraphView::LabelledOpenDataflowGraphView; +private: + Interface &get_interface() { + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); + } + + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h new file mode 100644 index 0000000000..e69b5958be --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledOpenDataflowGraphView : virtual public LabelledDataflowGraphView, + virtual public OpenDataflowGraphView { +private: + using Interface = ILabelledOpenDataflowGraphView; +public: + LabelledOpenDataflowGraphView(LabelledOpenDataflowGraphView const &) = default; + LabelledOpenDataflowGraphView &operator=(LabelledOpenDataflowGraphView const &) = default; + + ValueLabel const &at(OpenDataflowValue const &v) const { + return this->get_interface().at(v); + } + + template + static typename std::enable_if::value, + LabelledOpenDataflowGraphView>::type + create(Args &&... args) { + return LabelledOpenDataflowGraphView(make_cow_ptr(std::forward(args)...)); + } +protected: + using LabelledDataflowGraphView::LabelledDataflowGraphView; +private: + Interface const &get_interface() const { + return *std::dynamic_pointer_cast(GraphView::ptr.get()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/algorithms.h b/lib/utils/include/utils/graph/node/algorithms.h index fbc18d562f..5637c622b0 100644 --- a/lib/utils/include/utils/graph/node/algorithms.h +++ b/lib/utils/include/utils/graph/node/algorithms.h @@ -6,6 +6,8 @@ namespace FlexFlow { std::unordered_set get_nodes(GraphView const &); +size_t num_nodes(GraphView const &); +bool empty(GraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h index e456420094..44533452bc 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { std::unordered_set get_edges(OpenDataflowGraphView const &); std::vector get_inputs(OpenDataflowGraphView const &); -std::vector get_incoming_edges(OpenDataflowGraphView const &); +std::vector get_inputs(OpenDataflowGraphView const &, Node const &); +std::vector get_incoming_edges(OpenDataflowGraphView const &, Node const &); +std::unordered_set get_open_dataflow_values(OpenDataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h index 0a5c45013d..0244698041 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_INPUT_EDGE_QUERY_H +#include "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h" #include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.dtg.h" namespace FlexFlow { DataflowInputEdgeQuery dataflow_input_edge_query_all(); DataflowInputEdgeQuery dataflow_input_edge_query_none(); +bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &, DataflowInputEdge const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h index e4f9704ce4..0ffb067223 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h @@ -2,10 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_H #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &); +OpenDataflowValue get_open_dataflow_edge_source(OpenDataflowEdge const &); +OpenDataflowEdge open_dataflow_edge_from_src_and_dst(OpenDataflowValue const &src, DataflowInput const &dst); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h index d6b13ab2a0..aa5d82b2f2 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_EDGE_QUERY_H +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" namespace FlexFlow { OpenDataflowEdgeQuery open_dataflow_edge_query_all(); OpenDataflowEdgeQuery open_dataflow_edge_query_none(); +bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, OpenDataflowEdge const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index f251912103..1dba4cba44 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_VIEWS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_VIEWS_VIEWS_H -#include "utils/graph/digraph/adjacency_digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/digraph/digraph_view.h" #include "utils/graph/undirected/undirected_graph_view.h" // #include "utils/graph/multidigraph/multidigraph_view.h" diff --git a/lib/utils/src/utils/containers/group_by.cc b/lib/utils/src/utils/containers/group_by.cc new file mode 100644 index 0000000000..ac05cee861 --- /dev/null +++ b/lib/utils/src/utils/containers/group_by.cc @@ -0,0 +1 @@ +#include "utils/containers/group_by.h" diff --git a/lib/utils/src/utils/containers/set_minus.cc b/lib/utils/src/utils/containers/set_minus.cc new file mode 100644 index 0000000000..e5d5a1468e --- /dev/null +++ b/lib/utils/src/utils/containers/set_minus.cc @@ -0,0 +1 @@ +#include "utils/containers/set_minus.h" diff --git a/lib/utils/src/utils/containers/without_nullopts.cc b/lib/utils/src/utils/containers/without_nullopts.cc new file mode 100644 index 0000000000..5d85a7bc2f --- /dev/null +++ b/lib/utils/src/utils/containers/without_nullopts.cc @@ -0,0 +1,8 @@ +#include "utils/containers/without_nullopts.h" + +namespace FlexFlow { + +template std::unordered_set without_nullopts(std::unordered_set> const &); +template std::vector without_nullopts(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 0e45056ce3..73abeb7864 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -14,6 +14,7 @@ #include "utils/hash-utils.h" #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/digraph/algorithms.h" namespace FlexFlow { @@ -145,14 +146,6 @@ void remove_node_if_unused(UndirectedGraph &g, Node const &n) { g.remove_node_unsafe(n); } -std::size_t num_nodes(GraphView const &g) { - return get_nodes(g).size(); -} - -bool empty(GraphView const &g) { - return num_nodes(g) == 0; -} - DiGraphView contract_node(DiGraphView const &g, Node const &from, Node const &into) { return DiGraphView::create(g, from, into); @@ -233,10 +226,6 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { // return g.query_edges(MultiDiEdgeQuery::all()); // } -std::unordered_set get_edges(DiGraphView const &g) { - return g.query_edges(directed_edge_query_all()); -} - std::unordered_set get_edges(UndirectedGraphView const &g) { return g.query_edges(undirected_edge_query_all()); } @@ -267,25 +256,12 @@ std::unordered_set get_node_edges(UndirectedGraphView const &g, // return get_incoming_edges(g, std::unordered_set{n}); // } -std::unordered_set get_incoming_edges(DiGraphView const &g, - Node const &n) { - return get_incoming_edges(g, std::unordered_set{n}); -} - // std::unordered_set // get_incoming_edges(MultiDiGraphView const &g, // std::unordered_set dsts) { // return g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(dsts)); // } -std::unordered_set - get_incoming_edges(DiGraphView const &g, - std::unordered_set const &dsts) { - NOT_IMPLEMENTED(); - // auto multidigraph_view = as_multidigraph(g); - // return to_directed_edges(get_incoming_edges(multidigraph_view, dsts)); -} - // std::unordered_set // get_outgoing_edges(MultiDiGraphView const &g, // std::unordered_set const &srcs) { @@ -297,19 +273,6 @@ std::unordered_set // return get_outgoing_edges(g, std::unordered_set{n}); // } -std::unordered_set - get_outgoing_edges(DiGraphView const &g, - std::unordered_set const &dsts) { - NOT_IMPLEMENTED(); - // auto multidigraph_view = as_multidigraph(g); - // return to_directed_edges(get_outgoing_edges(multidigraph_view, dsts)); -} - -std::unordered_set get_outgoing_edges(DiGraphView const &g, - Node const &n) { - return get_outgoing_edges(g, std::unordered_set{n}); -} - // std::unordered_map> // get_incoming_edges_by_idx(MultiDiGraphView const &g, Node const &n) { // std::unordered_set edges = get_incoming_edges(g, n); @@ -358,23 +321,6 @@ std::unordered_set get_outgoing_edges(DiGraphView const &g, // return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); // } -std::unordered_map> - get_predecessors(DiGraphView const &g, - std::unordered_set const &nodes) { - std::unordered_map> predecessors; - for (Node const &n : nodes) { - predecessors[n]; - } - for (DirectedEdge const &e : get_incoming_edges(g, nodes)) { - predecessors.at(e.dst).insert(e.src); - } - return predecessors; -} - -std::unordered_set get_predecessors(DiGraphView const &g, Node const &n) { - return get_predecessors(g, std::unordered_set{n}).at(n); -} - std::vector get_unchecked_dfs_ordering( DiGraphView const &g, std::unordered_set const &starting_points) { UncheckedDFSView dfs_view = unchecked_dfs(g, starting_points); @@ -395,84 +341,10 @@ std::vector return {bfs_view.begin(), bfs_view.end()}; } -std::unordered_set get_sinks(DiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return get_outgoing_edges(g, n).size() == 0; - }); -} - -DiGraphView flipped(DiGraphView const &g) { - return DiGraphView::create(g); -} - -std::unordered_set get_sources(DiGraphView const &g) { - return filter(get_nodes(g), [&](Node const &n) { - return get_incoming_edges(g, n).size() == 0; - }); -} - -std::optional is_acyclic(DiGraphView const &g) { - if (num_nodes(g) == 0) { - return std::nullopt; - } - std::unordered_set sources = get_sources(g); - if (sources.size() == 0) { - return false; - } - auto dfs_view = unchecked_dfs(g, sources); - std::unordered_set seen; - for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); - it++) { - if (contains(seen, *it)) { - return false; - } else { - seen.insert(*it); - } - } - if (seen != get_nodes(g)) { - return false; - } - return true; -} - // std::optional is_acyclic(MultiDiGraph const &g) { // return is_acyclic(g); // } -std::vector get_unchecked_topological_ordering(DiGraphView const &g) { - auto dfs_view = unchecked_dfs(g, get_sources(g)); - std::vector order; - std::unordered_set seen; - std::unordered_map> predecessors = - get_predecessors(g, get_nodes(g)); - - auto all_predecessors_seen = [&](Node const &n) -> bool { - bool result = true; - for (Node const &pred : predecessors.at(n)) { - result &= contains(seen, pred); - } - return result; - }; - - unchecked_dfs_iterator it = dfs_view.cbegin(); - while (it != dfs_view.cend()) { - if (all_predecessors_seen(*it)) { - order.push_back(*it); - seen.insert(*it); - it++; - } else { - it.skip(); - } - } - - return order; -} - -std::vector get_topological_ordering(DiGraphView const &g) { - assert(is_acyclic(g)); - return get_unchecked_topological_ordering(g); -} - std::vector get_edge_topological_ordering(DiGraphView const &g) { std::vector result; for (Node const &n : get_topological_ordering(g)) { diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc index f7db16dfe2..e878b4deee 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -1,5 +1,6 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" namespace FlexFlow { @@ -29,7 +30,10 @@ std::vector get_outputs(DataflowGraphView const &g, Node const & [](DataflowOutput const &l, DataflowOutput const &r) { return l.idx < r.idx; }); +} +std::unordered_set get_all_dataflow_outputs(DataflowGraphView const &g) { + return g.query_outputs(dataflow_output_query_all()); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc index 03aaae8559..9840a7d15e 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc @@ -20,4 +20,11 @@ DataflowEdgeQuery dataflow_edge_query_none() { }; } +bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &q, DataflowEdge const &e) { + return includes(q.src_nodes, e.src.node) + && includes(q.dst_nodes, e.dst.node) + && includes(q.src_idxs, e.src.idx) + && includes(q.dst_idxs, e.dst.idx); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc index b739c7da68..9d75378281 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc @@ -16,4 +16,9 @@ DataflowOutputQuery dataflow_output_query_none() { }; } +bool dataflow_output_query_includes_dataflow_output(DataflowOutputQuery const &q, DataflowOutput const &o) { + return includes(q.nodes, o.node) + && includes(q.output_idxs, o.idx); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc new file mode 100644 index 0000000000..ea58982a7d --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -0,0 +1,130 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/group_by.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/traversal.h" +#include "utils/graph/views/views.h" + +namespace FlexFlow { + +std::unordered_set get_edges(DiGraphView const &g) { + return g.query_edges(directed_edge_query_all()); +} + +std::unordered_set get_incoming_edges(DiGraphView const &g, Node const &n) { + return g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set{n}, + }); +} + +std::unordered_map> get_incoming_edges(DiGraphView const &g, std::unordered_set const &ns) { + return group_by(g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set{ns}, + }), [](DirectedEdge const &e) { return e.dst; }); +} + +std::unordered_set get_outgoing_edges(DiGraphView const &g, Node const &n) { + return g.query_edges(DirectedEdgeQuery{ + query_set{n}, + query_set::matchall(), + }); +} + +std::unordered_map> get_outgoing_edges(DiGraphView const &g, std::unordered_set const &ns) { + return group_by(g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set{ns}, + }), [](DirectedEdge const &e) { return e.src; }); +} + +std::unordered_set get_sources(DiGraphView const &g) { + std::unordered_set all_nodes = get_nodes(g); + std::unordered_set with_incoming_edge = transform(get_edges(g), + [](DirectedEdge const &e) { return e.dst; }); + + return set_minus(all_nodes, with_incoming_edge); +} + +std::unordered_set get_sinks(DiGraphView const &g) { + return get_sources(flipped(g)); +} + +static std::vector get_unchecked_topological_ordering(DiGraphView const &g) { + auto dfs_view = unchecked_dfs(g, get_sources(g)); + std::vector order; + std::unordered_set seen; + std::unordered_map> predecessors = + get_predecessors(g, get_nodes(g)); + + auto all_predecessors_seen = [&](Node const &n) -> bool { + bool result = true; + for (Node const &pred : predecessors.at(n)) { + result &= contains(seen, pred); + } + return result; + }; + + unchecked_dfs_iterator it = dfs_view.cbegin(); + while (it != dfs_view.cend()) { + if (all_predecessors_seen(*it)) { + order.push_back(*it); + seen.insert(*it); + it++; + } else { + it.skip(); + } + } + + return order; +} + + +std::vector get_topological_ordering(DiGraphView const &g) { + assert(is_acyclic(g)); + return get_unchecked_topological_ordering(g); +} + +std::optional is_acyclic(DiGraphView const &g) { + if (num_nodes(g) == 0) { + return std::nullopt; + } + std::unordered_set sources = get_sources(g); + if (sources.size() == 0) { + return false; + } + auto dfs_view = unchecked_dfs(g, sources); + std::unordered_set seen; + for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); + it++) { + if (contains(seen, *it)) { + return false; + } else { + seen.insert(*it); + } + } + if (seen != get_nodes(g)) { + return false; + } + return true; +} + +DiGraphView flipped(DiGraphView const &g) { + return DiGraphView::create(g); +} + +std::unordered_set get_predecessors(DiGraphView const &g, Node const &n) { + return get_predecessors(g, std::unordered_set{n}).at(n); +} + +std::unordered_map> + get_predecessors(DiGraphView const &g, std::unordered_set const &ns) { + return map_values(get_incoming_edges(g, ns), + [](std::unordered_set const &es) { + return transform(es, [](DirectedEdge const &e) { return e.src; }); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc similarity index 97% rename from lib/utils/src/utils/graph/digraph/adjacency_digraph.cc rename to lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 4a54986832..758d7e299f 100644 --- a/lib/utils/src/utils/graph/digraph/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/digraph/adjacency_digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc similarity index 97% rename from lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc rename to lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 78788a6454..a4af66e9fe 100644 --- a/lib/utils/src/utils/graph/undirected/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/undirected/hashmap_undirected_graph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" #include "utils/containers.h" #include "utils/exception.h" diff --git a/lib/utils/include/utils/graph/undirected/hashmap_undirected_graph.h b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/include/utils/graph/undirected/hashmap_undirected_graph.h rename to lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc similarity index 97% rename from lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc rename to lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index 6fd7177f4d..4fe36aecaf 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -1,4 +1,4 @@ -#include "utils/graph/dataflow_graph/unordered_set_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/containers/enumerate_vector.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.cc new file mode 100644 index 0000000000..7e39d9462e --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.cc @@ -0,0 +1,7 @@ +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +template class UnorderedSetLabelledOpenDataflowGraph; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.cc new file mode 100644 index 0000000000..9278b7a14b --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.cc new file mode 100644 index 0000000000..9be315766e --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.cc new file mode 100644 index 0000000000..8044a82177 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.cc new file mode 100644 index 0000000000..132fd5c3ef --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/node/algorithms.cc b/lib/utils/src/utils/graph/node/algorithms.cc index 69fcdfa067..d92fa2a7ef 100644 --- a/lib/utils/src/utils/graph/node/algorithms.cc +++ b/lib/utils/src/utils/graph/node/algorithms.cc @@ -7,5 +7,12 @@ std::unordered_set get_nodes(GraphView const &g) { return g.query_nodes(node_query_all()); } +size_t num_nodes(GraphView const &g) { + return get_nodes(g).size(); +} + +bool empty(GraphView const &g) { + return num_nodes(g) == 0; +} } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc index 38dff4510e..55c958c58e 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -1,6 +1,7 @@ #include "utils/graph/open_dataflow_graph/algorithms.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -12,6 +13,11 @@ std::vector get_inputs(OpenDataflowGraphView const &g) { return g.get_inputs(); } +std::vector get_inputs(OpenDataflowGraphView const &g, Node const &n) { + return transform(get_incoming_edges(g, n), + [](OpenDataflowEdge const &e) { return get_open_dataflow_edge_source(e); }); +} + std::vector get_incoming_edges(OpenDataflowGraphView const &g, Node const &n) { return sorted_by(g.query_edges(OpenDataflowEdgeQuery{ DataflowInputEdgeQuery{ @@ -28,4 +34,13 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, }), [](OpenDataflowEdge const &l, OpenDataflowEdge const &r) { return get_open_dataflow_edge_dst_idx(l) < get_open_dataflow_edge_dst_idx(r); }); } +std::unordered_set get_open_dataflow_values(OpenDataflowGraphView const &g) { + return set_union( + transform(without_order(g.get_inputs()), + [](DataflowGraphInput const &gi) { return OpenDataflowValue{gi}; }), + transform(get_all_dataflow_outputs(g), + [](DataflowOutput const &o) { return OpenDataflowValue{o}; }) + ); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc index c3c6711304..3ae28fa828 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc @@ -17,4 +17,10 @@ DataflowInputEdgeQuery dataflow_input_edge_query_none() { }; } +bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &q, DataflowInputEdge const &e) { + return includes(q.srcs, e.src) + && includes(q.dst_nodes, e.dst.node) + && includes(q.dst_idxs, e.dst.idx); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc index 632c77df2c..256d16ea90 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc @@ -10,4 +10,18 @@ int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { }); } +OpenDataflowValue get_open_dataflow_edge_source_value(OpenDataflowEdge const &open_e) { + return open_e.visit(overload { + [](DataflowEdge const &e) { return OpenDataflowValue{e.src}; }, + [](DataflowInputEdge const &e) { return OpenDataflowValue{e.src}; }, + }); +} + +OpenDataflowEdge open_dataflow_edge_from_src_and_dst(OpenDataflowValue const &src, DataflowInput const &dst) { + return src.visit(overload { + [&](DataflowOutput const &o) { return OpenDataflowEdge{DataflowEdge{o, dst}}; }, + [&](DataflowGraphInput const &gi) { return OpenDataflowEdge{DataflowInputEdge{gi, dst}}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc index 9d72c8a009..70c5f131ac 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc @@ -1,6 +1,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" #include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/overload.h" namespace FlexFlow { @@ -18,4 +19,11 @@ OpenDataflowEdgeQuery open_dataflow_edge_query_none() { }; } +bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, OpenDataflowEdge const &open_e) { + return open_e.visit(overload { + [&](DataflowEdge const &e) { return dataflow_edge_query_includes_dataflow_edge(q.standard_edge_query, e); }, + [&](DataflowInputEdge const &e) { return dataflow_input_edge_query_includes(q.input_edge_query, e); }, + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc index 9fa4492bb6..7945ee6273 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc @@ -3,6 +3,7 @@ #include "utils/graph/serial_parallel/sink_settings.dtg.h" #include "utils/graph/serial_parallel/source_settings.dtg.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/digraph/algorithms.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/traversal.cc b/lib/utils/src/utils/graph/traversal.cc index aed0eb81ef..2e8e73d632 100644 --- a/lib/utils/src/utils/graph/traversal.cc +++ b/lib/utils/src/utils/graph/traversal.cc @@ -1,6 +1,7 @@ #include "utils/graph/traversal.h" #include "utils/containers.h" #include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -34,7 +35,7 @@ udi &udi::operator++() { Node const last = this->operator*(); this->stack.pop_back(); - std::unordered_set outgoing = get_outgoing_edges(graph, {last}); + std::unordered_set outgoing = get_outgoing_edges(graph, last); for (DirectedEdge const &e : outgoing) { auto it = std::find(stack.begin(), stack.end(), e.dst); if (it == stack.end()) { diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 0fb258bf15..bf8e135df6 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -1,6 +1,5 @@ #include "test/utils/doctest.h" -#include "utils/graph/adjacency_digraph.h" -#include "utils/graph/adjacency_multidigraph.h" +#include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" #include "utils/graph/hashmap_undirected_graph.h" diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc new file mode 100644 index 0000000000..6312c33111 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc @@ -0,0 +1,59 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inputs/get_outputs") { + DataflowGraph g; + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = + g.add_node({o1, o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + SUBCASE("get_inputs") { + std::vector result = get_inputs(g, n4); + std::vector correct = {o1, o2, o3}; + CHECK(result == correct); + } + + SUBCASE("get_outputs") { + std::vector result = get_outputs(g, n4); + std::vector correct = {o4}; + CHECK(result == correct); + } + } + + TEST_CASE("topological_ordering") { + DataflowGraph g; + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + std::vector result = get_topological_ordering(g); + std::vector correct = { n1, n2, n3 }; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc similarity index 97% rename from lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc rename to lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc index a82e3afbb5..d3db57a369 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc @@ -1,5 +1,5 @@ #include "test/utils/doctest.h" -#include "utils/graph/dataflow_graph/unordered_set_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" #include "utils/graph/node/node_query.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" From 7d4c7be92ec0cc5e735c638a9869780d2df33fa8 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 24 Jun 2024 22:38:56 -0700 Subject: [PATCH 12/71] Add utility functions to support pattern matching --- .../unlabelled/pattern_matching.h | 6 + .../unlabelled/pattern_value_use.dtg.h | 46 +++++++ ...abelled_dataflow_graph_pattern_match.dtg.h | 9 +- ...d_dataflow_graph_pattern_match.struct.toml | 6 +- .../unlabelled/pattern_matching.cc | 119 ++++++++++-------- .../unlabelled/pattern_value_use.dtg.cc | 65 ++++++++++ ...belled_dataflow_graph_pattern_match.dtg.cc | 19 +-- .../algorithms/rewrite_labels.h | 24 ++++ .../algorithms/with_labelling.h | 74 +++++++++++ .../graph/open_dataflow_graph/algorithms.h | 1 + .../algorithms/get_subgraph.h | 14 +++ .../algorithms/get_subgraph_inputs.h | 14 +++ lib/utils/include/utils/graph/rewriting.h | 34 ----- .../algorithms/rewrite_labels.cc | 14 +++ .../algorithms/with_labelling.cc | 1 + .../graph/open_dataflow_graph/algorithms.cc | 4 + .../algorithms/get_subgraph.cc | 95 ++++++++++++++ .../algorithms/get_subgraph_inputs.cc | 18 +++ 18 files changed, 450 insertions(+), 113 deletions(-) create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.h create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_value_use.dtg.cc create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h delete mode 100644 lib/utils/include/utils/graph/rewriting.h create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index baea34afeb..2af7bbf138 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -10,6 +10,12 @@ namespace FlexFlow { +// OpenDataflowGraphView apply_match(UnlabelledGraphPattern const &pattern, +// UnlabelledDataflowGraphPatternMatch const &match); + +OpenDataflowGraphView subgraph_matched(UnlabelledGraphPattern const &pattern, + UnlabelledDataflowGraphPatternMatch const &match); + bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &graph, diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.h new file mode 100644 index 0000000000..a2ca7d8a41 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value_use.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml +/* proj-data +{ + "generated_from": "7a2a514e11987e06022337d234fb32c8" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_USE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_USE_DTG_H + +#include "fmt/format.h" +#include "utils/graph/dataflow_graph/dataflow_input.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct PatternValueUse { + PatternValueUse() = delete; + explicit PatternValueUse(::FlexFlow::DataflowInput const &raw_dataflow_input); + + bool operator==(PatternValueUse const &) const; + bool operator!=(PatternValueUse const &) const; + bool operator<(PatternValueUse const &) const; + bool operator>(PatternValueUse const &) const; + bool operator<=(PatternValueUse const &) const; + bool operator>=(PatternValueUse const &) const; + ::FlexFlow::DataflowInput raw_dataflow_input; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::PatternValueUse> { + size_t operator()(::FlexFlow::PatternValueUse const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(PatternValueUse const &); +std::ostream &operator<<(std::ostream &, PatternValueUse const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_USE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h index 2209d9da09..e0ac05c236 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml /* proj-data { - "generated_from": "8e2550c2e4cd04bb1458f9e3f4ac05ba" + "generated_from": "a640c8f9530a44d78c1bce32d801360d" } */ @@ -25,16 +25,11 @@ struct UnlabelledDataflowGraphPatternMatch { UnlabelledDataflowGraphPatternMatch() = delete; explicit UnlabelledDataflowGraphPatternMatch( ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const - &node_assignment, - ::FlexFlow::bidict<::FlexFlow::PatternValue, - ::FlexFlow::OpenDataflowValue> const - &value_assignment); + &node_assignment); bool operator==(UnlabelledDataflowGraphPatternMatch const &) const; bool operator!=(UnlabelledDataflowGraphPatternMatch const &) const; ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; - ::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue> - value_assignment; }; } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml index f6c8c475cb..af28609478 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml @@ -19,6 +19,6 @@ includes = [ name = "node_assignment" type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" -[[fields]] -name = "value_assignment" -type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue>" +# [[fields]] +# name = "value_assignment" +# type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue>" diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index d0b6956222..65e278686e 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -2,72 +2,81 @@ #include "substitutions/unlabelled/match_split.h" #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include namespace FlexFlow { +OpenDataflowGraphView subgraph_matched(OpenDataflowGraphView const &g, + UnlabelledDataflowGraphPatternMatch const &match) { + std::unordered_set matched_nodes = keys(match.node_assignment.reversed()); + std::vector subgraph_inputs = sorted(get_subgraph_inputs(g, matched_nodes)); + return get_subgraph(g, matched_nodes, subgraph_inputs); +} + bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &graph, UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { - if (is_singleton_pattern(pattern)) { - PatternNode pattern_node = get_only(get_nodes(pattern)); - Node matched_node = match.node_assignment.at_l(pattern_node); - if (!additional_criterion.node_criterion(pattern_node, matched_node)) { - return false; - } - for (PatternValue const &pattern_value : get_values(pattern)) { - OpenDataflowValue matched_value = match.value_assignment.at_l(v); - - - - assert(is_input_edge(e) || is_output_edge(e)); - if (is_input_edge(e)) { - if (is_output_edge(matched_edge)) { - return false; - } - UpwardOpenMultiDiEdge matched_edge = - narrow(matched_edge).value(); - InputPatternEdge input_edge = require_input_edge(e); - if (match.node_assignment.at_l(get_dst_node(input_edge)) != - get_dst_node(matched_edge)) { - return false; - } - } else { - if (is_input_edge(matched_edge)) { - return false; - } - DownwardOpenMultiDiEdge matched_edge = - narrow(matched_edge).value(); - OutputPatternEdge output_edge = require_output_edge(e); - if (match.node_assignment.at_l(get_src_node(output_edge)) != - get_src_node(matched_edge)) { - return false; - } - } - - if (!additional_criterion.value_criterion(pattern_value, matched_value)) { - return false; - } - } - - return true; - } - PatternSplit split = find_even_split(pattern); - std::pair subpatterns = - apply_split(pattern, split); - auto submatches = apply_split(pattern, match, split); - return unlabelled_pattern_does_match(subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && - unlabelled_pattern_does_match(subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); + // PatternNode pattern_node = get_only(get_nodes(pattern)); + // Node matched_node = match.node_assignment.at_l(pattern_node); + // if (!additional_criterion.node_criterion(pattern_node, matched_node)) { + // return false; + // } + // + // for (PatternValue const &pattern_value : get_values(pattern)) { + // OpenDataflowValue matched_value = match.value_assignment.at_l(v); + // + // assert(is_input_edge(e) || is_output_edge(e)); + // if (is_input_edge(e)) { + // if (is_output_edge(matched_edge)) { + // return false; + // } + // UpwardOpenMultiDiEdge matched_edge = + // narrow(matched_edge).value(); + // InputPatternEdge input_edge = require_input_edge(e); + // if (match.node_assignment.at_l(get_dst_node(input_edge)) != + // get_dst_node(matched_edge)) { + // return false; + // } + // } else { + // if (is_input_edge(matched_edge)) { + // return false; + // } + // DownwardOpenMultiDiEdge matched_edge = + // narrow(matched_edge).value(); + // OutputPatternEdge output_edge = require_output_edge(e); + // if (match.node_assignment.at_l(get_src_node(output_edge)) != + // get_src_node(matched_edge)) { + // return false; + // } + // } + // + // if (!additional_criterion.value_criterion(pattern_value, matched_value)) { + // return false; + // } + // } + // + // return true; + // } + // + // PatternSplit split = find_even_split(pattern); + // std::pair subpatterns = + // apply_split(pattern, split); + // auto submatches = apply_split(pattern, match, split); + // + // return unlabelled_pattern_does_match(subpatterns.first, + // graph, + // submatches.prefix_submatch, + // additional_criterion) && + // unlabelled_pattern_does_match(subpatterns.second, + // graph, + // submatches.postfix_submatch, + // additional_criterion); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_value_use.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_value_use.dtg.cc new file mode 100644 index 0000000000..68714d99a0 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_value_use.dtg.cc @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_value_use.struct.toml +/* proj-data +{ + "generated_from": "7a2a514e11987e06022337d234fb32c8" +} +*/ + +#include "substitutions/unlabelled/pattern_value_use.dtg.h" + +#include + +namespace FlexFlow { +PatternValueUse::PatternValueUse( + ::FlexFlow::DataflowInput const &raw_dataflow_input) + : raw_dataflow_input(raw_dataflow_input) {} +bool PatternValueUse::operator==(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) == + std::tie(other.raw_dataflow_input); +} +bool PatternValueUse::operator!=(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) != + std::tie(other.raw_dataflow_input); +} +bool PatternValueUse::operator<(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) < + std::tie(other.raw_dataflow_input); +} +bool PatternValueUse::operator>(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) > + std::tie(other.raw_dataflow_input); +} +bool PatternValueUse::operator<=(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) <= + std::tie(other.raw_dataflow_input); +} +bool PatternValueUse::operator>=(PatternValueUse const &other) const { + return std::tie(this->raw_dataflow_input) >= + std::tie(other.raw_dataflow_input); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::PatternValueUse const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataflowInput>{}(x.raw_dataflow_input) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(PatternValueUse const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, PatternValueUse const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc index 3d5598539b..63504cd6c0 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.cc @@ -3,7 +3,7 @@ // lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml /* proj-data { - "generated_from": "8e2550c2e4cd04bb1458f9e3f4ac05ba" + "generated_from": "a640c8f9530a44d78c1bce32d801360d" } */ @@ -14,19 +14,15 @@ namespace FlexFlow { UnlabelledDataflowGraphPatternMatch::UnlabelledDataflowGraphPatternMatch( ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const - &node_assignment, - ::FlexFlow::bidict<::FlexFlow::PatternValue, - ::FlexFlow::OpenDataflowValue> const &value_assignment) - : node_assignment(node_assignment), value_assignment(value_assignment) {} + &node_assignment) + : node_assignment(node_assignment) {} bool UnlabelledDataflowGraphPatternMatch::operator==( UnlabelledDataflowGraphPatternMatch const &other) const { - return std::tie(this->node_assignment, this->value_assignment) == - std::tie(other.node_assignment, other.value_assignment); + return std::tie(this->node_assignment) == std::tie(other.node_assignment); } bool UnlabelledDataflowGraphPatternMatch::operator!=( UnlabelledDataflowGraphPatternMatch const &other) const { - return std::tie(this->node_assignment, this->value_assignment) != - std::tie(other.node_assignment, other.value_assignment); + return std::tie(this->node_assignment) != std::tie(other.node_assignment); } } // namespace FlexFlow @@ -39,10 +35,6 @@ size_t hash::operator()( ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>>{}( x.node_assignment) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::bidict<::FlexFlow::PatternValue, - ::FlexFlow::OpenDataflowValue>>{}( - x.value_assignment) + - 0x9e3779b9 + (result << 6) + (result >> 2); return result; } } // namespace std @@ -52,7 +44,6 @@ std::string format_as(UnlabelledDataflowGraphPatternMatch const &x) { std::ostringstream oss; oss << ""; return oss.str(); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h new file mode 100644 index 0000000000..969f0701cf --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" + +namespace FlexFlow { + +template , + typename NewValueLabel = std::invoke_result_t> +LabelledOpenDataflowGraphView + rewrite_labels(LabelledOpenDataflowGraphView const &g, F f) { + std::unordered_map node_labels = generate_map(get_nodes(g), f); + std::unordered_map value_labels = generate_map(get_nodes(g), f); + return with_labelling(g, node_labels, value_labels); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h new file mode 100644 index 0000000000..dd55b7ddf2 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h @@ -0,0 +1,74 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_WITH_LABELLING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_WITH_LABELLING_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +template +struct OpenDataflowGraphLabellingWrapper final : public ILabelledOpenDataflowGraphView { +public: + OpenDataflowGraphLabellingWrapper() = delete; + OpenDataflowGraphLabellingWrapper(OpenDataflowGraphView const &unlabelled, + std::unordered_map const &node_labels, + std::unordered_map const &value_labels) + : unlabelled(unlabelled), + node_labels(node_labels), + value_labels(value_labels) + { } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->unlabelled.query_nodes(q); + } + + std::unordered_set query_edges(OpenDataflowEdgeQuery const &q) const override { + return this->unlabelled.query_edges(q); + } + + std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { + return this->unlabelled.query_outputs(q); + } + + std::vector get_inputs() const override { + return this->unlabelled.get_inputs(); + } + + NodeLabel const &at(Node const &n) const override { + return this->node_labels.at(n); + } + + ValueLabel const &at(OpenDataflowValue const &v) const override { + return this->value_labels.at(v); + } + + OpenDataflowGraphLabellingWrapper *clone() const override { + return new OpenDataflowGraphLabellingWrapper{ + this->unlabelled, + this->node_labels, + this->value_labels, + }; + } + +private: + OpenDataflowGraphView unlabelled; + std::unordered_map node_labels; + std::unordered_map value_labels; +}; + +template +LabelledOpenDataflowGraphView + with_labelling(OpenDataflowGraphView const &g, + std::unordered_map const &node_labels, + std::unordered_map const &value_labels) { + return LabelledOpenDataflowGraphView::template create< + OpenDataflowGraphLabellingWrapper>( + g, + node_labels, + value_labels + ); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h index 44533452bc..a8f6c32490 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h @@ -10,6 +10,7 @@ std::unordered_set get_edges(OpenDataflowGraphView const &); std::vector get_inputs(OpenDataflowGraphView const &); std::vector get_inputs(OpenDataflowGraphView const &, Node const &); std::vector get_incoming_edges(OpenDataflowGraphView const &, Node const &); +std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &); std::unordered_set get_open_dataflow_values(OpenDataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h new file mode 100644 index 0000000000..835bea4731 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +namespace FlexFlow { + +OpenDataflowGraphView get_subgraph(OpenDataflowGraphView const &, + std::unordered_set const &, + std::vector const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h new file mode 100644 index 0000000000..4a70492766 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_inputs(OpenDataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/rewriting.h b/lib/utils/include/utils/graph/rewriting.h deleted file mode 100644 index f411d6c5ea..0000000000 --- a/lib/utils/include/utils/graph/rewriting.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_REWRITING_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_REWRITING_H - -#include "labelled_graphs.h" - -namespace FlexFlow { - -template ()(std::declval()))> -NodeLabelledMultiDiGraph rewrite(NodeLabelledMultiDiGraph const &, - F const &f); - -template ()(std::declval())), - typename OE = decltype(std::declval()(std::declval()))> -LabelledMultiDiGraph rewrite(LabelledMultiDiGraph const &, - F const &f); - -template ()(std::declval(), - std::declval())), - typename OE = decltype(std::declval()( - std::declval(), std::declval()))> -OutputLabelledMultiDiGraph - rewrite(OutputLabelledMultiDiGraph const &, F const &f); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc new file mode 100644 index 0000000000..2e9c4ac6e4 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc @@ -0,0 +1,14 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h" + +namespace FlexFlow { + +// TODO(@lockshaw) eventually move this over to tests + +struct Visitor { + std::string operator()(Node const &, int); + float operator()(OpenDataflowValue const &, int); +}; + +template LabelledOpenDataflowGraphView rewrite_labels(LabelledOpenDataflowGraphView const &, Visitor); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.cc new file mode 100644 index 0000000000..a13ffa36f2 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc index 55c958c58e..23fea301f6 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -34,6 +34,10 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, }), [](OpenDataflowEdge const &l, OpenDataflowEdge const &r) { return get_open_dataflow_edge_dst_idx(l) < get_open_dataflow_edge_dst_idx(r); }); } +std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &) { + NOT_IMPLEMENTED(); +} + std::unordered_set get_open_dataflow_values(OpenDataflowGraphView const &g) { return set_union( transform(without_order(g.get_inputs()), diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc new file mode 100644 index 0000000000..7a1a9a0c89 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -0,0 +1,95 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/containers/enumerate_vector.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/overload.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { + OpenDataflowSubgraph(OpenDataflowGraphView const &full_graph, + std::unordered_set const &subgraph_nodes, + bidict const &full_graph_values_to_subgraph_inputs, + std::vector const &subgraph_inputs) + : full_graph(full_graph), + subgraph_nodes(subgraph_nodes), + full_graph_values_to_subgraph_inputs(full_graph_values_to_subgraph_inputs), + subgraph_inputs(subgraph_inputs) + { + assert(is_subseteq_of(this->subgraph_nodes, get_nodes(full_graph))); + assert(without_order(subgraph_inputs) == without_order(values(full_graph_values_to_subgraph_inputs))); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return intersection(this->full_graph.query_nodes(q), this->subgraph_nodes); + } + + std::unordered_set query_edges(OpenDataflowEdgeQuery const &q) const override { + std::unordered_set result; + for (OpenDataflowEdge const &open_e : this->full_graph.query_edges(q)) { + open_e.visit(overload { + [&](DataflowEdge const &e) { + bool contains_src = contains(this->subgraph_nodes, e.src.node); + bool contains_dst = contains(this->subgraph_nodes, e.dst.node); + if (contains_src && contains_dst) { + result.insert(OpenDataflowEdge{e}); + } else if (contains_dst && !contains_src) { + result.insert(OpenDataflowEdge{DataflowInputEdge{this->full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{e.src}), e.dst}}); + } + }, + [&](DataflowInputEdge const &e) { + if (contains(this->subgraph_nodes, e.dst.node)) { + result.insert(OpenDataflowEdge{DataflowInputEdge{this->full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{e.src}), e.dst}}); + } + } + }); + } + return result; + } + + std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { + return filter(this->full_graph.query_outputs(q), + [&](DataflowOutput const &o) { + return contains(this->subgraph_nodes, o.node); + }); + } + + std::vector get_inputs() const override { + return this->subgraph_inputs; + }; + + OpenDataflowSubgraph *clone() const override { + return new OpenDataflowSubgraph{ + this->full_graph, + this->subgraph_nodes, + this->full_graph_values_to_subgraph_inputs, + this->subgraph_inputs, + }; + } +private: + OpenDataflowGraphView full_graph; + std::unordered_set subgraph_nodes; + bidict full_graph_values_to_subgraph_inputs; + std::vector subgraph_inputs; +}; + + +OpenDataflowGraphView get_subgraph(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes, + std::vector const &input_ordering) { + std::vector subgraph_inputs; + bidict full_graph_values_to_subgraph_inputs; + for (auto const &[idx, full_graph_value] : enumerate_vector(input_ordering)) { + DataflowGraphInput subgraph_input = DataflowGraphInput{idx}; + subgraph_inputs.push_back(subgraph_input); + full_graph_values_to_subgraph_inputs.equate({full_graph_value, subgraph_input}); + } + + return OpenDataflowGraphView::create( + g, + subgraph_nodes, + subgraph_inputs, + input_ordering); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc new file mode 100644 index 0000000000..c0d537925a --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -0,0 +1,18 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_inputs(OpenDataflowGraphView const &g, + std::unordered_set const &subgraph_nodes) { + std::unordered_set relevant_edges; + for (std::vector const &incoming : values(get_incoming_edges(g, subgraph_nodes))) { + extend(relevant_edges, incoming); + } + + return transform(relevant_edges, get_open_dataflow_edge_source); +} + +} // namespace FlexFlow From 9ab9eb229eea005da36d2724e89cb7ab554059bf Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 26 Jun 2024 13:34:04 -0700 Subject: [PATCH 13/71] Pre-refactor inputs --- .../unlabelled/pattern_matching.cc | 5 ++- .../algorithms/get_subgraph.h | 4 ++- .../open_dataflow_subgraph_result.dtg.h | 34 +++++++++++++++++++ .../open_dataflow_subgraph_result.struct.toml | 18 ++++++++++ .../algorithms/get_subgraph.cc | 15 ++++---- .../open_dataflow_subgraph_result.dtg.cc | 20 +++++++++++ 6 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.cc diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 65e278686e..15af188d5a 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -8,7 +8,10 @@ namespace FlexFlow { -OpenDataflowGraphView subgraph_matched(OpenDataflowGraphView const &g, +std::pair< + OpenDataflowGraphView, + bidict +> subgraph_matched(OpenDataflowGraphView const &g, UnlabelledDataflowGraphPatternMatch const &match) { std::unordered_set matched_nodes = keys(match.node_assignment.reversed()); std::vector subgraph_inputs = sorted(get_subgraph_inputs(g, matched_nodes)); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h index 835bea4731..0c89906a3f 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h @@ -3,9 +3,11 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" + namespace FlexFlow { -OpenDataflowGraphView get_subgraph(OpenDataflowGraphView const &, +OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &, std::unordered_set const &, std::vector const &); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h new file mode 100644 index 0000000000..125e81af83 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h @@ -0,0 +1,34 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml +/* proj-data +{ + "generated_from": "f04bbd620df1cdf83703ff43b9bd6c40" +} +*/ + +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_SUBGRAPH_RESULT_DTG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_SUBGRAPH_RESULT_DTG_H + +#include "utils/bidict.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { +struct OpenDataflowSubgraphResult { + OpenDataflowSubgraphResult() = delete; + explicit OpenDataflowSubgraphResult( + ::FlexFlow::OpenDataflowGraphView const &graph, + ::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, + ::FlexFlow::DataflowGraphInput> const + &full_graph_values_to_subgraph_inputs); + + ::FlexFlow::OpenDataflowGraphView graph; + ::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, + ::FlexFlow::DataflowGraphInput> + full_graph_values_to_subgraph_inputs; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_SUBGRAPH_RESULT_DTG_H diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml new file mode 100644 index 0000000000..913f54f147 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OpenDataflowSubgraphResult" +features = [] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h", + "utils/bidict.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "graph" +type = "::FlexFlow::OpenDataflowGraphView" + +[[fields]] +name = "full_graph_values_to_subgraph_inputs" +type = "::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index 7a1a9a0c89..c582117166 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -74,7 +74,7 @@ struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { }; -OpenDataflowGraphView get_subgraph(OpenDataflowGraphView const &g, +OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &g, std::unordered_set const &subgraph_nodes, std::vector const &input_ordering) { std::vector subgraph_inputs; @@ -85,11 +85,14 @@ OpenDataflowGraphView get_subgraph(OpenDataflowGraphView const &g, full_graph_values_to_subgraph_inputs.equate({full_graph_value, subgraph_input}); } - return OpenDataflowGraphView::create( - g, - subgraph_nodes, - subgraph_inputs, - input_ordering); + return OpenDataflowSubgraphResult{ + OpenDataflowGraphView::create( + g, + subgraph_nodes, + subgraph_inputs, + input_ordering), + full_graph_values_to_subgraph_inputs, + }; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.cc new file mode 100644 index 0000000000..0496e4270a --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.cc @@ -0,0 +1,20 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.struct.toml +/* proj-data +{ + "generated_from": "f04bbd620df1cdf83703ff43b9bd6c40" +} +*/ + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" + +namespace FlexFlow { +OpenDataflowSubgraphResult::OpenDataflowSubgraphResult( + ::FlexFlow::OpenDataflowGraphView const &graph, + ::FlexFlow::bidict<::FlexFlow::OpenDataflowValue, + ::FlexFlow::DataflowGraphInput> const + &full_graph_values_to_subgraph_inputs) + : graph(graph), full_graph_values_to_subgraph_inputs( + full_graph_values_to_subgraph_inputs) {} +} // namespace FlexFlow From f9b129e7b5a713522386a3470ab815e64bfe1257 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 26 Jun 2024 13:47:08 -0700 Subject: [PATCH 14/71] Fix proj url --- flake.nix | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index ae973038ff..3599304ed0 100644 --- a/flake.nix +++ b/flake.nix @@ -18,8 +18,7 @@ flake-utils.url = "github:numtide/flake-utils"; proj-repo = { - # url = "github:lockshaw/proj"; - url = "git+file:///home/lockshaw/x/proj/proj"; + url = "github:lockshaw/proj"; inputs.nixpkgs.follows = "nixpkgs"; inputs.flake-utils.follows = "flake-utils"; }; From cf73f08e930f1a4bae7a544f746c62e2c85f1c7e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 6 Jul 2024 23:37:32 -0700 Subject: [PATCH 15/71] Get back to substitutions, now with unordered graph inputs --- lib/op-attrs/test/src/ops/conv_2d.cc | 2 + lib/pcg/include/pcg/create_grad.h | 8 +- .../include/pcg/dataflow_graph/algorithms.h | 37 --- .../operator_added_result.struct.toml | 22 -- lib/pcg/src/pcg/computation_graph_builder.cc | 4 +- lib/pcg/src/pcg/create_grad.cc | 17 ++ lib/pcg/src/pcg/dataflow_graph/algorithms.cc | 1 - .../test/src/pcg/dataflow_graph/algorithms.cc | 76 ------- .../parallel_computation_graph.cc | 4 +- ...tput_operator_attrs_assignment.struct.toml | 1 + .../include/substitutions/substitution.h | 4 +- .../tensor_pattern/eval_list_access.h | 2 +- .../tensor_pattern/eval_list_size.h | 2 +- .../tensor_pattern/get_attribute.h | 2 +- .../tensor_pattern/satisfies_constraint.h | 2 +- .../tensor_pattern/tensor_attribute_expr.h | 2 +- .../tensor_attribute_pattern.struct.toml | 3 +- .../closed_pattern_edge.struct.toml | 15 -- .../unlabelled/downward_open_pattern_edge.h | 12 - .../downward_open_pattern_edge.struct.toml | 15 -- .../substitutions/unlabelled/edge_splits.h | 20 +- .../unlabelled/find_pattern_matches.h | 6 +- .../unlabelled/input_pattern_edge.h | 3 + .../unlabelled/input_pattern_edge.struct.toml | 4 +- .../unlabelled/match_additional_criterion.h | 12 + .../unlabelled/multidigraph_pattern_match.h | 14 +- .../substitutions/unlabelled/pattern_edge.h | 2 + .../unlabelled/pattern_edge.variant.toml | 19 ++ .../unlabelled/pattern_input.struct.toml | 16 ++ .../unlabelled/pattern_node_output.h | 13 ++ .../pattern_node_output.struct.toml | 16 ++ .../substitutions/unlabelled/pattern_split.h | 3 +- .../pattern_split_result.struct.toml | 22 ++ .../substitutions/unlabelled/pattern_value.h | 13 ++ .../unlabelled/pattern_value.struct.toml | 16 -- .../unlabelled/pattern_value.variant.toml | 19 ++ .../unlabelled/standard_pattern_edge.h | 16 ++ .../standard_pattern_edge.struct.toml | 16 ++ .../unlabelled_dataflow_graph_pattern_match.h | 17 ++ ...d_dataflow_graph_pattern_match.struct.toml | 12 +- .../unlabelled/unlabelled_graph_pattern.h | 8 +- .../unlabelled/upward_open_pattern_edge.h | 12 - .../upward_open_pattern_edge.struct.toml | 15 -- .../src/substitutions/pcg_pattern.cc | 13 +- .../sub_parallel_computation_graph.cc | 2 +- .../src/substitutions/substitution.cc | 2 +- .../unlabelled/downward_open_pattern_edge.cc | 9 - .../substitutions/unlabelled/edge_splits.cc | 58 ++--- .../unlabelled/find_pattern_matches.cc | 162 +++++-------- .../unlabelled/input_pattern_edge.cc | 7 +- .../unlabelled/match_additional_criterion.cc | 12 + .../unlabelled/multidigraph_pattern_match.cc | 98 ++++---- .../unlabelled/pattern_matching.cc | 214 +++++++++++++----- .../unlabelled/pattern_node_output.cc | 13 ++ .../substitutions/unlabelled/pattern_value.cc | 13 ++ .../unlabelled/standard_pattern_edge.cc | 13 ++ ...unlabelled_dataflow_graph_pattern_match.cc | 36 +++ .../unlabelled/upward_open_pattern_edge.cc | 9 - lib/utils/include/utils/containers.h | 1 + .../graph/dataflow_graph/dataflow_graph.h | 13 ++ .../graph/dataflow_graph/i_dataflow_graph.h | 8 + .../instances/unordered_set_dataflow_graph.h | 8 +- ...ordered_set_labelled_open_dataflow_graph.h | 17 +- .../labelled_dataflow_graph_view.h | 6 +- .../algorithms/rewrite_labels.h | 14 +- .../algorithms/with_labelling.h | 2 +- .../i_labelled_open_dataflow_graph.h | 2 + .../i_labelled_open_dataflow_graph_view.h | 1 + .../labelled_open_dataflow_graph.h | 4 + .../labelled_open_dataflow_graph_view.h | 12 +- .../include/utils/graph/node/graph_view.h | 2 +- .../graph/open_dataflow_graph/algorithms.h | 2 +- .../algorithms/get_subgraph.h | 3 +- .../dataflow_graph_input.struct.toml | 6 +- .../dataflow_graph_input_source.h | 19 ++ .../i_open_dataflow_graph_view.h | 2 +- .../open_dataflow_graph/open_dataflow_edge.h | 1 + .../open_dataflow_graph_view.h | 2 +- .../unordered_set_open_dataflow_graph.h | 9 +- .../graph/serial_parallel/graph_generation.h | 26 +++ .../graph/serial_parallel/parallel.fwd.h | 10 - .../serial_parallel/parallel.struct.toml | 29 --- .../utils/graph/serial_parallel/serial.fwd.h | 10 - .../graph/serial_parallel/serial.struct.toml | 29 --- ...serial_parallel_decomposition.variant.toml | 7 +- .../serial_parallel/serial_parallel_splits.h | 74 ++++++ .../graph/serial_parallel/serialparallel.h | 25 +- .../graph/dataflow_graph/dataflow_graph.cc | 6 + .../src/utils/graph/digraph/algorithms.cc | 16 +- .../instances/unordered_set_dataflow_graph.cc | 35 ++- .../algorithms/rewrite_labels.cc | 4 +- .../graph/open_dataflow_graph/algorithms.cc | 2 +- .../algorithms/get_subgraph.cc | 37 ++- .../dataflow_graph_input_source.cc | 15 ++ .../open_dataflow_graph/open_dataflow_edge.cc | 9 +- .../open_dataflow_graph_view.cc | 2 +- .../unordered_set_open_dataflow_graph.cc | 14 +- .../graph/serial_parallel/graph_generation.cc | 48 ++++ .../serial_parallel/serial_parallel_splits.cc | 109 +++++++++ .../graph/serial_parallel/serialparallel.cc | 131 +---------- .../serialparallel_internal.cc | 50 +++- .../serial_parallel/serialparallel_internal.h | 21 +- 102 files changed, 1182 insertions(+), 847 deletions(-) delete mode 100644 lib/pcg/include/pcg/dataflow_graph/algorithms.h delete mode 100644 lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml create mode 100644 lib/pcg/src/pcg/create_grad.cc delete mode 100644 lib/pcg/src/pcg/dataflow_graph/algorithms.cc delete mode 100644 lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc delete mode 100644 lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml delete mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml delete mode 100644 lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_value.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc delete mode 100644 lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input_source.h create mode 100644 lib/utils/include/utils/graph/serial_parallel/graph_generation.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml delete mode 100644 lib/utils/include/utils/graph/serial_parallel/serial.fwd.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/serial.struct.toml create mode 100644 lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input_source.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/graph_generation.cc create mode 100644 lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/ops/conv_2d.cc index 6f5028cfeb..c4462eb7ec 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/ops/conv_2d.cc @@ -2,6 +2,8 @@ #include "doctest/doctest.h" #include "utils/integer_conversions.h" +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Conv2D shape inference") { int out_channels = 4; diff --git a/lib/pcg/include/pcg/create_grad.h b/lib/pcg/include/pcg/create_grad.h index 5a12d310c2..b2f753eaec 100644 --- a/lib/pcg/include/pcg/create_grad.h +++ b/lib/pcg/include/pcg/create_grad.h @@ -1,8 +1,12 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H #define _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H -#include "pcg/create_grad_t.h" +#include "pcg/create_grad.dtg.h" -namespace FlexFlow {} +namespace FlexFlow { + +bool bool_from_create_grad(CreateGrad); + +} #endif diff --git a/lib/pcg/include/pcg/dataflow_graph/algorithms.h b/lib/pcg/include/pcg/dataflow_graph/algorithms.h deleted file mode 100644 index 7673bae41f..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph/algorithms.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_GRAPH_ALGORITHMS_H - -#include "pcg/dataflow_graph/dataflow_graph.h" - -namespace FlexFlow { - -template -std::vector - get_inputs(DataflowGraph const &g, Node const &n) { - std::vector> input_edges = - transform(as_vector(get_incoming_edges(g.get_raw_graph(), - std::unordered_set{n})), - [&](MultiDiEdge const &e) { - int idx = g.idx_for_port(e.dst_idx); - MultiDiOutput val = static_cast(e); - return std::make_pair(idx, val); - }); - - return vector_from_indexed_set(input_edges); -} - -template -std::vector - get_outputs(DataflowGraph const &g, Node const &n) { - return g.get_output_map().at(n); -} - -template -std::vector - topological_ordering(DataflowGraph const &g) { - return get_topological_ordering(g.get_raw_graph()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml b/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml deleted file mode 100644 index 3c9cb87e85..0000000000 --- a/lib/pcg/include/pcg/dataflow_graph/operator_added_result.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "OperatorAddedResult" - -features = [ - "eq", - "ord", - "fmt", -] - -includes = [ - "", - "utils/graph.h", - "utils/fmt/vector.h", -] - -[[fields]] -name = "node" -type = "::FlexFlow::Node" - -[[fields]] -name = "outputs" -type = "std::vector<::FlexFlow::MultiDiOutput>" diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 9dea39e5c3..0fb6a90c38 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -22,9 +22,9 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { } tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, - bool create_grad) { + CreateGrad create_grad) { TensorAttrs tensor_attrs = - TensorAttrs{shape, std::nullopt, create_grad, std::nullopt}; + TensorAttrs{shape, std::nullopt, std::nullopt, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, std::nullopt, diff --git a/lib/pcg/src/pcg/create_grad.cc b/lib/pcg/src/pcg/create_grad.cc new file mode 100644 index 0000000000..00029aa3fd --- /dev/null +++ b/lib/pcg/src/pcg/create_grad.cc @@ -0,0 +1,17 @@ +#include "pcg/create_grad.h" +#include "utils/exception.h" + +namespace FlexFlow { + +bool bool_from_create_grad(CreateGrad cg) { + switch (cg) { + case CreateGrad::YES: + return true; + case CreateGrad::NO: + return false; + default: + throw mk_runtime_error(fmt::format("Unknown CreateGrad value {}", cg)); + } +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/dataflow_graph/algorithms.cc b/lib/pcg/src/pcg/dataflow_graph/algorithms.cc deleted file mode 100644 index 3ef04c95a3..0000000000 --- a/lib/pcg/src/pcg/dataflow_graph/algorithms.cc +++ /dev/null @@ -1 +0,0 @@ -#include "pcg/dataflow_graph/algorithms.h" diff --git a/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc b/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc deleted file mode 100644 index f47151e76a..0000000000 --- a/lib/pcg/test/src/pcg/dataflow_graph/algorithms.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/dataflow_graph/algorithms.h" -#include "test/utils/doctest.h" -#include "utils/fmt/unordered_set.h" - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_inputs/get_outputs") { - DataflowGraph g; - - int n1_label = 1; - int n2_label = 2; - int n3_label = 3; - int n4_label = 4; - - std::string o1_label = "o1"; - std::string o2_label = "o2"; - std::string o3_label = "o3"; - std::string o4_label = "o4"; - - OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); - Node n1 = n1_added.node; - MultiDiOutput o1 = get_only(n1_added.outputs); - - OperatorAddedResult n2_added = g.add_operator(n2_label, {}, {o2_label}); - Node n2 = n2_added.node; - MultiDiOutput o2 = get_only(n2_added.outputs); - - OperatorAddedResult n3_added = g.add_operator(n3_label, {}, {o3_label}); - Node n3 = n3_added.node; - MultiDiOutput o3 = get_only(n3_added.outputs); - - OperatorAddedResult n4_added = - g.add_operator(n4_label, {o1, o2, o3}, {o4_label}); - Node n4 = n4_added.node; - MultiDiOutput o4 = get_only(n4_added.outputs); - - SUBCASE("get_inputs") { - std::vector result = get_inputs(g, n4); - std::vector correct = {o1, o2, o3}; - CHECK(result == correct); - } - - SUBCASE("get_outputs") { - std::vector result = get_outputs(g, n4); - std::vector correct = {o4}; - CHECK(result == correct); - } - } - - TEST_CASE("topological_ordering") { - DataflowGraph g; - - int n1_label = 1; - int n2_label = 2; - int n3_label = 3; - - std::string o1_label = "o1"; - std::string o2_label = "o2"; - std::string o3_label = "o3"; - - OperatorAddedResult n1_added = g.add_operator(n1_label, {}, {o1_label}); - Node n1 = n1_added.node; - MultiDiOutput o1 = get_only(n1_added.outputs); - - OperatorAddedResult n2_added = g.add_operator(n2_label, {o1}, {o2_label}); - Node n2 = n2_added.node; - MultiDiOutput o2 = get_only(n2_added.outputs); - - OperatorAddedResult n3_added = g.add_operator(n3_label, {o2}, {o3_label}); - Node n3 = n3_added.node; - MultiDiOutput o3 = get_only(n3_added.outputs); - - std::vector result = topological_ordering(g); - std::vector correct = {n1, n2, n3}; - CHECK(result == correct); - } -} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 7b188d1b66..1c603d9778 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -31,7 +31,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t tensor3 = get_only(layer3_added.outputs); std::vector result = topological_ordering(pcg); - std::vector correct = {layer1, layer2, layer3}; - CHECK(result == correct); + // std::vector correct = {layer1, layer2, layer3}; + // CHECK(result == correct); } } diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml index 9781515803..f52d0f2b3e 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -12,6 +12,7 @@ includes = [ "substitutions/operator_pattern/operator_attribute_key.dtg.h", "substitutions/output_graph/output_operator_attribute_expr.dtg.h", "", + "utils/hash/unordered_map.h", ] # NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 1aa2b2946b..9b61f8fb17 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -3,7 +3,7 @@ #include "sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" namespace FlexFlow { @@ -12,7 +12,7 @@ bool is_valid_substitution(Substitution const &); SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &, Substitution const &, - MultiDiGraphPatternMatch const &); + UnlabelledDataflowGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h index e245e800b2..94eb00f74d 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h index de0d58e14f..99a4063d0a 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h index eedca2da82..08615207bb 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h index 6c11b421a8..ba57ff5300 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h index 98d4394530..d40e9dad47 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H -#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml index 43f45e95b9..139774979e 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml @@ -12,7 +12,8 @@ features = [ includes = [ "", "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", - "utils/hash-utils.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", ] [[fields]] diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml deleted file mode 100644 index d609ca1c27..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "ClosedPatternEdge" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::MultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h deleted file mode 100644 index 9855d96e46..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H - -#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" - -namespace FlexFlow { - -int get_src_idx(DownwardOpenPatternEdge const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml deleted file mode 100644 index 2dda7498f0..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "DownwardOpenPatternEdge" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::DownwardOpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h index 58704500ac..81a0a320db 100644 --- a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h @@ -1,20 +1,20 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H -#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" -#include "substitutions/unlabelled/edge_splits.dtg.h" -#include "substitutions/unlabelled/input_pattern_edge.dtg.h" -#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +// #include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +// #include "substitutions/unlabelled/edge_splits.dtg.h" +// #include "substitutions/unlabelled/input_pattern_edge.dtg.h" +// #include "substitutions/unlabelled/output_pattern_edge.dtg.h" #include namespace FlexFlow { -std::pair - get_split_edges(UnlabelledPatternEdgeSplits const &, - ClosedPatternEdge const &); - -std::vector> - as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); +// std::pair +// get_split_edges(UnlabelledPatternEdgeSplits const &, +// ClosedPatternEdge const &); +// +// std::vector> +// as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h index 29c5740c0e..154ce183a4 100644 --- a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -2,15 +2,15 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H #include "substitutions/unlabelled/match_additional_criterion.dtg.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "utils/graph.h" namespace FlexFlow { -std::vector +std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, + OpenDataflowGraphView const &graph, MatchAdditionalCriterion const &additional_criterion); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h index b05fa479db..7a7c9c3c28 100644 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h @@ -2,11 +2,14 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H #include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_input.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" namespace FlexFlow { +PatternInput get_src_input(InputPatternEdge const &); PatternNode get_dst_node(InputPatternEdge const &); +int get_dst_idx(InputPatternEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml index 6da52b58aa..abe9c6f768 100644 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml @@ -7,9 +7,9 @@ features = [ ] includes = [ - "utils/graph.h" + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" ] [[fields]] name = "raw_edge" -type = "::FlexFlow::InputMultiDiEdge" +type = "::FlexFlow::OpenDataflowEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.h new file mode 100644 index 0000000000..445c5cb26e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_H + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" + +namespace FlexFlow { + +MatchAdditionalCriterion match_additional_crition_always_true(); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h index aacae6d42a..1b30f274f9 100644 --- a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H -#include "substitutions/unlabelled/edge_splits.dtg.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +// #include "substitutions/unlabelled/edge_splits.dtg.h" +// #include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" namespace FlexFlow { -MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); -std::optional - unsplit_matches(MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - UnlabelledPatternEdgeSplits const &edge_splits); +// MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); +// std::optional +// unsplit_matches(MultiDiGraphPatternMatch const &prefix, +// MultiDiGraphPatternMatch const &postfix, +// UnlabelledPatternEdgeSplits const &edge_splits); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h index 79db533d4e..12405af184 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -9,6 +9,8 @@ namespace FlexFlow { +PatternNode get_dst_node(PatternEdge const &); + std::unordered_set get_nodes(PatternEdge const &); bool is_closed_edge(PatternEdge const &); bool is_input_edge(PatternEdge const &); diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml new file mode 100644 index 0000000000..143ea78ac1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "PatternEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/unlabelled/input_pattern_edge.dtg.h", + "substitutions/unlabelled/standard_pattern_edge.dtg.h", +] + +[[values]] +type = "::FlexFlow::InputPatternEdge" + +[[values]] +type = "::FlexFlow::StandardPatternEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml new file mode 100644 index 0000000000..e91e5673af --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "PatternInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h new file mode 100644 index 0000000000..3dd5b262c9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_OUTPUT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_OUTPUT_H + +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/pattern_node_output.dtg.h" +namespace FlexFlow { + +PatternNode get_src_node(PatternNodeOutput const &); +int get_idx(PatternNodeOutput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml new file mode 100644 index 0000000000..c2b85ae4fb --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node_output.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "PatternNodeOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", +] + +[[fields]] +name = "raw_dataflow_output" +type = "::FlexFlow::DataflowOutput" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h index 3fcc5cb12f..ff67c882df 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -4,6 +4,7 @@ #include "substitutions/unlabelled/edge_splits.dtg.h" #include "substitutions/unlabelled/pattern_split.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_split_result.dtg.h" namespace FlexFlow { @@ -15,7 +16,7 @@ UnlabelledPatternEdgeSplits get_edge_splits(UnlabelledGraphPattern const &pattern, PatternSplit const &split); -std::pair +PatternSplitResult apply_split(UnlabelledGraphPattern const &, PatternSplit const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml new file mode 100644 index 0000000000..168e38e180 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split_result.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PatternSplitResult" +features = [ ] + +includes = [ + "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h", + "substitutions/unlabelled/pattern_value.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "utils/bidict.h", +] + +[[fields]] +name = "subpattern_1" +type = "::FlexFlow::UnlabelledGraphPattern" + +[[fields]] +name = "subpattern_2" +type = "::FlexFlow::UnlabelledGraphPattern" + +[[fields]] +name = "subpattern_1_outputs_to_subpattern_2_inputs" +type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::PatternInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.h b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h new file mode 100644 index 0000000000..1ae391f080 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_VALUE_H + +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +OpenDataflowValue raw_dataflow_value_from_pattern_value(PatternValue const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml deleted file mode 100644 index c9b52b4c9e..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_value.struct.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "PatternValue" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", -] - -[[fields]] -name = "raw_dataflow_value" -type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml new file mode 100644 index 0000000000..f9abc85c4b --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "PatternValue" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node_output.dtg.h", +] + +[[values]] +type = "::FlexFlow::PatternNodeOutput" + +[[values]] +type = "::FlexFlow::PatternInput" diff --git a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h new file mode 100644 index 0000000000..27e3429585 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_STANDARD_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_STANDARD_PATTERN_EDGE_H + +#include "substitutions/unlabelled/standard_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_src_node(StandardPatternEdge const &); +PatternNode get_dst_node(StandardPatternEdge const &); +int get_src_idx(StandardPatternEdge const &); +int get_dst_idx(StandardPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml new file mode 100644 index 0000000000..4a2e193544 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/standard_pattern_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "StandardPatternEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h new file mode 100644 index 0000000000..3a459e69b4 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H + +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" + +namespace FlexFlow { + +UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match(); +std::unordered_set matched_nodes(UnlabelledDataflowGraphPatternMatch const &); +std::optional merge_unlabelled_dataflow_graph_pattern_matches(UnlabelledDataflowGraphPatternMatch const &subpattern_1, + UnlabelledDataflowGraphPatternMatch const &subpattern_2, + bidict const &outputs_of_1_to_inputs_of_2); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml index af28609478..064bc85d2a 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml @@ -10,15 +10,17 @@ features = [ includes = [ "utils/bidict.h", "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", - "substitutions/unlabelled/pattern_value.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", "substitutions/unlabelled/pattern_node.dtg.h", + "", + "utils/fmt/unordered_map.h", ] [[fields]] name = "node_assignment" type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" -# [[fields]] -# name = "value_assignment" -# type = "::FlexFlow::bidict<::FlexFlow::PatternValue, ::FlexFlow::OpenDataflowValue>" +[[fields]] +name = "input_assignment" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 3de76f6ab2..e87e261dc0 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -1,9 +1,11 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H +#include "substitutions/unlabelled/pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_input.dtg.h" namespace FlexFlow { @@ -11,10 +13,14 @@ size_t num_nodes(UnlabelledGraphPattern const &); bool is_singleton_pattern(UnlabelledGraphPattern const &); std::unordered_set get_nodes(UnlabelledGraphPattern const &); std::unordered_set get_values(UnlabelledGraphPattern const &); -std::unordered_set get_value_uses(UnlabelledGraphPattern const &, PatternValue const &); +// std::unordered_set get_value_uses(UnlabelledGraphPattern const &, PatternValue const &); std::vector get_topological_ordering(UnlabelledGraphPattern const &); +std::unordered_set get_inputs(UnlabelledGraphPattern const &); + +std::unordered_set get_edges(UnlabelledGraphPattern const &); + std::vector get_inputs_to_pattern_node(UnlabelledGraphPattern const &, PatternNode const &); std::vector diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h deleted file mode 100644 index 998cf1a519..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H - -#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" - -namespace FlexFlow { - -int get_dst_idx(UpwardOpenPatternEdge const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml deleted file mode 100644 index a4c3bad809..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "UpwardOpenPatternEdge" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::UpwardOpenMultiDiEdge" diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index bf59ab5080..fe78ebd266 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -2,6 +2,7 @@ #include "substitutions/operator_pattern/satisfies_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/satisfies_pattern.h" +#include "substitutions/unlabelled/pattern_value.h" namespace FlexFlow { @@ -10,8 +11,8 @@ UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { } TensorAttributePattern get_tensor_pattern(PCGPattern const &p, - PatternEdge const &e) { - return p.raw_graph.at(e.raw_edge); + PatternValue const &v) { + return p.raw_graph.at(raw_dataflow_value_from_pattern_value(v)); } OperatorAttributePattern get_operator_pattern(PCGPattern const &p, @@ -21,7 +22,7 @@ OperatorAttributePattern get_operator_pattern(PCGPattern const &p, bool assignment_satisfies(SubParallelComputationGraph const &pcg, PCGPattern const &pattern, - MultiDiGraphPatternMatch const &patternMatch) { + UnlabelledDataflowGraphPatternMatch const &patternMatch) { return unlabelled_pattern_does_match( get_unlabelled_pattern(pattern), pcg.raw_graph, @@ -32,10 +33,10 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, get_operator_attrs(pcg, pcgNode), get_operator_pattern(pattern, patternNode)); }, - [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { + [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { return parallel_tensor_satisfies_pattern( - get_parallel_tensor_attrs(pcg, pcgEdge), - get_tensor_pattern(pattern, patternEdge)); + get_parallel_tensor_attrs(pcg, pcgValue), + get_tensor_pattern(pattern, patternValue)); }}); } diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 5cb4825d29..965d77f3d1 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -10,7 +10,7 @@ ParallelLayerAttrs PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, Node const &n) { - return get_parallel_layer_attrs(spcg, n).attrs; + return get_parallel_layer_attrs(spcg, n).op_attrs; } ParallelTensorAttrs diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index e900175bc6..b4e6709a73 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -147,7 +147,7 @@ bool is_valid_substitution(Substitution const &) { SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &, Substitution const &, - MultiDiGraphPatternMatch const &) { + UnlabelledDataflowGraphPatternMatch const &) { NOT_IMPLEMENTED(); } diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc deleted file mode 100644 index 704e0aea1a..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "substitutions/unlabelled/downward_open_pattern_edge.h" - -namespace FlexFlow { - -int get_src_idx(DownwardOpenPatternEdge const &e) { - return get_src_idx(e.raw_edge); -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc index 33ea7dc9f6..761bd3b5a1 100644 --- a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc @@ -2,34 +2,34 @@ namespace FlexFlow { -std::pair - get_split_edges(UnlabelledPatternEdgeSplits const &splits, - ClosedPatternEdge const &e) { - std::pair raw_result = - splits.unwrapped.at_l(e.raw_edge); - return { - OutputPatternEdge{raw_result.first}, - InputPatternEdge{raw_result.second}, - }; -} - -std::vector> - as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { - std::vector< - std::tuple> - result; - - for (auto const &kv : s.unwrapped) { - MultiDiEdge standard_edge = kv.first; - OutputMultiDiEdge output_edge = kv.second.first; - InputMultiDiEdge input_edge = kv.second.second; - - result.push_back({ClosedPatternEdge{standard_edge}, - OutputPatternEdge{output_edge}, - InputPatternEdge{input_edge}}); - } - - return result; -} +// std::pair +// get_split_edges(UnlabelledPatternEdgeSplits const &splits, +// ClosedPatternEdge const &e) { +// std::pair raw_result = +// splits.unwrapped.at_l(e.raw_edge); +// return { +// OutputPatternEdge{raw_result.first}, +// InputPatternEdge{raw_result.second}, +// }; +// } +// +// std::vector> +// as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { +// std::vector< +// std::tuple> +// result; +// +// for (auto const &kv : s.unwrapped) { +// MultiDiEdge standard_edge = kv.first; +// OutputMultiDiEdge output_edge = kv.second.first; +// InputMultiDiEdge input_edge = kv.second.second; +// +// result.push_back({ClosedPatternEdge{standard_edge}, +// OutputPatternEdge{output_edge}, +// InputPatternEdge{input_edge}}); +// } +// +// return result; +// } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 8c787ca255..84d0fae324 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -1,153 +1,95 @@ #include "substitutions/unlabelled/find_pattern_matches.h" -#include "substitutions/unlabelled/downward_open_pattern_edge.h" #include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "substitutions/unlabelled/upward_open_pattern_edge.h" #include "utils/containers.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/containers/zip_vectors.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "substitutions/unlabelled/match_additional_criterion.h" +#include "substitutions/unlabelled/pattern_matching.h" namespace FlexFlow { -static std::vector - sorted_by_dst_idx(std::unordered_set const &in) { - return sorted_by( - in, compare_by([](UpwardOpenPatternEdge const &e) { - return get_dst_idx(e); - })); -} - -static std::vector - sorted_by_src_idx(std::unordered_set const &in) { - return sorted_by( - in, - compare_by( - [](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); -} - -static std::vector - sorted_by_dst_idx(std::unordered_set const &in) { - return sorted_by( - in, compare_by([](UpwardOpenPatternEdge const &e) { - return get_dst_idx(e); - })); -} - -static std::vector - sorted_by_src_idx(std::unordered_set const &in) { - return sorted_by( - in, - compare_by( - [](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); -} - -static std::optional +static std::optional get_candidate_singleton_match(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, + OpenDataflowGraphView const &graph, Node const &graph_node) { assert(is_singleton_pattern(pattern)); PatternNode pattern_node = get_only(get_nodes(pattern)); - MultiDiGraphPatternMatch match = empty_multidigraph_pattern_match(); + UnlabelledDataflowGraphPatternMatch match = empty_unlabelled_pattern_match(); match.node_assignment.equate(pattern_node, graph_node); - std::unordered_set incoming = - get_incoming_edges(graph, graph_node); - std::unordered_set outgoing = - get_outgoing_edges(graph, graph_node); - - std::unordered_set pattern_incoming = - get_incoming_edges(pattern, pattern_node); - std::unordered_set pattern_outgoing = - get_outgoing_edges(pattern, pattern_node); - - if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { + std::vector pattern_outputs = get_outputs_from_pattern_node(pattern, pattern_node); + std::vector graph_outputs = transform(get_outputs(graph, graph_node), + [](DataflowOutput const &o) { + return OpenDataflowValue{o}; + }); + + if (pattern_outputs.size() != graph_outputs.size()) { return std::nullopt; } - if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { + std::vector pattern_node_inputs = get_inputs_to_pattern_node(pattern, pattern_node); + std::unordered_set pattern_graph_inputs = get_inputs(pattern); + + assert (without_order(pattern_node_inputs) == transform(pattern_graph_inputs, [](PatternInput const &i) { return PatternValue{i}; })); + + std::vector graph_node_inputs = get_inputs(graph, graph_node); + + if (graph_node_inputs.size() != pattern_node_inputs.size()) { return std::nullopt; } - std::vector incoming_ordered = - sorted_by_dst_idx(incoming); - std::vector outgoing_ordered = - sorted_by_src_idx(outgoing); - - std::vector pattern_incoming_ordered = - sorted_by_dst_idx(pattern_incoming); - std::vector pattern_outgoing_ordered = - sorted_by_src_idx(pattern_outgoing); - - if (pattern_incoming.size() > 0) { - std::unordered_map node_port_mapping; - for (int i = 0; i < incoming_ordered.size(); ++i) { - UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i]; - UpwardOpenPatternEdge pattern_edge = pattern_incoming_ordered[i]; - NodePort graph_port = get_dst_idx(graph_edge), - pattern_port = get_dst_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.emplace(graph_port, pattern_port); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } + for (auto const &[pattern_node_input, graph_node_input] : zip(pattern_node_inputs, graph_node_inputs)) { + assert (pattern_node_input.has()); + assert (graph_node_input.has()); - if (pattern_outgoing.size() > 0) { - std::unordered_map node_port_mapping; - for (int i = 0; i < outgoing_ordered.size(); ++i) { - DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], - DownwardOpenPatternEdge pattern_edge = - pattern_outgoing_ordered[i]; - - NodePort graph_port = get_src_idx(graph_edge), - pattern_port = get_src_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.insert({graph_port, pattern_port}); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } + match.input_assignment.insert({ + pattern_node_input.get(), + graph_node_input.get(), + }); } + assert (unlabelled_pattern_does_match(pattern, graph, match, match_additional_crition_always_true())); + return match; } -std::vector +std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, + OpenDataflowGraphView const &graph, MatchAdditionalCriterion const &additional_criterion) { - std::vector matches; + std::vector matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - std::optional candidate = + std::optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() && - pattern_does_match( + unlabelled_pattern_does_match( pattern, graph, candidate.value(), additional_criterion)) { matches.push_back(candidate.value()); } } } else { - GraphSplit split = split_pattern(pattern); + PatternSplit split = find_even_split(pattern); auto subpatterns = apply_split(pattern, split); - auto prefix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto postfix_matches = - find_pattern_matches(subpatterns.second, graph, additional_criterion); + std::vector prefix_matches = + find_pattern_matches(subpatterns.subpattern_1, graph, additional_criterion); + std::vector postfix_matches = + find_pattern_matches(subpatterns.subpattern_2, graph, additional_criterion); + auto edge_splits = get_edge_splits(pattern, split); - for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { - for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - std::optional unsplit = - unsplit_matches(prefix_match, postfix_match, edge_splits); + for (UnlabelledDataflowGraphPatternMatch const &prefix_match : prefix_matches) { + for (UnlabelledDataflowGraphPatternMatch const &postfix_match : postfix_matches) { + std::optional unsplit = + merge_unlabelled_dataflow_graph_pattern_matches(prefix_match, + postfix_match, + subpatterns.subpattern_1_outputs_to_subpattern_2_inputs); if (unsplit.has_value()) { matches.push_back(unsplit.value()); } diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc index 2eff39bb1e..a45154cf05 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -1,9 +1,14 @@ #include "substitutions/unlabelled/input_pattern_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" namespace FlexFlow { +PatternInput get_src_input(InputPatternEdge const &) { + NOT_IMPLEMENTED(); +} + PatternNode get_dst_node(InputPatternEdge const &e) { - return PatternNode{e.raw_edge.dst}; + return PatternNode{get_open_dataflow_edge_dst_node(e.raw_edge)}; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc new file mode 100644 index 0000000000..7d4e11d1a7 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.cc @@ -0,0 +1,12 @@ +#include "substitutions/unlabelled/match_additional_criterion.h" + +namespace FlexFlow { + +MatchAdditionalCriterion match_additional_crition_always_true() { + return MatchAdditionalCriterion{ + [](PatternNode const &, Node const &) { return true; }, + [](PatternValue const &, OpenDataflowValue const &) { return true; }, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc index 8f4fd7f535..8ce60fab4f 100644 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -1,56 +1,56 @@ #include "substitutions/unlabelled/multidigraph_pattern_match.h" -#include "substitutions/unlabelled/edge_splits.h" -#include "substitutions/unlabelled/pattern_edge.h" +// #include "substitutions/unlabelled/edge_splits.h" +// #include "substitutions/unlabelled/pattern_edge.h" #include "utils/containers.h" namespace FlexFlow { -MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { - return MultiDiGraphPatternMatch{ - bidict{}, - bidict{}, - }; -} - -std::optional - unsplit_matches(MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - UnlabelledPatternEdgeSplits const &edge_splits) { - - MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); - - std::unordered_set handled; - for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { - ClosedPatternEdge closed_edge = std::get(coi); - OutputPatternEdge output_edge = std::get(coi); - InputPatternEdge input_edge = std::get(coi); - - handled.insert(pattern_edge_from_output_edge(output_edge)); - handled.insert(pattern_edge_from_input_edge(input_edge)); - - OpenMultiDiEdge output_graph_edge = - prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); - OpenMultiDiEdge input_graph_edge = - postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); - if (output_graph_edge == input_graph_edge) { - result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), - output_graph_edge); - } else { - return std::nullopt; - } - } - - for (auto const &kv : - merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { - if (!contains(handled, kv.first)) { - result.edge_assignment.equate(kv.first, kv.second); - } - } - - result.node_assignment = - merge_maps(prefix.node_assignment, postfix.node_assignment); - - return result; -} +// MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { +// return MultiDiGraphPatternMatch{ +// bidict{}, +// bidict{}, +// }; +// } + +// std::optional +// unsplit_matches(MultiDiGraphPatternMatch const &prefix, +// MultiDiGraphPatternMatch const &postfix, +// UnlabelledPatternEdgeSplits const &edge_splits) { +// +// MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); +// +// std::unordered_set handled; +// for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { +// ClosedPatternEdge closed_edge = std::get(coi); +// OutputPatternEdge output_edge = std::get(coi); +// InputPatternEdge input_edge = std::get(coi); +// +// handled.insert(pattern_edge_from_output_edge(output_edge)); +// handled.insert(pattern_edge_from_input_edge(input_edge)); +// +// OpenMultiDiEdge output_graph_edge = +// prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); +// OpenMultiDiEdge input_graph_edge = +// postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); +// if (output_graph_edge == input_graph_edge) { +// result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), +// output_graph_edge); +// } else { +// return std::nullopt; +// } +// } +// +// for (auto const &kv : +// merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { +// if (!contains(handled, kv.first)) { +// result.edge_assignment.equate(kv.first, kv.second); +// } +// } +// +// result.node_assignment = +// merge_maps(prefix.node_assignment, postfix.node_assignment); +// +// return result; +// } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 15af188d5a..89888a7c2f 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -1,21 +1,161 @@ #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/pattern_node_output.h" #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include +#include "utils/graph/node/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/overload.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.h" +#include "substitutions/unlabelled/standard_pattern_edge.h" namespace FlexFlow { -std::pair< - OpenDataflowGraphView, - bidict -> subgraph_matched(OpenDataflowGraphView const &g, - UnlabelledDataflowGraphPatternMatch const &match) { +OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &g, + UnlabelledDataflowGraphPatternMatch const &match) { std::unordered_set matched_nodes = keys(match.node_assignment.reversed()); - std::vector subgraph_inputs = sorted(get_subgraph_inputs(g, matched_nodes)); - return get_subgraph(g, matched_nodes, subgraph_inputs); + return get_subgraph(g, matched_nodes); +} + +// bool are_dataflow_graphs_equal_under(DataflowGraphView const &l, +// DataflowGraphView const &r, +// bidict const &matching) { +// std::unordered_set l_nodes = get_nodes(l); +// +// auto l_from_r = [&](Node const &r_node) { +// return matching.at_r(r_node); +// }; +// +// std::unordered_set l_from_r_nodes = transform(get_nodes(r), l_from_r); +// +// if (l_nodes != l_from_r_nodes) { +// return false; +// } +// +// std::unordered_set l_edges = get_edges(l); +// std::unordered_set l_from_r_edges = transform(get_edges(r), +// [&](DataflowEdge const &r_edge) { +// return DataflowEdge{ +// DataflowOutput{ +// l_from_r(r_edge.src.node), +// r_edge.src.idx, +// }, +// DataflowInput{ +// l_from_r(r_edge.dst.node), +// r_edge.dst.idx, +// } +// }; +// }); +// +// if (l_edges != l_from_r_edges) { +// return false; +// } +// +// return true; +// } + +struct ConcreteFromPattern { + ConcreteFromPattern(UnlabelledDataflowGraphPatternMatch const &match) + : match(match) + { } + + UnlabelledDataflowGraphPatternMatch const &match; + + + + Node operator()(PatternNode const &n) const { + return match.node_assignment.at_l(n); + } + + DataflowGraphInput operator()(PatternInput const &i) { + return match.input_assignment.at(i); + } + + DataflowInputEdge operator()(InputPatternEdge const &e) { + return DataflowInputEdge{ + this->operator()(get_src_input(e)), + DataflowInput{ + this->operator()(get_dst_node(e)), + get_dst_idx(e), + }, + }; + } + + DataflowEdge operator()(StandardPatternEdge const &e) { + return DataflowEdge{ + DataflowOutput{ + this->operator()(get_src_node(e)), + get_src_idx(e), + }, + DataflowInput{ + this->operator()(get_dst_node(e)), + get_dst_idx(e), + }, + }; + } + + OpenDataflowEdge operator()(PatternEdge const &pattern_e) { + return pattern_e.visit([&](auto const &e) { return OpenDataflowEdge{this->operator()(e)}; }); + } + + OpenDataflowValue operator()(PatternValue const &pattern_v) { + return pattern_v.visit([&](auto const &v) { return OpenDataflowValue{this->operator()(v)}; }); + } + + DataflowOutput operator()(PatternNodeOutput const &o) { + return DataflowOutput{ + this->operator()(get_src_node(o)), + get_idx(o), + }; + } +}; + +static bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, + OpenDataflowGraphView const &subgraph, + UnlabelledDataflowGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion) { + ConcreteFromPattern concrete_from_pattern{match}; + + std::unordered_set concrete_nodes = get_nodes(subgraph); + std::unordered_set concrete_nodes_from_match = transform(get_nodes(pattern), concrete_from_pattern); + + if (concrete_nodes != concrete_nodes_from_match) { + return false; + } + + for (PatternNode const &pattern_node : get_nodes(pattern)) { + if (!additional_criterion.node_criterion(pattern_node, concrete_from_pattern(pattern_node))) { + return false; + } + } + + std::unordered_set concrete_edges = get_edges(subgraph); + std::unordered_set concrete_edge_from_match = transform(get_edges(pattern), concrete_from_pattern); + + if (concrete_edges != concrete_edge_from_match) { + return false; + } + + std::unordered_set concrete_values = get_open_dataflow_values(subgraph); + std::unordered_set concrete_values_from_match = transform(get_values(pattern), concrete_from_pattern); + + if (concrete_values != concrete_values_from_match) { + return false; + } + + for (PatternValue const &pattern_value : get_values(pattern)) { + if (!additional_criterion.value_criterion(pattern_value, concrete_from_pattern(pattern_value))) { + return false; + } + } + + return true; } bool unlabelled_pattern_does_match( @@ -24,62 +164,12 @@ bool unlabelled_pattern_does_match( UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { + OpenDataflowGraphView matched_subgraph = subgraph_matched(matched_subgraph, match).graph; + + assert (keys(match.node_assignment) == get_nodes(pattern)); + assert (keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph)); - // PatternNode pattern_node = get_only(get_nodes(pattern)); - // Node matched_node = match.node_assignment.at_l(pattern_node); - // if (!additional_criterion.node_criterion(pattern_node, matched_node)) { - // return false; - // } - // - // for (PatternValue const &pattern_value : get_values(pattern)) { - // OpenDataflowValue matched_value = match.value_assignment.at_l(v); - // - // assert(is_input_edge(e) || is_output_edge(e)); - // if (is_input_edge(e)) { - // if (is_output_edge(matched_edge)) { - // return false; - // } - // UpwardOpenMultiDiEdge matched_edge = - // narrow(matched_edge).value(); - // InputPatternEdge input_edge = require_input_edge(e); - // if (match.node_assignment.at_l(get_dst_node(input_edge)) != - // get_dst_node(matched_edge)) { - // return false; - // } - // } else { - // if (is_input_edge(matched_edge)) { - // return false; - // } - // DownwardOpenMultiDiEdge matched_edge = - // narrow(matched_edge).value(); - // OutputPatternEdge output_edge = require_output_edge(e); - // if (match.node_assignment.at_l(get_src_node(output_edge)) != - // get_src_node(matched_edge)) { - // return false; - // } - // } - // - // if (!additional_criterion.value_criterion(pattern_value, matched_value)) { - // return false; - // } - // } - // - // return true; - // } - // - // PatternSplit split = find_even_split(pattern); - // std::pair subpatterns = - // apply_split(pattern, split); - // auto submatches = apply_split(pattern, match, split); - // - // return unlabelled_pattern_does_match(subpatterns.first, - // graph, - // submatches.prefix_submatch, - // additional_criterion) && - // unlabelled_pattern_does_match(subpatterns.second, - // graph, - // submatches.postfix_submatch, - // additional_criterion); + return pattern_matches_subgraph_under(pattern, matched_subgraph, match, additional_criterion); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc new file mode 100644 index 0000000000..9abdc4e83c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node_output.cc @@ -0,0 +1,13 @@ +#include "substitutions/unlabelled/pattern_node_output.h" + +namespace FlexFlow { + +PatternNode get_src_node(PatternNodeOutput const &o) { + return PatternNode{o.raw_dataflow_output.node}; +} + +int get_idx(PatternNodeOutput const &o) { + return o.raw_dataflow_output.idx; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc new file mode 100644 index 0000000000..287139dc30 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc @@ -0,0 +1,13 @@ +#include "substitutions/unlabelled/pattern_value.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenDataflowValue raw_dataflow_value_from_pattern_value(PatternValue const &v) { + return v.visit(overload { + [](PatternNodeOutput const &o) { return o.raw_dataflow_output; }, + [](PatternInput const &i) { return i.raw_dataflow_graph_input; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc new file mode 100644 index 0000000000..2e1a47f9e5 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc @@ -0,0 +1,13 @@ +#include "substitutions/unlabelled/standard_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_src_node(StandardPatternEdge const &) { + NOT_IMPLEMENTED(); +} + +PatternNode get_dst_node(StandardPatternEdge const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc new file mode 100644 index 0000000000..3f163aedb6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc @@ -0,0 +1,36 @@ +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" + +namespace FlexFlow { + +UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match() { + return UnlabelledDataflowGraphPatternMatch{ + bidict{}, + bidict{}, + }; +} + +template +std::optional> try_merge_nondisjoint_bidicts(bidict const &d1, + bidict const &d2) { + for (L const &l : intersection(keys(d1), keys(d2))) { + return + } +} + + +std::optional + merge_unlabelled_dataflow_graph_pattern_matches(UnlabelledDataflowGraphPatternMatch const &subpattern_1, + UnlabelledDataflowGraphPatternMatch const &subpattern_2, + bidict const &outputs_of_1_to_inputs_of_2) { + if (!are_disjoint(matched_nodes(subpattern_1), matched_nodes(subpattern_2))) { + return std::nullopt; + } + + bidict merged_node_assignment = merge_maps(subpattern_1.node_assignment, subpattern_2.node_assignment); + + // if (!are_disjoint(matched_values(subpattern_1), matched_values(subpattern_2))) { + // + // } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc deleted file mode 100644 index 8664f3c66c..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "substitutions/unlabelled/upward_open_pattern_edge.h" - -namespace FlexFlow { - -int get_dst_idx(UpwardOpenPatternEdge const &e) { - return get_src_idx(e.raw_edge); -} - -} // namespace FlexFlow diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 60df0caca3..b3d0db8822 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -20,6 +20,7 @@ #include #include #include +#include "utils/hash/pair.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index d79983d8ec..e04f1a92e1 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -11,6 +11,10 @@ struct DataflowGraph : virtual DataflowGraphView { public: NodeAddedResult add_node(std::vector const &inputs, int num_outputs); + + void add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs); std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(DataflowEdgeQuery const &) const; @@ -23,6 +27,15 @@ struct DataflowGraph : virtual DataflowGraphView { return DataflowGraph(make_cow_ptr()); } + template + static typename std::enable_if::value, + DataflowGraph>::type + create_copy_of(DataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return DataflowGraph(std::move(impl)); + } + protected: using DataflowGraphView::DataflowGraphView; diff --git a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h index 94fd54802b..1833fca35f 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/i_dataflow_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_I_DATAFLOW_GRAPH_H +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" #include "utils/graph/dataflow_graph/node_added_result.dtg.h" #include "utils/graph/dataflow_graph/i_dataflow_graph_view.h" @@ -9,6 +10,13 @@ namespace FlexFlow { struct IDataflowGraph : virtual public IDataflowGraphView { virtual NodeAddedResult add_node(std::vector const &inputs, int num_outputs) = 0; + + virtual void add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs) = 0; + + virtual void inplace_materialize_from(DataflowGraphView const &) = 0; + virtual IDataflowGraph *clone() const = 0; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IDataflowGraph); diff --git a/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h index 4e9e508e39..5f552f6d66 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h @@ -6,7 +6,7 @@ namespace FlexFlow { -struct UnorderedSetDataflowGraph : public IDataflowGraph { +struct UnorderedSetDataflowGraph final : public IDataflowGraph { public: UnorderedSetDataflowGraph(); @@ -17,6 +17,12 @@ struct UnorderedSetDataflowGraph : public IDataflowGraph { std::unordered_set query_edges(DataflowEdgeQuery const &) const override; std::unordered_set query_outputs(DataflowOutputQuery const &) const override; + void add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs) override; + + void inplace_materialize_from(DataflowGraphView const &view) override; + UnorderedSetDataflowGraph *clone() const override; private: UnorderedSetDataflowGraph(NodeSource const &node_source, diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index 579aeda83e..1812024f95 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -10,6 +10,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" #include "utils/containers/zip_vectors.h" #include "utils/containers/without_nullopts.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" namespace FlexFlow { @@ -41,6 +42,12 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo new_outputs, }; } + + DataflowGraphInput add_input(ValueLabel const &value_label) override { + DataflowGraphInput new_input = this->input_source.new_dataflow_graph_input(); + this->values.insert({OpenDataflowValue{new_input}, value_label}); + return new_input; + } std::unordered_set query_nodes(NodeQuery const &q) const override { return filter(keys(this->nodes), @@ -66,7 +73,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo })); } - std::vector get_inputs() const override { + std::unordered_set get_inputs() const override { return this->inputs; } @@ -81,6 +88,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo UnorderedSetLabelledOpenDataflowGraph *clone() const override { return new UnorderedSetLabelledOpenDataflowGraph{ this->node_source, + this->input_source, this->inputs, this->nodes, this->edges, @@ -89,11 +97,13 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo } private: UnorderedSetLabelledOpenDataflowGraph(NodeSource const &node_source, - std::vector const &inputs, + DataflowGraphInputSource const &input_source, + std::unordered_set const &inputs, std::unordered_map const &nodes, std::unordered_set const &edges, std::unordered_map const &values) : node_source(node_source), + input_source(input_source), inputs(inputs), nodes(nodes), edges(edges), @@ -102,7 +112,8 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo private: NodeSource node_source; - std::vector inputs; + DataflowGraphInputSource input_source; + std::unordered_set inputs; std::unordered_map nodes; std::unordered_set edges; std::unordered_map values; diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h index 94113a62ae..cb65138437 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h @@ -21,11 +21,11 @@ struct LabelledDataflowGraphView : virtual public DataflowGraphView { return this->get_interface().at(o); } - template + template static typename std::enable_if::value, LabelledDataflowGraphView>::type - create() { - return LabelledDataflowGraphView(make_cow_ptr()); + create(Args &&... args) { + return LabelledDataflowGraphView(make_cow_ptr(std::forward(args)...)); } protected: using DataflowGraphView::DataflowGraphView; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h index 969f0701cf..8c67ea879d 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -3,6 +3,8 @@ #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" +#include "utils/containers.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" namespace FlexFlow { @@ -13,8 +15,16 @@ template > LabelledOpenDataflowGraphView rewrite_labels(LabelledOpenDataflowGraphView const &g, F f) { - std::unordered_map node_labels = generate_map(get_nodes(g), f); - std::unordered_map value_labels = generate_map(get_nodes(g), f); + auto get_new_node_label = [&](Node const &n) -> NewNodeLabel { + return f(n, g.at(n)); + }; + + auto get_new_value_label = [&](OpenDataflowValue const &v) -> NewValueLabel { + return f(v, g.at(v)); + }; + + std::unordered_map node_labels = generate_map(get_nodes(g), get_new_node_label); + std::unordered_map value_labels = generate_map(get_open_dataflow_values(g), get_new_value_label); return with_labelling(g, node_labels, value_labels); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h index dd55b7ddf2..976879754d 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h @@ -30,7 +30,7 @@ struct OpenDataflowGraphLabellingWrapper final : public ILabelledOpenDataflowGra return this->unlabelled.query_outputs(q); } - std::vector get_inputs() const override { + std::unordered_set get_inputs() const override { return this->unlabelled.get_inputs(); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h index 4326d06283..1fe84179c2 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -14,6 +14,8 @@ struct ILabelledOpenDataflowGraph : virtual public ILabelledOpenDataflowGraphVie std::vector const &inputs, std::vector const &output_labels) = 0; + virtual DataflowGraphInput add_input(ValueLabel const &value_label) = 0; + NodeAddedResult add_node(NodeLabel const &node_label, std::vector const &inputs, std::vector const &output_labels) override final { diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h index 5c691fe225..93663d3615 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h @@ -11,6 +11,7 @@ template struct ILabelledOpenDataflowGraphView : virtual public ILabelledDataflowGraphView, virtual public IOpenDataflowGraphView { public: + virtual NodeLabel const &at(Node const &) const override = 0; virtual ValueLabel const &at(OpenDataflowValue const &) const = 0; ValueLabel const &at(DataflowOutput const &o) const override final { diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h index c41749e333..f0d6b6bd8f 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h @@ -20,6 +20,10 @@ struct LabelledOpenDataflowGraph : virtual public LabelledOpenDataflowGraphView< std::vector const &output_labels) { return this->get_interface().add_node(node_label, inputs, output_labels); } + + DataflowGraphInput add_input(ValueLabel const &value_label) { + return this->get_interface().add_input(value_label); + } protected: using LabelledOpenDataflowGraphView::LabelledOpenDataflowGraphView; private: diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h index e69b5958be..a9f74f5d41 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h @@ -16,18 +16,24 @@ struct LabelledOpenDataflowGraphView : virtual public LabelledDataflowGraphView< LabelledOpenDataflowGraphView(LabelledOpenDataflowGraphView const &) = default; LabelledOpenDataflowGraphView &operator=(LabelledOpenDataflowGraphView const &) = default; + NodeLabel const &at(Node const &n) const { + return this->get_interface().at(n); + } + ValueLabel const &at(OpenDataflowValue const &v) const { return this->get_interface().at(v); } template static typename std::enable_if::value, - LabelledOpenDataflowGraphView>::type + LabelledOpenDataflowGraphView>::type create(Args &&... args) { - return LabelledOpenDataflowGraphView(make_cow_ptr(std::forward(args)...)); + return LabelledOpenDataflowGraphView( + static_cast>(make_cow_ptr(std::forward(args)...))); } protected: - using LabelledDataflowGraphView::LabelledDataflowGraphView; + using OpenDataflowGraphView::OpenDataflowGraphView; + // using LabelledDataflowGraphView::LabelledDataflowGraphView; private: Interface const &get_interface() const { return *std::dynamic_pointer_cast(GraphView::ptr.get()); diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index ad8001b8e4..56b8a3f6b0 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -19,7 +19,7 @@ struct GraphView { } protected: - GraphView() : ptr(nullptr) {} + GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h index a8f6c32490..96ff08a08c 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h @@ -7,7 +7,7 @@ namespace FlexFlow { std::unordered_set get_edges(OpenDataflowGraphView const &); -std::vector get_inputs(OpenDataflowGraphView const &); +std::unordered_set get_inputs(OpenDataflowGraphView const &); std::vector get_inputs(OpenDataflowGraphView const &, Node const &); std::vector get_incoming_edges(OpenDataflowGraphView const &, Node const &); std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h index 0c89906a3f..ee1ddb4af9 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph.h @@ -8,8 +8,7 @@ namespace FlexFlow { OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &, - std::unordered_set const &, - std::vector const &); + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml index 6d047ed878..e9e52be893 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input.struct.toml @@ -7,6 +7,10 @@ features = [ "fmt", ] +includes = [ + "", +] + [[fields]] name = "idx" -type = "int" +type = "size_t" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input_source.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input_source.h new file mode 100644 index 0000000000..86fe66f052 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_graph_input_source.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_INPUT_SOURCE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_DATAFLOW_GRAPH_INPUT_SOURCE_H + +#include "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h" + +namespace FlexFlow { + +struct DataflowGraphInputSource { +public: + DataflowGraphInputSource(); + + DataflowGraphInput new_dataflow_graph_input(); +private: + static size_t next_available_uid; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h index c485e32c68..02799795ce 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h @@ -9,7 +9,7 @@ namespace FlexFlow { struct IOpenDataflowGraphView : virtual public IDataflowGraphView { - virtual std::vector get_inputs() const = 0; + virtual std::unordered_set get_inputs() const = 0; virtual std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const = 0; std::unordered_set query_edges(DataflowEdgeQuery const &) const override final; diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h index 0ffb067223..5f08464bbf 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h @@ -6,6 +6,7 @@ namespace FlexFlow { +Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &); int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &); OpenDataflowValue get_open_dataflow_edge_source(OpenDataflowEdge const &); OpenDataflowEdge open_dataflow_edge_from_src_and_dst(OpenDataflowValue const &src, DataflowInput const &dst); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h index b3875eb10d..cfd6d7f5dd 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h @@ -11,7 +11,7 @@ struct OpenDataflowGraphView : virtual DataflowGraphView { OpenDataflowGraphView(OpenDataflowGraphView const &) = default; OpenDataflowGraphView &operator=(OpenDataflowGraphView const &) = default; - std::vector get_inputs() const; + std::unordered_set get_inputs() const; std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const; template diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h b/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h index f251136eb0..d56178e122 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_UNORDERED_SET_OPEN_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_UNORDERED_SET_OPEN_DATAFLOW_GRAPH_H +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/i_open_dataflow_graph.h" #include "utils/graph/node/node_source.h" @@ -16,25 +17,27 @@ struct UnorderedSetOpenDataflowGraph : public IOpenDataflowGraph { std::unordered_set query_nodes(NodeQuery const &) const override; std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const override; std::unordered_set query_outputs(DataflowOutputQuery const &) const override; - std::vector get_inputs() const override; + std::unordered_set get_inputs() const override; DataflowGraphInput add_input() override; UnorderedSetOpenDataflowGraph *clone() const override; private: UnorderedSetOpenDataflowGraph(NodeSource const &node_source, + DataflowGraphInputSource const &input_source, std::unordered_set const &nodes, std::unordered_set const &standard_edges, std::unordered_set const &input_edges, std::unordered_set const &outputs, - std::vector const &graph_inputs); + std::unordered_set const &graph_inputs); private: NodeSource node_source; + DataflowGraphInputSource input_source; std::unordered_set nodes; std::unordered_set standard_edges; std::unordered_set input_edges; std::unordered_set outputs; - std::vector graph_inputs; + std::unordered_set graph_inputs; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetOpenDataflowGraph); diff --git a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h b/lib/utils/include/utils/graph/serial_parallel/graph_generation.h new file mode 100644 index 0000000000..5c351fb2a5 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/graph_generation.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H + +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +void parallel_extend_unsafe(DataflowGraph &g, + DataflowGraphView const &ext); + +void serial_extend(DataflowGraph &g, + DataflowGraphView const &ext); + +DataflowGraph serial_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); + +DataflowGraph parallel_composition(DataflowGraphView const &g1, + DataflowGraphView const &g2); + +DataflowGraph dataflow_graph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition); + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h b/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h deleted file mode 100644 index c82a8ec6b3..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/parallel.fwd.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_FWD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_FWD_H - -namespace FlexFlow { - -struct Parallel; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml b/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml deleted file mode 100644 index b8358a96c2..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/parallel.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "Parallel" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/serial_parallel/serial.fwd.h", - "", - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/fmt/variant.h", - "utils/hash/vector.h", -] - -trailing_includes = [ - "utils/graph/serial_parallel/serial.dtg.h", -] - -[[fields]] -name = "children" -type = "std::vector>" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h b/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h deleted file mode 100644 index 913b81434c..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial.fwd.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_FWD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_FWD_H - -namespace FlexFlow { - -struct Serial; - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml b/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml deleted file mode 100644 index 1a5fd2408e..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial.struct.toml +++ /dev/null @@ -1,29 +0,0 @@ -namespace = "FlexFlow" -name = "Serial" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/serial_parallel/parallel.fwd.h", - "", - "", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/fmt/vector.h", - "utils/fmt/variant.h", - "utils/hash/vector.h", -] - -trailing_includes = [ - "utils/graph/serial_parallel/parallel.dtg.h", -] - -[[fields]] -name = "children" -type = "std::vector>" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml index cd80f2dd3e..68d0af3c63 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml @@ -8,16 +8,15 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial.dtg.h", - "utils/graph/serial_parallel/parallel.dtg.h", + "utils/graph/serial_parallel/serial_parallel_splits.h", "utils/graph/node/node.dtg.h", ] [[values]] -type = "::FlexFlow::Serial" +type = "::FlexFlow::SerialSplit" [[values]] -type = "::FlexFlow::Parallel" +type = "::FlexFlow::ParallelSplit" [[values]] type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h new file mode 100644 index 0000000000..73e8f82b95 --- /dev/null +++ b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h @@ -0,0 +1,74 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H + +#include "utils/graph/node/node.dtg.h" +#include + +namespace FlexFlow { + +struct SerialSplit; +struct ParallelSplit; + +struct SerialSplit { +public: + SerialSplit() = delete; + explicit SerialSplit(std::vector> const &); + + bool operator==(SerialSplit const &) const; + bool operator!=(SerialSplit const &) const; + bool operator<(SerialSplit const &) const; + bool operator<=(SerialSplit const &) const; + bool operator>(SerialSplit const &) const; + bool operator>=(SerialSplit const &) const; + +public: + std::vector> children; + +private: + using Tie = std::tuple; + Tie tie() const; +}; + +std::string format_as(SerialSplit const &); +std::ostream &operator<<(std::ostream &, SerialSplit const &); + +struct ParallelSplit { +public: + ParallelSplit() = delete; + explicit ParallelSplit(std::vector> const &); + + bool operator==(ParallelSplit const &) const; + bool operator!=(ParallelSplit const &) const; + bool operator<(ParallelSplit const &) const; + bool operator<=(ParallelSplit const &) const; + bool operator>(ParallelSplit const &) const; + bool operator>=(ParallelSplit const &) const; + +public: + std::vector> children; + +private: + using Tie = std::tuple; + Tie tie() const; +}; + +std::string format_as(ParallelSplit const &); +std::ostream &operator<<(std::ostream &, ParallelSplit const &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::SerialSplit> { + size_t operator()(::FlexFlow::SerialSplit const &) const; +}; + +template <> +struct hash<::FlexFlow::ParallelSplit> { + size_t operator()(::FlexFlow::ParallelSplit const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serialparallel.h b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h index d032707efc..c6ed84e172 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serialparallel.h +++ b/lib/utils/include/utils/graph/serial_parallel/serialparallel.h @@ -9,35 +9,14 @@ namespace FlexFlow { -Node find_source_node(DiGraphView const &); -Node find_sink_node(DiGraphView const &); - -std::optional find_bottleneck_node(DiGraphView const &); - -struct Parallel; - SerialParallelDecomposition get_serial_parallel_decomposition(DiGraphView const &); std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(Serial const &); -std::unordered_set get_nodes(Parallel const &); +std::unordered_set get_nodes(SerialSplit const &); +std::unordered_set get_nodes(ParallelSplit const &); std::unordered_set get_nodes(Node const &); -// std::unordered_map parallel_extend(MultiDiGraph &g, -// MultiDiGraph const &ext); - -// std::unordered_map serial_extend(MultiDiGraph &g, -// MultiDiGraph const &ext); - -// MultiDiGraph serial_composition(MultiDiGraph const &g1, MultiDiGraph const &g2); - -// MultiDiGraph parallel_composition(MultiDiGraph const &g1, -// MultiDiGraph const &g2); - -// MultiDiGraph multidigraph_from_sp_decomposition( -// SerialParallelDecomposition const &sp_decomposition); - } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc index 18dc7516e8..6a3d804a39 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_graph.cc @@ -7,6 +7,12 @@ NodeAddedResult DataflowGraph::add_node(std::vector const &input return this->get_interface().add_node(inputs, num_outputs); } +void DataflowGraph::add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs) { + return this->get_interface().add_node_unsafe(node, inputs, outputs); +} + std::unordered_set DataflowGraph::query_nodes(NodeQuery const &q) const { return this->get_interface().query_nodes(q); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index ea58982a7d..4f4f7e8bd1 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -20,10 +20,16 @@ std::unordered_set get_incoming_edges(DiGraphView const &g, Node c } std::unordered_map> get_incoming_edges(DiGraphView const &g, std::unordered_set const &ns) { - return group_by(g.query_edges(DirectedEdgeQuery{ + std::unordered_map> result = group_by(g.query_edges(DirectedEdgeQuery{ query_set::matchall(), query_set{ns}, }), [](DirectedEdge const &e) { return e.dst; }); + + for (Node const &n : ns) { + result[n]; + } + + return result; } std::unordered_set get_outgoing_edges(DiGraphView const &g, Node const &n) { @@ -34,10 +40,16 @@ std::unordered_set get_outgoing_edges(DiGraphView const &g, Node c } std::unordered_map> get_outgoing_edges(DiGraphView const &g, std::unordered_set const &ns) { - return group_by(g.query_edges(DirectedEdgeQuery{ + std::unordered_map> result = group_by(g.query_edges(DirectedEdgeQuery{ query_set::matchall(), query_set{ns}, }), [](DirectedEdge const &e) { return e.src; }); + + for (Node const &n : ns) { + result[n]; + } + + return result; } std::unordered_set get_sources(DiGraphView const &g) { diff --git a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index 4fe36aecaf..1c4e6eb4bf 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -1,5 +1,7 @@ #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/containers/enumerate_vector.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -14,15 +16,11 @@ UnorderedSetDataflowGraph::UnorderedSetDataflowGraph(NodeSource const &node_sour NodeAddedResult UnorderedSetDataflowGraph::add_node(std::vector const &inputs, int num_outputs) { Node new_node = this->node_source.new_node(); - this->nodes.insert(new_node); - - for (auto const &[input_idx, input_src] : enumerate_vector(inputs)) { - this->edges.insert(DataflowEdge{input_src, DataflowInput{new_node, input_idx}}); - } std::vector new_outputs = transform(count(num_outputs), [&](int output_idx) { return DataflowOutput{new_node, output_idx}; }); - extend(this->outputs, new_outputs); + + this->add_node_unsafe(new_node, inputs, new_outputs); return NodeAddedResult{new_node, new_outputs}; } @@ -47,6 +45,31 @@ std::unordered_set UnorderedSetDataflowGraph::query_outputs(Data }); } +void UnorderedSetDataflowGraph::add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs) { + assert (!contains(this->nodes, node)); + assert (are_disjoint(this->outputs, without_order(outputs))); + + this->nodes.insert(node); + + for (auto const &[input_idx, input_src] : enumerate_vector(inputs)) { + this->edges.insert(DataflowEdge{input_src, DataflowInput{node, input_idx}}); + } + + extend(this->outputs, outputs); +} + +void UnorderedSetDataflowGraph::inplace_materialize_from(DataflowGraphView const &view) { + std::unordered_set nodes = get_nodes(view); + std::unordered_set edges = get_edges(view); + std::unordered_set outputs = get_all_dataflow_outputs(view); + + this->nodes = nodes; + this->edges = edges; + this->outputs = outputs; +} + UnorderedSetDataflowGraph *UnorderedSetDataflowGraph::clone() const { return new UnorderedSetDataflowGraph{ this->node_source, diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc index 2e9c4ac6e4..3d9fef6c6b 100644 --- a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.cc @@ -5,8 +5,8 @@ namespace FlexFlow { // TODO(@lockshaw) eventually move this over to tests struct Visitor { - std::string operator()(Node const &, int); - float operator()(OpenDataflowValue const &, int); + std::string operator()(Node const &, int) { NOT_IMPLEMENTED(); } + float operator()(OpenDataflowValue const &, int) { NOT_IMPLEMENTED(); } }; template LabelledOpenDataflowGraphView rewrite_labels(LabelledOpenDataflowGraphView const &, Visitor); diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc index 23fea301f6..1cd224dc1c 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -9,7 +9,7 @@ std::unordered_set get_edges(OpenDataflowGraphView const &g) { return g.query_edges(open_dataflow_edge_query_all()); } -std::vector get_inputs(OpenDataflowGraphView const &g) { +std::unordered_set get_inputs(OpenDataflowGraphView const &g) { return g.get_inputs(); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index c582117166..2234f4d283 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -1,5 +1,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/containers/enumerate_vector.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" #include "utils/overload.h" #include "utils/graph/node/algorithms.h" @@ -9,15 +11,12 @@ namespace FlexFlow { struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { OpenDataflowSubgraph(OpenDataflowGraphView const &full_graph, std::unordered_set const &subgraph_nodes, - bidict const &full_graph_values_to_subgraph_inputs, - std::vector const &subgraph_inputs) + bidict const &full_graph_values_to_subgraph_inputs) : full_graph(full_graph), subgraph_nodes(subgraph_nodes), - full_graph_values_to_subgraph_inputs(full_graph_values_to_subgraph_inputs), - subgraph_inputs(subgraph_inputs) + full_graph_values_to_subgraph_inputs(full_graph_values_to_subgraph_inputs) { assert(is_subseteq_of(this->subgraph_nodes, get_nodes(full_graph))); - assert(without_order(subgraph_inputs) == without_order(values(full_graph_values_to_subgraph_inputs))); } std::unordered_set query_nodes(NodeQuery const &q) const override { @@ -27,7 +26,7 @@ struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { std::unordered_set query_edges(OpenDataflowEdgeQuery const &q) const override { std::unordered_set result; for (OpenDataflowEdge const &open_e : this->full_graph.query_edges(q)) { - open_e.visit(overload { + open_e.visit(overload { [&](DataflowEdge const &e) { bool contains_src = contains(this->subgraph_nodes, e.src.node); bool contains_dst = contains(this->subgraph_nodes, e.dst.node); @@ -36,11 +35,13 @@ struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { } else if (contains_dst && !contains_src) { result.insert(OpenDataflowEdge{DataflowInputEdge{this->full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{e.src}), e.dst}}); } + return std::nullopt; }, [&](DataflowInputEdge const &e) { if (contains(this->subgraph_nodes, e.dst.node)) { result.insert(OpenDataflowEdge{DataflowInputEdge{this->full_graph_values_to_subgraph_inputs.at_l(OpenDataflowValue{e.src}), e.dst}}); } + return std::nullopt; } }); } @@ -54,8 +55,8 @@ struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { }); } - std::vector get_inputs() const override { - return this->subgraph_inputs; + std::unordered_set get_inputs() const override { + return without_order(values(this->full_graph_values_to_subgraph_inputs)); }; OpenDataflowSubgraph *clone() const override { @@ -63,34 +64,28 @@ struct OpenDataflowSubgraph final : public IOpenDataflowGraphView { this->full_graph, this->subgraph_nodes, this->full_graph_values_to_subgraph_inputs, - this->subgraph_inputs, }; } private: OpenDataflowGraphView full_graph; std::unordered_set subgraph_nodes; bidict full_graph_values_to_subgraph_inputs; - std::vector subgraph_inputs; }; OpenDataflowSubgraphResult get_subgraph(OpenDataflowGraphView const &g, - std::unordered_set const &subgraph_nodes, - std::vector const &input_ordering) { - std::vector subgraph_inputs; - bidict full_graph_values_to_subgraph_inputs; - for (auto const &[idx, full_graph_value] : enumerate_vector(input_ordering)) { - DataflowGraphInput subgraph_input = DataflowGraphInput{idx}; - subgraph_inputs.push_back(subgraph_input); - full_graph_values_to_subgraph_inputs.equate({full_graph_value, subgraph_input}); - } + std::unordered_set const &subgraph_nodes) { + DataflowGraphInputSource input_source; + bidict full_graph_values_to_subgraph_inputs = generate_bidict(get_subgraph_inputs(g, subgraph_nodes), + [&](OpenDataflowValue const &i) { + return input_source.new_dataflow_graph_input(); + }); return OpenDataflowSubgraphResult{ OpenDataflowGraphView::create( g, subgraph_nodes, - subgraph_inputs, - input_ordering), + full_graph_values_to_subgraph_inputs), full_graph_values_to_subgraph_inputs, }; } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input_source.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input_source.cc new file mode 100644 index 0000000000..42d2126068 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_graph_input_source.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" + +namespace FlexFlow { + +size_t DataflowGraphInputSource::next_available_uid = 0; + +DataflowGraphInputSource::DataflowGraphInputSource() { } + +DataflowGraphInput DataflowGraphInputSource::new_dataflow_graph_input() { + DataflowGraphInput result = DataflowGraphInput{DataflowGraphInputSource::next_available_uid}; + DataflowGraphInputSource::next_available_uid++; + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc index 256d16ea90..033eafaba0 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc @@ -3,6 +3,13 @@ namespace FlexFlow { +Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &e) { + return e.visit(overload { + [](DataflowEdge const &e) { return e.dst.node; }, + [](DataflowInputEdge const &e) { return e.dst.node; }, + }); +} + int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { return e.visit(overload { [](DataflowEdge const &e) { return e.dst.idx; }, @@ -10,7 +17,7 @@ int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { }); } -OpenDataflowValue get_open_dataflow_edge_source_value(OpenDataflowEdge const &open_e) { +OpenDataflowValue get_open_dataflow_edge_source(OpenDataflowEdge const &open_e) { return open_e.visit(overload { [](DataflowEdge const &e) { return OpenDataflowValue{e.src}; }, [](DataflowInputEdge const &e) { return OpenDataflowValue{e.src}; }, diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc index 8c031f68ec..795199e47b 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_graph_view.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -std::vector OpenDataflowGraphView::get_inputs() const { +std::unordered_set OpenDataflowGraphView::get_inputs() const { return this->get_interface().get_inputs(); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc index 66d416bdb8..7d5eea0081 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/unordered_set_open_dataflow_graph.cc @@ -5,12 +5,14 @@ namespace FlexFlow { UnorderedSetOpenDataflowGraph::UnorderedSetOpenDataflowGraph() {} UnorderedSetOpenDataflowGraph::UnorderedSetOpenDataflowGraph(NodeSource const &node_source, + DataflowGraphInputSource const &input_source, std::unordered_set const &nodes, std::unordered_set const &standard_edges, std::unordered_set const &input_edges, std::unordered_set const &outputs, - std::vector const &graph_inputs) + std::unordered_set const &graph_inputs) : node_source(node_source), + input_source(input_source), nodes(nodes), standard_edges(standard_edges), input_edges(input_edges), @@ -54,20 +56,20 @@ std::unordered_set UnorderedSetOpenDataflowGraph::query_outputs( }); } -std::vector UnorderedSetOpenDataflowGraph::get_inputs() const { +std::unordered_set UnorderedSetOpenDataflowGraph::get_inputs() const { return this->graph_inputs; } DataflowGraphInput UnorderedSetOpenDataflowGraph::add_input() { - int idx = this->graph_inputs.size(); - DataflowGraphInput result = DataflowGraphInput{idx}; - this->graph_inputs.push_back(result); - return result; + DataflowGraphInput new_input = this->input_source.new_dataflow_graph_input(); + this->graph_inputs.insert(new_input); + return new_input; } UnorderedSetOpenDataflowGraph *UnorderedSetOpenDataflowGraph::clone() const { return new UnorderedSetOpenDataflowGraph{ this->node_source, + this->input_source, this->nodes, this->standard_edges, this->input_edges, diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc new file mode 100644 index 0000000000..d9c3595e99 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc @@ -0,0 +1,48 @@ +#include "utils/graph/serial_parallel/graph_generation.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +void parallel_extend_unsafe(DataflowGraph &g, + DataflowGraphView const &ext) { + for (Node const &node : get_nodes(ext)) { + g.add_node_unsafe(node, + get_inputs(ext, node), + get_outputs(ext, node)); + } +} + +void serial_extend_unsafe(DataflowGraph &g, + DataflowGraphView const &ext) { + // TODO(@lockshaw): This function signature is impossible to implement in general, + // as there is no guarantee that the graph view ext actually has source nodes with inputs + // Either the signature should be changed, or an implementation should be added that throws + // an error if this problematic case is found + + NOT_IMPLEMENTED(); +} + +DataflowGraph serial_composition(DataflowGraphView const &g1, + DataflowGraphView const &g2) { + DataflowGraph g = DataflowGraph::create_copy_of(g1); + serial_extend_unsafe(g, g2); + return g; +} + +DataflowGraph parallel_composition(DataflowGraphView const &g1, + DataflowGraphView const &g2) { + DataflowGraph g = DataflowGraph::create_copy_of(g1); + parallel_extend_unsafe(g, g2); + return g; +} + +DataflowGraph dataflow_graph_from_sp_decomposition( + SerialParallelDecomposition const &sp_decomposition) { + // TODO(@lockshaw): see existing concerns about serial_extend_unsafe + + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc new file mode 100644 index 0000000000..e31ae31c72 --- /dev/null +++ b/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc @@ -0,0 +1,109 @@ +#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/hash-utils.h" +#include "utils/hash/vector.h" +#include "utils/fmt/vector.h" +#include "utils/fmt/variant.h" + +namespace FlexFlow { + +SerialSplit::SerialSplit( + std::vector> const &children) + : children(children) +{ } + +bool SerialSplit::operator==(SerialSplit const &other) const { + return this->tie() == other.tie(); +} + +bool SerialSplit::operator!=(SerialSplit const &other) const { + return this->tie() != other.tie(); +} + +bool SerialSplit::operator<(SerialSplit const &other) const { + return this->tie() < other.tie(); +} + +bool SerialSplit::operator<=(SerialSplit const &other) const { + return this->tie() <= other.tie(); +} + +bool SerialSplit::operator>(SerialSplit const &other) const { + return this->tie() > other.tie(); +} + +bool SerialSplit::operator>=(SerialSplit const &other) const { + return this->tie() >= other.tie(); +} + +SerialSplit::Tie SerialSplit::tie() const { + return std::tie(this->children); +} + +std::string format_as(SerialSplit const &split) { + return fmt::format("", split.children); +} + +std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { + return s << fmt::to_string(split); +} + +ParallelSplit::ParallelSplit( + std::vector> const &children) + : children(children) +{ } + +bool ParallelSplit::operator==(ParallelSplit const &other) const { + return this->tie() == other.tie(); +} + +bool ParallelSplit::operator!=(ParallelSplit const &other) const { + return this->tie() != other.tie(); +} + +bool ParallelSplit::operator<(ParallelSplit const &other) const { + return this->tie() < other.tie(); +} + +bool ParallelSplit::operator<=(ParallelSplit const &other) const { + return this->tie() <= other.tie(); +} + +bool ParallelSplit::operator>(ParallelSplit const &other) const { + return this->tie() > other.tie(); +} + +bool ParallelSplit::operator>=(ParallelSplit const &other) const { + return this->tie() >= other.tie(); +} + +ParallelSplit::Tie ParallelSplit::tie() const { + return std::tie(this->children); +} + +std::string format_as(ParallelSplit const &split) { + return fmt::format("", split.children); +} + +std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { + return s << fmt::to_string(split); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::SerialSplit> + ::operator()(::FlexFlow::SerialSplit const &s) const { + size_t result = 0; + ::FlexFlow::hash_combine(result, s.children); + return result; +} + +size_t hash<::FlexFlow::ParallelSplit> + ::operator()(::FlexFlow::ParallelSplit const &s) const { + size_t result = 0; + ::FlexFlow::hash_combine(result, s.children); + return result; +} + +} // namespace std diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc index 76aa3b2c00..918b96cbd6 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel.cc @@ -7,40 +7,6 @@ namespace FlexFlow { -Node find_source_node(DiGraphView const &g) { - std::unordered_set srcs = get_sources(g); - return get_only(srcs); -} - -Node find_sink_node(DiGraphView const &g) { - std::unordered_set sinks = get_sinks(g); - return get_only(sinks); -} - -std::optional find_bottleneck_node(DiGraphView const &g) { - std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sinks(g); - - std::optional maybe_bottleneck = get_imm_post_dominator(g, sources); - if (maybe_bottleneck.has_value()) { - assert(contains(get_dominators(g, sinks), maybe_bottleneck.value())); - } - return maybe_bottleneck; -} - -std::unordered_set from_source_to_sink(DiGraphView const &g, - Node const &src, - Node const &sink) { - assert(contains(get_dominators(g, sink), src)); - - std::vector bfs = get_bfs_ordering(g, {src}); - auto end = find(bfs, sink); - assert(end != bfs.end()); - - std::unordered_set result(bfs.cbegin(), ++end); - return result; -} - SerialParallelDecomposition get_serial_parallel_decomposition(DiGraphView const &g) { std::variant ast = sp_decomposition(g); @@ -51,17 +17,17 @@ std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { return sp.visit>([](auto &&t) { return get_nodes(t); }); } -std::unordered_set get_nodes(Serial const &serial) { +std::unordered_set get_nodes(SerialSplit const &serial) { return set_union(transform( serial.children, - [](std::variant const &child) -> std::unordered_set { + [](std::variant const &child) -> std::unordered_set { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(Parallel const ¶llel) { +std::unordered_set get_nodes(ParallelSplit const ¶llel) { return set_union( - transform(parallel.children, [](std::variant const &child) { + transform(parallel.children, [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } @@ -70,94 +36,5 @@ std::unordered_set get_nodes(Node const &node) { return {node}; } -// std::unordered_map parallel_extend(MultiDiGraph &g, -// MultiDiGraph const &ext) { -// std::unordered_map node_map; -// std::unordered_map node_port_map; -// for (Node const &node : get_nodes(MultiDiGraphView(ext))) { -// node_map.emplace(node, g.add_node()); -// } -// for (NodePort const &node_port : get_present_node_ports(ext)) { -// node_port_map.emplace(node_port, g.add_node_port()); -// } -// for (MultiDiEdge const &edge : get_edges(ext)) { -// g.add_edge(MultiDiEdge{node_map.at(edge.dst), -// node_port_map.at(edge.dst_idx), -// node_map.at(edge.src), -// node_port_map.at(edge.src_idx)}); -// } -// return node_map; -// } - -// std::unordered_map serial_extend(MultiDiGraph &g, -// MultiDiGraph const &ext) { -// std::unordered_set original_sinks = get_sinks(g); -// std::unordered_map node_map = parallel_extend(g, ext); -// for (Node const &node1 : original_sinks) { -// for (Node const &node2 : get_sources(ext)) { -// g.add_edge(MultiDiEdge{ -// node_map.at(node2), g.add_node_port(), node1, g.add_node_port()}); -// } -// } -// return node_map; -// } - -// MultiDiGraph serial_composition(MultiDiGraph const &g1, -// MultiDiGraph const &g2) { -// MultiDiGraph g = g1; -// serial_extend(g, g2); -// return g; -// } - -// MultiDiGraph parallel_composition(MultiDiGraph const &g1, -// MultiDiGraph const &g2) { -// MultiDiGraph g = g1; -// parallel_extend(g, g2); -// return g; -// } - -// struct MultiDiGraphFromSPDecompositionFunctor { -// template -// MultiDiGraph operator()(T const &t) { -// return multidigraph_from_sp_decomposition(t); -// } -// }; - -// MultiDiGraph multidigraph_from_sp_decomposition( -// SerialParallelDecomposition const &sp_decomposition) { -// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -// } - -// MultiDiGraph multidigraph_from_sp_decomposition( -// std::variant const &sp_decomposition) { -// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -// } - -// MultiDiGraph multidigraph_from_sp_decomposition( -// std::variant const &sp_decomposition) { -// return visit(MultiDiGraphFromSPDecompositionFunctor{}, sp_decomposition); -// } - -// MultiDiGraph multidigraph_from_sp_decomposition(Serial const &serial) { -// MultiDiGraph g = MultiDiGraph::create(); -// for (std::variant const &child : serial.children) { -// serial_extend(g, multidigraph_from_sp_decomposition(child)); -// } -// return g; -// } - -// MultiDiGraph multidigraph_from_sp_decomposition(Parallel const ¶llel) { -// MultiDiGraph g = MultiDiGraph::create(); -// for (std::variant const &child : parallel.children) { -// parallel_extend(g, multidigraph_from_sp_decomposition(child)); -// } -// return g; -// } - -// MultiDiGraph multidigraph_from_sp_decomposition(Node const &Node) { -// MultiDiGraph g = MultiDiGraph::create(); -// g.add_node(); -// return g; -// } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc index 7945ee6273..5a81dc0dfb 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.cc @@ -7,6 +7,40 @@ namespace FlexFlow { +Node find_source_node(DiGraphView const &g) { + std::unordered_set srcs = get_sources(g); + return get_only(srcs); +} + +Node find_sink_node(DiGraphView const &g) { + std::unordered_set sinks = get_sinks(g); + return get_only(sinks); +} + +std::optional find_bottleneck_node(DiGraphView const &g) { + std::unordered_set sources = get_sources(g); + std::unordered_set sinks = get_sinks(g); + + std::optional maybe_bottleneck = get_imm_post_dominator(g, sources); + if (maybe_bottleneck.has_value()) { + assert(contains(get_dominators(g, sinks), maybe_bottleneck.value())); + } + return maybe_bottleneck; +} + +std::unordered_set from_source_to_sink(DiGraphView const &g, + Node const &src, + Node const &sink) { + assert(contains(get_dominators(g, sink), src)); + + std::vector bfs = get_bfs_ordering(g, {src}); + auto end = find(bfs, sink); + assert(end != bfs.end()); + + std::unordered_set result(bfs.cbegin(), ++end); + return result; +} + std::unordered_set from_source_to_sink(DiGraphView const &g, std::unordered_set const &srcs, @@ -128,25 +162,25 @@ std::variant flatten_ast(std::variant operator()(IntermediateSpDecompositionTree const &node) { + std::variant operator()(IntermediateSpDecompositionTree const &node) { if (node.type == SplitType::SERIAL) { - return Serial{transform(node.children, [](std::variant const &s) { - return narrow>(internal_to_final_ast(s)).value(); + return SerialSplit{transform(node.children, [](std::variant const &s) { + return narrow>(internal_to_final_ast(s)).value(); })}; } else { - return Parallel{transform(node.children, [](std::variant const &s) { - return narrow>(internal_to_final_ast(s)).value(); + return ParallelSplit{transform(node.children, [](std::variant const &s) { + return narrow>(internal_to_final_ast(s)).value(); })}; } } - std::variant operator()(Node const &node) { + std::variant operator()(Node const &node) { return node; } }; -std::variant internal_to_final_ast(std::variant const &ast) { - return visit(ToFinalAST{}, ast); +std::variant internal_to_final_ast(std::variant const &ast) { + return std::visit(ToFinalAST{}, ast); } SerialParallelDecomposition to_final_ast(std::variant const &ast) { diff --git a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h index 6b7671d5ef..4ffe537228 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h +++ b/lib/utils/src/utils/graph/serial_parallel/serialparallel_internal.h @@ -7,18 +7,33 @@ #include "utils/visitable.h" #include #include +#include "./source_settings.dtg.h" +#include "./sink_settings.dtg.h" namespace FlexFlow { -struct ParallelInternal; +Node find_source_node(DiGraphView const &); +Node find_sink_node(DiGraphView const &); +std::optional find_bottleneck_node(DiGraphView const &); std::variant sp_decomposition(DiGraphView const &g); IntermediateSpDecompositionTree parallel_decomposition(DiGraphView const &g); std::unordered_set from_source_to_sink(DiGraphView const &, Node const &src, Node const &sink); - -std::variant internal_to_final_ast(std::variant const &); +std::unordered_set + from_source_to_sink(DiGraphView const &g, + std::unordered_set const &srcs, + std::unordered_set const &sinks, + SourceSettings include_src, + SinkSettings include_sink); +DiGraphView source_to_sink_subgraph(DiGraphView const &g, + std::unordered_set const &srcs, + std::unordered_set const &sinks, + SourceSettings include_src, + SinkSettings include_sink); + +std::variant internal_to_final_ast(std::variant const &ast); SerialParallelDecomposition to_final_ast(std::variant const &); std::variant flatten_ast(std::variant const &ast); From 5fd666db1ffa81ed7b2da94ccb19104bd1e7c9b2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 12 Jul 2024 21:45:03 -0700 Subject: [PATCH 16/71] Get substitutions building --- flake.lock | 6 +- .../op-attrs/ops/conv_2d_attrs.struct.toml | 4 + .../ops/element_unary_attrs.struct.toml | 8 +- .../op-attrs/ops/embedding_attrs.struct.toml | 4 + .../op-attrs/ops/linear_attrs.struct.toml | 4 + lib/op-attrs/src/op-attrs/datatype.cc | 2 + .../constant_initializer_attrs.struct.toml | 4 + lib/pcg/include/pcg/layer_attrs.struct.toml | 4 + .../parallel_layer_attrs.struct.toml | 4 + .../parallel_tensor_attrs.struct.toml | 4 + lib/pcg/include/pcg/tensor_attrs.struct.toml | 5 +- .../operator_attribute_value.variant.toml | 5 + .../unlabelled/input_pattern_edge.struct.toml | 5 +- .../substitutions/unlabelled/match_split.h | 16 +- ...truct.toml => match_split.struct.toml.old} | 4 +- .../unlabelled/output_pattern_edge.h | 13 -- .../output_pattern_edge.struct.toml | 15 -- .../substitutions/unlabelled/pattern_edge.h | 13 +- .../unlabelled/pattern_matching.h | 1 - .../unlabelled/pattern_node.struct.toml | 3 +- .../substitutions/unlabelled/pattern_split.h | 10 +- .../unlabelled/pattern_split.struct.toml | 5 +- .../substitutions/unlabelled/pattern_value.h | 3 +- .../unlabelled_dataflow_graph_pattern_match.h | 2 + ...d_dataflow_graph_pattern_match.struct.toml | 8 +- .../src/substitutions/pcg_pattern.cc | 2 +- .../unlabelled/find_pattern_matches.cc | 3 +- .../unlabelled/input_pattern_edge.cc | 6 +- .../substitutions/unlabelled/match_split.cc | 138 +++++++++--------- .../unlabelled/output_pattern_edge.cc | 9 -- .../substitutions/unlabelled/pattern_edge.cc | 47 +++--- .../unlabelled/pattern_matching.cc | 21 ++- .../substitutions/unlabelled/pattern_split.cc | 55 ++++--- .../substitutions/unlabelled/pattern_value.cc | 13 +- .../unlabelled/standard_pattern_edge.cc | 16 +- ...unlabelled_dataflow_graph_pattern_match.cc | 42 ++++-- .../unlabelled/unlabelled_graph_pattern.cc | 24 +-- .../test/src/test_pattern_matches.cc | 7 +- lib/utils/include/utils/bidict.h | 8 + lib/utils/include/utils/fmt.decl.h | 11 -- lib/utils/include/utils/fmt.h | 21 --- lib/utils/include/utils/fmt/optional.h | 19 +++ ...ordered_set_labelled_open_dataflow_graph.h | 1 + .../include/utils/graph/node/i_graph_view.h | 2 + lib/utils/include/utils/graph/query_set.h | 2 +- lib/utils/include/utils/hash/map.h | 1 + lib/utils/include/utils/hash/unordered_map.h | 1 + lib/utils/src/utils/fmt/optional.cc | 1 + .../utils/graph/dataflow_graph/algorithms.cc | 1 + .../src/utils/graph/digraph/algorithms.cc | 1 + .../graph/instances/adjacency_digraph.cc | 1 + .../instances/unordered_set_dataflow_graph.cc | 1 + .../graph/open_dataflow_graph/algorithms.cc | 1 + .../algorithms/get_subgraph.cc | 1 + .../i_open_dataflow_graph_view.cc | 1 + .../test/common/include/test/utils/doctest.h | 1 - 56 files changed, 340 insertions(+), 270 deletions(-) rename lib/substitutions/include/substitutions/unlabelled/{match_split.struct.toml => match_split.struct.toml.old} (69%) delete mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h delete mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml delete mode 100644 lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc create mode 100644 lib/utils/include/utils/fmt/optional.h create mode 100644 lib/utils/src/utils/fmt/optional.cc diff --git a/flake.lock b/flake.lock index 2d1157ba40..26b7339f5c 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1719373251, - "narHash": "sha256-n1Rm8vOflScty0XRzkjUvJEFHfDhyqTriZZ8AFZJbT0=", + "lastModified": 1720843291, + "narHash": "sha256-RurMfG9Enp29u3L1/Yj+IDn0aWsW6hEg8JiM9D5aSkM=", "owner": "lockshaw", "repo": "proj", - "rev": "a01d0b1e60a05703c5e42f9e924a183a65032de8", + "rev": "faea146e54ba092e3a26feee60387090b97783bc", "type": "github" }, "original": { diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 353ef93004..2fb385b64d 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -15,6 +15,10 @@ includes = [ "utils/json.h", ] +src_includes = [ + "utils/fmt/optional.h", +] + fields = [ { name = "out_channels", type = "int" }, { name = "kernel_h", type = "int" }, diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 0122255be2..4b9c8a9f45 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -12,7 +12,11 @@ features = [ includes = [ "utils/json.h", - "op-attrs/operator_type.h" + "op-attrs/operator_type.h", +] + +src_includes = [ + "utils/fmt/optional.h", ] [[fields]] @@ -21,4 +25,4 @@ type = "::FlexFlow::OperatorType" [[fields]] name = "scalar" -type = "std::optional" \ No newline at end of file +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index f0772c351e..38d5a4371e 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -15,6 +15,10 @@ includes = [ "op-attrs/datatype.dtg.h", ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "num_entries" type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index 4ac8f83ec9..eaa34cc496 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -16,6 +16,10 @@ includes = [ "utils/json.h", ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "out_channels" type = "int" diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc index 06d99db702..bbe8501ecc 100644 --- a/lib/op-attrs/src/op-attrs/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -1,4 +1,6 @@ #include "op-attrs/datatype.h" +#include "utils/exception.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 511ec057fa..12917d0989 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -14,6 +14,10 @@ includes = [ "utils/json.h", ] +src_includes = [ + "utils/fmt/variant.h", +] + [[fields]] name = "value" type = "::FlexFlow::DataTypeValue" diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index 9f8aaa5ba3..d062f6cd78 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -16,6 +16,10 @@ includes = [ "utils/json.h" ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "attrs" type = "::FlexFlow::ComputationGraphOpAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 1ba9ac5487..60cfc426cc 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -15,6 +15,10 @@ includes = [ "", ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "op_attrs" type = "::FlexFlow::PCGOperatorAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index faf7159ad7..d9e6cf113b 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -17,6 +17,10 @@ includes = [ "", ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "shape" type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml index 260cb9e68f..c0b89cfc99 100644 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -17,6 +17,10 @@ includes = [ "", ] +src_includes = [ + "utils/fmt/optional.h", +] + [[fields]] name = "shape" type = "::FlexFlow::TensorShape" @@ -29,7 +33,6 @@ type = "std::optional<::FlexFlow::InitializerAttrs>" name = "sync_type" type = "std::optional<::FlexFlow::ParamSync>" - [[fields]] name = "create_gradients" type = "::FlexFlow::CreateGrad" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 29ae87afb7..da2feb1903 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -21,6 +21,11 @@ includes = [ "op-attrs/tensor_shape.dtg.h", "op-attrs/datatype.dtg.h", "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml index abe9c6f768..e4203cf495 100644 --- a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml @@ -4,12 +4,13 @@ features = [ "eq", "ord", "hash", + "fmt", ] includes = [ - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" + "utils/graph/open_dataflow_graph/dataflow_input_edge.dtg.h", ] [[fields]] name = "raw_edge" -type = "::FlexFlow::OpenDataflowEdge" +type = "::FlexFlow::DataflowInputEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h index 957ce6eaa0..b07209fdb8 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -1,17 +1,17 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H -#include "substitutions/unlabelled/match_split.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" -#include "substitutions/unlabelled/pattern_split.dtg.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +// #include "substitutions/unlabelled/match_split.dtg.h" +// #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" +// #include "substitutions/unlabelled/pattern_split.dtg.h" +// #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { -MatchSplit empty_match_split(); -MatchSplit apply_split(UnlabelledGraphPattern const &pattern, - UnlabelledDataflowGraphPatternMatch const &match, - PatternSplit const &split); +// MatchSplit empty_match_split(); +// MatchSplit apply_split(UnlabelledGraphPattern const &pattern, +// UnlabelledDataflowGraphPatternMatch const &match, +// PatternSplit const &split); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml.old similarity index 69% rename from lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml rename to lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml.old index 05c7451351..4439e687e1 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml.old @@ -13,8 +13,8 @@ includes = [ [[fields]] name = "prefix_submatch" -type = "UnlabelledDataflowGraphPatternMatch" +type = "::FlexFlow::UnlabelledDataflowGraphPatternMatch" [[fields]] name = "postfix_submatch" -type = "UnlabelledDataflowGraphPatternMatch" +type = "::FlexFlow::UnlabelledDataflowGraphPatternMatch" diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h deleted file mode 100644 index 72e8ff02cf..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H -#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H - -#include "substitutions/unlabelled/output_pattern_edge.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" - -namespace FlexFlow { - -PatternNode get_src_node(OutputPatternEdge const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml deleted file mode 100644 index 362cbc3265..0000000000 --- a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +++ /dev/null @@ -1,15 +0,0 @@ -namespace = "FlexFlow" -name = "OutputPatternEdge" -features = [ - "eq", - "ord", - "hash", -] - -includes = [ - "utils/graph.h", -] - -[[fields]] -name = "raw_edge" -type = "::FlexFlow::OutputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h index 12405af184..f6050ea8c1 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -1,28 +1,25 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H -#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" #include "substitutions/unlabelled/input_pattern_edge.dtg.h" -#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/standard_pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" +#include namespace FlexFlow { PatternNode get_dst_node(PatternEdge const &); std::unordered_set get_nodes(PatternEdge const &); -bool is_closed_edge(PatternEdge const &); bool is_input_edge(PatternEdge const &); -bool is_output_edge(PatternEdge const &); +bool is_standard_edge(PatternEdge const &); -ClosedPatternEdge require_closed_edge(PatternEdge const &); +StandardPatternEdge require_standard_edge(PatternEdge const &); InputPatternEdge require_input_edge(PatternEdge const &); -OutputPatternEdge require_output_edge(PatternEdge const &); PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); -PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &); -PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &); +PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index 2af7bbf138..381a4933e3 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H #include "substitutions/unlabelled/match_additional_criterion.dtg.h" -#include "substitutions/unlabelled/match_split.dtg.h" #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "utils/graph.h" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml index ecd0253516..a3bcc83249 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml @@ -4,10 +4,11 @@ features = [ "eq", "ord", "hash", + "fmt", ] includes = [ - "utils/graph.h", + "utils/graph/node/node.dtg.h", ] [[fields]] diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h index ff67c882df..058ebe0b56 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -10,12 +10,12 @@ namespace FlexFlow { PatternSplit find_even_split(UnlabelledGraphPattern const &); -GraphSplit get_raw_split(PatternSplit const &); - -UnlabelledPatternEdgeSplits - get_edge_splits(UnlabelledGraphPattern const &pattern, - PatternSplit const &split); +// GraphSplit get_raw_split(PatternSplit const &); +// UnlabelledPatternEdgeSplits +// get_edge_splits(UnlabelledGraphPattern const &pattern, +// PatternSplit const &split); +// PatternSplitResult apply_split(UnlabelledGraphPattern const &, PatternSplit const &); diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml index 04d1080ff7..7496a798ee 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml @@ -3,13 +3,16 @@ name = "PatternSplit" features = [ "eq", # "ord", - "json", + "hash", + # "json", "fmt", ] includes = [ "utils/graph.h", "", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", "substitutions/unlabelled/pattern_node.dtg.h", ] diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_value.h b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h index 1ae391f080..48613bc78e 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_value.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_value.h @@ -6,7 +6,8 @@ namespace FlexFlow { -OpenDataflowValue raw_dataflow_value_from_pattern_value(PatternValue const &); +OpenDataflowValue raw_open_dataflow_value_from_pattern_value(PatternValue const &); +PatternValue pattern_value_from_raw_open_dataflow_value(OpenDataflowValue const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h index 3a459e69b4..66c23ba135 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h @@ -3,6 +3,8 @@ #include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" +#include +#include namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml index 064bc85d2a..53c3fe4829 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.struct.toml @@ -10,11 +10,15 @@ features = [ includes = [ "utils/bidict.h", "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", "substitutions/unlabelled/pattern_input.dtg.h", "substitutions/unlabelled/pattern_node.dtg.h", "", +] + +src_includes = [ "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", ] [[fields]] @@ -23,4 +27,4 @@ type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" [[fields]] name = "input_assignment" -type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::DataflowGraphInput>" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::OpenDataflowValue>" diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index fe78ebd266..a5fe879696 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -12,7 +12,7 @@ UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { TensorAttributePattern get_tensor_pattern(PCGPattern const &p, PatternValue const &v) { - return p.raw_graph.at(raw_dataflow_value_from_pattern_value(v)); + return p.raw_graph.at(raw_open_dataflow_value_from_pattern_value(v)); } OperatorAttributePattern get_operator_pattern(PCGPattern const &p, diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 84d0fae324..009e955350 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -51,7 +51,7 @@ static std::optional match.input_assignment.insert({ pattern_node_input.get(), - graph_node_input.get(), + graph_node_input, }); } @@ -83,7 +83,6 @@ std::vector std::vector postfix_matches = find_pattern_matches(subpatterns.subpattern_2, graph, additional_criterion); - auto edge_splits = get_edge_splits(pattern, split); for (UnlabelledDataflowGraphPatternMatch const &prefix_match : prefix_matches) { for (UnlabelledDataflowGraphPatternMatch const &postfix_match : postfix_matches) { std::optional unsplit = diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc index a45154cf05..7e993345b5 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -3,12 +3,12 @@ namespace FlexFlow { -PatternInput get_src_input(InputPatternEdge const &) { - NOT_IMPLEMENTED(); +PatternInput get_src_input(InputPatternEdge const &e) { + return PatternInput{e.raw_edge.src}; } PatternNode get_dst_node(InputPatternEdge const &e) { - return PatternNode{get_open_dataflow_edge_dst_node(e.raw_edge)}; + return PatternNode{e.raw_edge.dst.node}; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.cc index ef0397d6a8..f6c961c93f 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_split.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.cc @@ -1,69 +1,69 @@ -#include "substitutions/unlabelled/match_split.h" -#include "substitutions/unlabelled/edge_splits.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.h" -#include "substitutions/unlabelled/pattern_edge.h" -#include "substitutions/unlabelled/pattern_split.h" - -namespace FlexFlow { - -MatchSplit empty_match_split() { - return MatchSplit{empty_multidigraph_pattern_match(), - empty_multidigraph_pattern_match()}; -} - -MatchSplit apply_split(UnlabelledGraphPattern const &pattern, - MultiDiGraphPatternMatch const &match, - PatternSplit const &split) { - std::unordered_set prefix = split.first; - std::unordered_set postfix = split.second; - - MatchSplit result = empty_match_split(); - - for (auto const &[pattern_node, match_node] : match.node_assignment) { - if (contains(split.first, pattern_node)) { - result.prefix_submatch.node_assignment.equate(pattern_node, match_node); - } else { - assert(contains(split.second, pattern_node)); - result.postfix_submatch.node_assignment.equate(pattern_node, match_node); - } - } - - UnlabelledPatternEdgeSplits edge_splits = get_edge_splits(pattern, split); - - std::function - handle_edge = [&](PatternEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) -> void { - std::unordered_set edge_nodes = get_nodes(pattern_edge); - - if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else { - assert(is_standard_edge(graph_edge)); - - ClosedPatternEdge closed_edge = require_closed_edge(pattern_edge); - - auto split = get_split_edges(edge_splits, closed_edge); - OutputPatternEdge output_edge = split.first; - InputPatternEdge input_edge = split.second; - - auto split_graph_edge = split_edge(std::get(graph_edge)); - OutputMultiDiEdge output_graph_edge = split_graph_edge.first; - InputMultiDiEdge input_graph_edge = split_graph_edge.second; - - handle_edge(pattern_edge_from_input_edge(input_edge), - OpenMultiDiEdge{input_graph_edge}); - handle_edge(pattern_edge_from_output_edge(output_edge), - OpenMultiDiEdge{output_graph_edge}); - } - }; - - for (auto const &[pattern_edge, match_edge] : match.edge_assignment) { - handle_edge(pattern_edge, match_edge); - } - - return result; -} - -} // namespace FlexFlow +// #include "substitutions/unlabelled/match_split.h" +// #include "substitutions/unlabelled/edge_splits.h" +// #include "substitutions/unlabelled/pattern_edge.h" +// #include "substitutions/unlabelled/pattern_split.h" +// #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" +// +// namespace FlexFlow { +// +// MatchSplit empty_match_split() { +// return MatchSplit{empty_unlabelled_pattern_match(), +// empty_unlabelled_pattern_match()}; +// } +// +// MatchSplit apply_split(UnlabelledGraphPattern const &pattern, +// UnlabelledDataflowGraphPatternMatch const &match, +// PatternSplit const &split) { +// std::unordered_set prefix = split.first; +// std::unordered_set postfix = split.second; +// +// MatchSplit result = empty_match_split(); +// +// for (auto const &[pattern_node, match_node] : match.node_assignment) { +// if (contains(split.first, pattern_node)) { +// result.prefix_submatch.node_assignment.equate(pattern_node, match_node); +// } else { +// assert(contains(split.second, pattern_node)); +// result.postfix_submatch.node_assignment.equate(pattern_node, match_node); +// } +// } +// +// UnlabelledPatternEdgeSplits edge_splits = get_edge_splits(pattern, split); +// +// std::function +// handle_edge = [&](PatternEdge const &pattern_edge, +// OpenMultiDiEdge const &graph_edge) -> void { +// std::unordered_set edge_nodes = get_nodes(pattern_edge); +// +// if (is_subseteq_of(edge_nodes, prefix)) { +// result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); +// } else if (is_subseteq_of(edge_nodes, postfix)) { +// result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); +// } else { +// assert(is_standard_edge(graph_edge)); +// +// ClosedPatternEdge closed_edge = require_closed_edge(pattern_edge); +// +// auto split = get_split_edges(edge_splits, closed_edge); +// OutputPatternEdge output_edge = split.first; +// InputPatternEdge input_edge = split.second; +// +// auto split_graph_edge = split_edge(std::get(graph_edge)); +// OutputMultiDiEdge output_graph_edge = split_graph_edge.first; +// InputMultiDiEdge input_graph_edge = split_graph_edge.second; +// +// handle_edge(pattern_edge_from_input_edge(input_edge), +// OpenMultiDiEdge{input_graph_edge}); +// handle_edge(pattern_edge_from_output_edge(output_edge), +// OpenMultiDiEdge{output_graph_edge}); +// } +// }; +// +// for (auto const &[pattern_edge, match_edge] : match.edge_assignment) { +// handle_edge(pattern_edge, match_edge); +// } +// +// return result; +// } +// +// } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc deleted file mode 100644 index 6e70fc8df6..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "substitutions/unlabelled/output_pattern_edge.h" - -namespace FlexFlow { - -PatternNode get_src_node(OutputPatternEdge const &e) { - return PatternNode{e.raw_edge.src}; -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc index 3dd4987705..5a6ca41281 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -1,50 +1,49 @@ #include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/standard_pattern_edge.h" +#include "substitutions/unlabelled/input_pattern_edge.h" #include "utils/containers.h" +#include "utils/overload.h" namespace FlexFlow { std::unordered_set get_nodes(PatternEdge const &e) { - return transform(get_nodes(e.raw_edge), - [](Node const &n) { return PatternNode{n}; }); + return e.visit>(overload { + [](InputPatternEdge const &ee) { + return std::unordered_set{get_dst_node(ee)}; + }, + [](StandardPatternEdge const &ee) { + return std::unordered_set{ + get_src_node(ee), + get_dst_node(ee), + }; + }, + }); } bool is_standard_edge(PatternEdge const &e) { - return is_standard_edge(e.raw_edge); + return e.has(); } bool is_input_edge(PatternEdge const &e) { - return is_input_edge(e.raw_edge); + return e.has(); } -bool is_output_edge(PatternEdge const &e) { - return is_output_edge(e.raw_edge); -} - -ClosedPatternEdge require_closed_edge(PatternEdge const &e) { - assert(is_closed_edge(e)); - return ClosedPatternEdge{std::get(e.raw_edge)}; +StandardPatternEdge require_standard_edge(PatternEdge const &e) { + assert(is_standard_edge(e)); + return e.get(); } InputPatternEdge require_input_edge(PatternEdge const &e) { assert(is_input_edge(e)); - return InputPatternEdge{std::get(e.raw_edge)}; -} - -OutputPatternEdge require_output_edge(PatternEdge const &e) { - assert(is_output_edge(e)); - return OutputPatternEdge{std::get(e.raw_edge)}; + return e.get(); } PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &e) { - return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; -} - -PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &e) { - return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; + return PatternEdge{e}; } -PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &e) { - return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &e) { + return PatternEdge{e}; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 89888a7c2f..9081f90f08 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -9,6 +9,7 @@ #include "utils/graph/node/algorithms.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" #include "substitutions/unlabelled/pattern_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" @@ -67,27 +68,25 @@ struct ConcreteFromPattern { UnlabelledDataflowGraphPatternMatch const &match; - - Node operator()(PatternNode const &n) const { return match.node_assignment.at_l(n); } - DataflowGraphInput operator()(PatternInput const &i) { + OpenDataflowValue operator()(PatternInput const &i) const { return match.input_assignment.at(i); } - DataflowInputEdge operator()(InputPatternEdge const &e) { - return DataflowInputEdge{ + OpenDataflowEdge operator()(InputPatternEdge const &e) const { + return open_dataflow_edge_from_src_and_dst( this->operator()(get_src_input(e)), DataflowInput{ this->operator()(get_dst_node(e)), get_dst_idx(e), - }, - }; + } + ); } - DataflowEdge operator()(StandardPatternEdge const &e) { + DataflowEdge operator()(StandardPatternEdge const &e) const { return DataflowEdge{ DataflowOutput{ this->operator()(get_src_node(e)), @@ -100,15 +99,15 @@ struct ConcreteFromPattern { }; } - OpenDataflowEdge operator()(PatternEdge const &pattern_e) { + OpenDataflowEdge operator()(PatternEdge const &pattern_e) const { return pattern_e.visit([&](auto const &e) { return OpenDataflowEdge{this->operator()(e)}; }); } - OpenDataflowValue operator()(PatternValue const &pattern_v) { + OpenDataflowValue operator()(PatternValue const &pattern_v) const { return pattern_v.visit([&](auto const &v) { return OpenDataflowValue{this->operator()(v)}; }); } - DataflowOutput operator()(PatternNodeOutput const &o) { + DataflowOutput operator()(PatternNodeOutput const &o) const { return DataflowOutput{ this->operator()(get_src_node(o)), get_idx(o), diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc index e116c062df..b0be4617dd 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -1,10 +1,14 @@ #include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/pattern_value.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" namespace FlexFlow { -PatternSplit find_even_split(UnlabelledGraphPattern const &p) { +PatternSplit find_even_split(UnlabelledGraphPattern const &pattern) { std::vector topological_ordering = - get_topological_ordering(pattern.raw_graph); + transform(get_topological_ordering(pattern.raw_graph), + [](Node const &raw_node) { return PatternNode{raw_node}; }); assert(topological_ordering.size() >= 2); int split_point = topological_ordering.size() / 2; @@ -13,29 +17,40 @@ PatternSplit find_even_split(UnlabelledGraphPattern const &p) { split.first.end()); std::unordered_set postfix(split.second.begin(), split.second.end()); - return {prefix, postfix}; + return PatternSplit{prefix, postfix}; } -GraphSplit get_raw_split(PatternSplit const &s) { - return std::pair{ - transform(s.first, [](PatternNode const &n) { return n.raw_node; }), - transform(s.second, [](PatternNode const &n) { return n.raw_node; }), - }; -} +// GraphSplit get_raw_split(PatternSplit const &s) { +// return std::pair{ +// transform(s.first, [](PatternNode const &n) { return n.raw_node; }), +// transform(s.second, [](PatternNode const &n) { return n.raw_node; }), +// }; +// } -UnlabelledPatternEdgeSplits - get_edge_splits(UnlabelledGraphPattern const &pattern, - PatternSplit const &split) { - bidict> - raw_result = get_edge_splits(pattern.raw_graph, get_raw_split(split), ); - return UnlabelledPatternEdgeSplits{raw_result}; -} +// UnlabelledPatternEdgeSplits +// get_edge_splits(UnlabelledGraphPattern const &pattern, +// PatternSplit const &split) { +// bidict> +// raw_result = get_edge_splits(pattern.raw_graph, get_raw_split(split), ); +// return UnlabelledPatternEdgeSplits{raw_result}; +// } -std::pair +PatternSplitResult apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { - return { - get_subgraph(p, s.left); - get_subgraph(p, s.right); + OpenDataflowSubgraphResult raw_second_subgraph_result = get_subgraph(p.raw_graph, transform(s.second, [](PatternNode const &pn) { return pn.raw_node; })); + + bidict subpattern_1_outputs_to_subpattern_2_inputs; + for (auto const &kv : raw_second_subgraph_result.full_graph_values_to_subgraph_inputs) { + OpenDataflowValue open_dataflow_value = kv.first; + DataflowGraphInput dataflow_graph_input = kv.second; + subpattern_1_outputs_to_subpattern_2_inputs.equate( + pattern_value_from_raw_open_dataflow_value(open_dataflow_value), PatternInput{dataflow_graph_input}); + } + + return PatternSplitResult{ + get_subgraph(p, s.first), + UnlabelledGraphPattern{raw_second_subgraph_result.graph}, + subpattern_1_outputs_to_subpattern_2_inputs, }; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc index 287139dc30..8ad7d3496f 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_value.cc @@ -3,10 +3,17 @@ namespace FlexFlow { -OpenDataflowValue raw_dataflow_value_from_pattern_value(PatternValue const &v) { +OpenDataflowValue raw_open_dataflow_value_from_pattern_value(PatternValue const &v) { return v.visit(overload { - [](PatternNodeOutput const &o) { return o.raw_dataflow_output; }, - [](PatternInput const &i) { return i.raw_dataflow_graph_input; }, + [](PatternNodeOutput const &o) { return OpenDataflowValue{o.raw_dataflow_output}; }, + [](PatternInput const &i) { return OpenDataflowValue{i.raw_dataflow_graph_input}; }, + }); +} + +PatternValue pattern_value_from_raw_open_dataflow_value(OpenDataflowValue const &v) { + return v.visit(overload { + [](DataflowOutput const &o) { return PatternValue{PatternNodeOutput{o}}; }, + [](DataflowGraphInput const &i) { return PatternValue{PatternInput{i}}; }, }); } diff --git a/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc index 2e1a47f9e5..dea3e5f500 100644 --- a/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/standard_pattern_edge.cc @@ -2,12 +2,20 @@ namespace FlexFlow { -PatternNode get_src_node(StandardPatternEdge const &) { - NOT_IMPLEMENTED(); +PatternNode get_src_node(StandardPatternEdge const &e) { + return PatternNode{e.raw_edge.src.node}; } -PatternNode get_dst_node(StandardPatternEdge const &) { - NOT_IMPLEMENTED(); +PatternNode get_dst_node(StandardPatternEdge const &e) { + return PatternNode{e.raw_edge.dst.node}; +} + +int get_src_idx(StandardPatternEdge const &e) { + return e.raw_edge.src.idx; +} + +int get_dst_idx(StandardPatternEdge const &e) { + return e.raw_edge.dst.idx; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc index 3f163aedb6..96d3d40cc1 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.cc @@ -1,20 +1,36 @@ #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h" +#include "utils/containers.h" namespace FlexFlow { UnlabelledDataflowGraphPatternMatch empty_unlabelled_pattern_match() { return UnlabelledDataflowGraphPatternMatch{ bidict{}, - bidict{}, + bidict{}, }; } template std::optional> try_merge_nondisjoint_bidicts(bidict const &d1, bidict const &d2) { - for (L const &l : intersection(keys(d1), keys(d2))) { - return + bidict result; + for (L const &l : set_union(keys(d1), keys(d2))) { + if (d1.contains_l(l) && d2.contains_l(l)) { + if (d1.at_l(l) == d2.at_l(l)) { + result.equate(l, d1.at_l(l)); + } else { + return std::nullopt; + } + } else if (d1.contains_l(l)) { + result.equate(l, d1.at_l(l)); + } else { + assert (d2.contains_l(l)); + + result.equate(l, d2.at_l(l)); + } } + + return result; } @@ -22,15 +38,21 @@ std::optional merge_unlabelled_dataflow_graph_pattern_matches(UnlabelledDataflowGraphPatternMatch const &subpattern_1, UnlabelledDataflowGraphPatternMatch const &subpattern_2, bidict const &outputs_of_1_to_inputs_of_2) { - if (!are_disjoint(matched_nodes(subpattern_1), matched_nodes(subpattern_2))) { - return std::nullopt; - } + bidict merged_node_assignment = ({ + std::optional> result = try_merge_nondisjoint_bidicts( + subpattern_1.node_assignment, subpattern_2.node_assignment); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); - bidict merged_node_assignment = merge_maps(subpattern_1.node_assignment, subpattern_2.node_assignment); + assert (all_of(keys(subpattern_2.input_assignment), [&](PatternInput const &i) { return outputs_of_1_to_inputs_of_2.contains_r(i); })); - // if (!are_disjoint(matched_values(subpattern_1), matched_values(subpattern_2))) { - // - // } + return UnlabelledDataflowGraphPatternMatch{ + merged_node_assignment, + subpattern_1.input_assignment, + }; } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 5f9d5b7a1e..328ce5aba1 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,9 +1,12 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/pattern_value.h" #include "utils/containers.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" namespace FlexFlow { @@ -21,8 +24,7 @@ std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { } std::unordered_set get_values(UnlabelledGraphPattern const &p) { - return transform(get_open_dataflow_values(p.raw_graph), - [](OpenDataflowValue const &v) { return PatternValue{v}; }); + return transform(get_open_dataflow_values(p.raw_graph), pattern_value_from_raw_open_dataflow_value); } std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { @@ -32,24 +34,24 @@ std::vector get_topological_ordering(UnlabelledGraphPattern const & std::vector get_inputs_to_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_inputs(p.raw_graph, n.raw_node), - [](OpenDataflowValue const &v) { return PatternValue{v}; }); + return transform(get_inputs(p.raw_graph, n.raw_node), pattern_value_from_raw_open_dataflow_value); } std::vector get_outputs_from_pattern_node(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_outputs(p.raw_graph, n.raw_node), - [](DataflowOutput const &o) { return PatternValue{OpenDataflowValue{o}}; }); + return transform(get_outputs(p.raw_graph, n.raw_node), + [](DataflowOutput const &o) { return pattern_value_from_raw_open_dataflow_value(OpenDataflowValue{o}); }); } UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, std::unordered_set const &n) { - NOT_IMPLEMENTED(); - // return UnlabelledGraphPattern{ - // get_subgraph(p.raw_graph, - // transform(n, [](PatternNode const &n) { return n.raw_node; })); - // }; + OpenDataflowGraphView raw_subgraph = + get_subgraph(p.raw_graph, transform(n, [](PatternNode const &pn) { return pn.raw_node; })).graph; + return UnlabelledGraphPattern{ + raw_subgraph, + }; } + } // namespace FlexFlow diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index e130d0f5d6..4edd016784 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -1,7 +1,8 @@ #include "doctest/doctest.h" #include "rapidcheck.h" -#include "substitutions/graph_pattern_match.h" +#include "substitutions/unlabelled/find_pattern_matches.h" #include "test/utils/all.h" +#include "substitutions/unlabelled/match_additional_criterion.h" using namespace FlexFlow; @@ -92,9 +93,7 @@ TEST_SUITE(FF_TEST_SUITE) { sg0.add_edge(e0); } - MatchAdditionalCriterion always_true{ - [](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; + MatchAdditionalCriterion always_true = match_additional_crition_always_true(); std::vector matches = find_pattern_matches( as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 11e4ba8b05..d568600465 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -187,6 +187,14 @@ std::unordered_map format_as(bidict const &b) { return b; } +template +std::ostream &operator<<(std::ostream &s, bidict const &b) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return s << fmt::to_string(b); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 5b8d474025..26193ae416 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -24,15 +24,4 @@ typename std::enable_if>::value, } // namespace FlexFlow -namespace fmt { - -template -struct formatter<::std::variant> : formatter<::std::string> { - template - auto format(::std::variant const &m, FormatContext &ctx) - -> decltype(ctx.out()); -}; - -} // namespace fmt - #endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 72fca552d8..f1d4a9f2d9 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/containers.h" #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" @@ -10,28 +9,8 @@ #include #include -namespace fmt { - -template -template -auto formatter<::std::variant>::format(::std::variant const &m, - FormatContext &ctx) - -> decltype(ctx.out()) { - - std::string result = - std::visit([](auto &&x) { return fmt::to_string(x); }, m); - return formatter::format(result, ctx); -} -} // namespace fmt - namespace FlexFlow { -template -struct delegate_ostream_operator> : std::true_type {}; - -template -struct delegate_ostream_operator> : std::true_type {}; - template typename std::enable_if>::value, std::ostream &>::type diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h new file mode 100644 index 0000000000..aef77f7d0f --- /dev/null +++ b/lib/utils/include/utils/fmt/optional.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H + +#include +#include +#include "utils/check_fmtable.h" + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::optional const &t) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(t); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index 1812024f95..8213534e7c 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -11,6 +11,7 @@ #include "utils/containers/zip_vectors.h" #include "utils/containers/without_nullopts.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/node/i_graph_view.h b/lib/utils/include/utils/graph/node/i_graph_view.h index 7d395bca2c..be5b07a685 100644 --- a/lib/utils/include/utils/graph/node/i_graph_view.h +++ b/lib/utils/include/utils/graph/node/i_graph_view.h @@ -2,6 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_I_GRAPH_VIEW_H #include "utils/graph/node/node_query.dtg.h" +#include "utils/type_traits.h" + namespace FlexFlow { struct IGraphView { diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index cbbd0a092d..b1c22a6e94 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_QUERY_SET_H #include "utils/bidict.h" -#include "utils/containers.decl.h" +#include "utils/containers.h" #include "utils/exception.h" #include #include diff --git a/lib/utils/include/utils/hash/map.h b/lib/utils/include/utils/hash/map.h index 48e9cdeac0..e8648dfa20 100644 --- a/lib/utils/include/utils/hash/map.h +++ b/lib/utils/include/utils/hash/map.h @@ -3,6 +3,7 @@ #include "utils/hash-utils.h" #include +#include "utils/hash/pair.h" namespace std { diff --git a/lib/utils/include/utils/hash/unordered_map.h b/lib/utils/include/utils/hash/unordered_map.h index 1435784249..13eb8ba195 100644 --- a/lib/utils/include/utils/hash/unordered_map.h +++ b/lib/utils/include/utils/hash/unordered_map.h @@ -3,6 +3,7 @@ #include "utils/hash-utils.h" #include +#include "utils/hash/pair.h" namespace std { diff --git a/lib/utils/src/utils/fmt/optional.cc b/lib/utils/src/utils/fmt/optional.cc new file mode 100644 index 0000000000..e21b32eaa9 --- /dev/null +++ b/lib/utils/src/utils/fmt/optional.cc @@ -0,0 +1 @@ +#include "utils/fmt/optional.h" diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc index e878b4deee..ba57ea0053 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -1,6 +1,7 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index 4f4f7e8bd1..9e27ed3882 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -5,6 +5,7 @@ #include "utils/graph/node/algorithms.h" #include "utils/graph/traversal.h" #include "utils/graph/views/views.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 758d7e299f..90ce080938 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -1,4 +1,5 @@ #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/containers.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index 1c4e6eb4bf..589d778d4f 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -2,6 +2,7 @@ #include "utils/containers/enumerate_vector.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc index 1cd224dc1c..3c91bebbf1 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -2,6 +2,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc index 2234f4d283..b7a11c22ea 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph.cc @@ -5,6 +5,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" #include "utils/overload.h" #include "utils/graph/node/algorithms.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc index 0fba80d612..f3e9701deb 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.cc @@ -1,5 +1,6 @@ #include "utils/graph/open_dataflow_graph/i_open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" +#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index ff7683dbcd..cb386a5507 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,5 +1,4 @@ #include "doctest/doctest.h" -#include "utils/containers.decl.h" #include "utils/fmt/expected.h" #include #include From 5f0c88a873cfa4f56af184f6e7ae0aba13856baf Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 12 Jul 2024 22:47:46 -0700 Subject: [PATCH 17/71] substitutions-tests now builds --- .proj.toml | 2 +- .../substitutions/unlabelled/pattern_edge.h | 3 + .../unlabelled/input_pattern_edge.cc | 4 + .../substitutions/unlabelled/pattern_edge.cc | 7 + .../unlabelled/unlabelled_graph_pattern.cc | 9 + .../test/src/test_pattern_matches.cc | 133 ++++----- .../test/src/test_substitution.cc | 254 +++++++++--------- lib/utils/include/utils/containers.h | 2 +- .../graph/dataflow_graph/dataflow_graph.h | 2 +- .../instances/unordered_set_dataflow_graph.h | 25 +- .../instances/unordered_set_dataflow_graph.cc | 56 +++- .../graph/open_dataflow_graph/algorithms.cc | 5 +- 12 files changed, 292 insertions(+), 210 deletions(-) diff --git a/.proj.toml b/.proj.toml index f6e3cd2308..a593cb23e5 100644 --- a/.proj.toml +++ b/.proj.toml @@ -17,7 +17,7 @@ test_targets = [ "utils-tests", "op-attrs-tests", "pcg-tests", - # "substitutions-tests", + "substitutions-tests", # "compiler-tests", # "substitution-generator-tests", ] diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h index f6050ea8c1..cd5b24fcb3 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -5,6 +5,7 @@ #include "substitutions/unlabelled/standard_pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include namespace FlexFlow { @@ -21,6 +22,8 @@ InputPatternEdge require_input_edge(PatternEdge const &); PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &); +PatternEdge pattern_edge_from_raw_open_dataflow_edge(OpenDataflowEdge const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc index 7e993345b5..e8deacebec 100644 --- a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -11,4 +11,8 @@ PatternNode get_dst_node(InputPatternEdge const &e) { return PatternNode{e.raw_edge.dst.node}; } +int get_dst_idx(InputPatternEdge const &e) { + return e.raw_edge.dst.idx; +} + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc index 5a6ca41281..98f5aa0d5e 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -46,4 +46,11 @@ PatternEdge pattern_edge_from_standard_edge(StandardPatternEdge const &e) { return PatternEdge{e}; } +PatternEdge pattern_edge_from_raw_open_dataflow_edge(OpenDataflowEdge const &e) { + return e.visit(overload { + [](DataflowInputEdge const &ee) { return PatternEdge{InputPatternEdge{ee}}; }, + [](DataflowEdge const &ee) { return PatternEdge{StandardPatternEdge{ee}}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 328ce5aba1..f67b8780d9 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,4 +1,5 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/pattern_edge.h" #include "substitutions/unlabelled/pattern_value.h" #include "utils/containers.h" #include "utils/graph/node/algorithms.h" @@ -27,6 +28,14 @@ std::unordered_set get_values(UnlabelledGraphPattern const &p) { return transform(get_open_dataflow_values(p.raw_graph), pattern_value_from_raw_open_dataflow_value); } +std::unordered_set get_inputs(UnlabelledGraphPattern const &p) { + return transform(get_inputs(p.raw_graph), [](DataflowGraphInput const &i) { return PatternInput{i}; }); +} + +std::unordered_set get_edges(UnlabelledGraphPattern const &p) { + return transform(get_edges(p.raw_graph), pattern_edge_from_raw_open_dataflow_edge); +} + std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { return transform(get_topological_ordering(p.raw_graph), [](Node const &n) { return PatternNode{n}; }); diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 4edd016784..d558ffeefa 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -3,44 +3,47 @@ #include "substitutions/unlabelled/find_pattern_matches.h" #include "test/utils/all.h" #include "substitutions/unlabelled/match_additional_criterion.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "substitutions/unlabelled/pattern_matching.h" using namespace FlexFlow; namespace rc { -template <> -struct Arbitrary { - static int const MAX_GRAPH_SIZE = 200; - static int const MAX_EDGE_SIZE = 1000; - - static Gen arbitrary() { - return gen::exec([&] { - int num_nodes = *gen::inRange(1, MAX_GRAPH_SIZE + 1); - MultiDiGraph g = MultiDiGraph::template create(); - - std::vector nodes; - for (int i = 0; i < num_nodes; ++i) { - nodes.push_back(g.add_node()); - } - - int num_edges = *gen::inRange(1, MAX_GRAPH_SIZE + 1); - for (int i = 0; i < num_edges; ++i) { - int src_id = *gen::inRange(0, num_nodes); - int dst_id = *gen::inRange(0, num_nodes); - if (src_id > dst_id) { - std::swap(src_id, dst_id); - } - - g.add_edge(MultiDiEdge{nodes[dst_id], - g.add_node_port(), - nodes[src_id], - g.add_node_port()}); - } - - return g; - }); - } -}; +// template <> +// struct Arbitrary { +// static int const MAX_GRAPH_SIZE = 200; +// static int const MAX_EDGE_SIZE = 1000; +// +// static Gen arbitrary() { +// return gen::exec([&] { +// int num_nodes = *gen::inRange(1, MAX_GRAPH_SIZE + 1); +// MultiDiGraph g = MultiDiGraph::template create(); +// +// std::vector nodes; +// for (int i = 0; i < num_nodes; ++i) { +// nodes.push_back(g.add_node()); +// } +// +// int num_edges = *gen::inRange(1, MAX_GRAPH_SIZE + 1); +// for (int i = 0; i < num_edges; ++i) { +// int src_id = *gen::inRange(0, num_nodes); +// int dst_id = *gen::inRange(0, num_nodes); +// if (src_id > dst_id) { +// std::swap(src_id, dst_id); +// } +// +// g.add_edge(MultiDiEdge{nodes[dst_id], +// g.add_node_port(), +// nodes[src_id], +// g.add_node_port()}); +// } +// +// return g; +// }); +// } +// }; } // namespace rc @@ -65,46 +68,52 @@ struct Arbitrary { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find_pattern_matches_small") { - MultiDiGraph g = MultiDiGraph::template create(); + UnlabelledGraphPattern pattern = [] { + OpenDataflowGraph g = OpenDataflowGraph::create(); - { - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); + NodeAddedResult n0_added = g.add_node({}, 1); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; - MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; - MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; - MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; + NodeAddedResult n1_added = g.add_node({v0}, 1); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - } + return UnlabelledGraphPattern{g}; + }(); - MultiDiGraph sg0 = MultiDiGraph::template create(); + OpenDataflowGraph graph = [] { + OpenDataflowGraph g = OpenDataflowGraph::create(); + + NodeAddedResult n0_added = g.add_node({}, 1); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; - { - Node n0 = sg0.add_node(); - Node n1 = sg0.add_node(); + NodeAddedResult n1_added = g.add_node({v0}, 1); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; - MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + NodeAddedResult n2_added = g.add_node({v1}, 1); + Node n2 = n2_added.node; + OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; - sg0.add_edge(e0); - } + NodeAddedResult n3_added = g.add_node({v2}, 1); + Node n3 = n3_added.node; + OpenDataflowValue v3 = OpenDataflowValue{get_only(n3_added.outputs)}; - MatchAdditionalCriterion always_true = match_additional_crition_always_true(); + return g; + }(); - std::vector matches = find_pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); + std::vector matches = find_pattern_matches( + pattern, graph, match_additional_crition_always_true()); - RC_ASSERT(matches.size() == 3); + CHECK(matches.size() == 3); - for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches(as_openmultidigraph(sg0), - as_openmultidigraph(g), - match, - always_true)); + for (UnlabelledDataflowGraphPatternMatch const &match : matches) { + CHECK(unlabelled_pattern_does_match(pattern, + graph, + match, + match_additional_crition_always_true())); } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 2d9320275d..156d573ab8 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -5,130 +5,130 @@ using namespace FlexFlow; -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("apply_substitution") { - OperatorPattern operator_pattern_n0{ - std::vector{ - OperatorAttributeConstraint{ConstraintType::EQUAL, - OperatorAttributeKey::OP_TYPE, - OperatorType::LINEAR}}}; - - ParallelTensorPattern tensor_pattern_e0{ - std::vector{ - TensorAttributeConstraint{ConstraintType::EQUAL, - ListIndexAccess{ - TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; - - ParallelTensorPattern tensor_pattern_empty{ - std::vector{}}; - - auto ig = - OutputLabelledOpenMultiDiGraph:: - create>(); - Node n0 = ig.add_node(operator_pattern_n0); - NodePort p0 = ig.add_node_port(); - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - ig.add_edge(e0); - ig.add_label(e0, tensor_pattern_e0); - - RC_ASSERT(get_nodes(ig).size() == 1); - RC_ASSERT(get_edges(ig).size() == 1); - - GraphPattern input_graph{ig}; - - OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, - AttrConstant{OperatorType::REPARTITION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::LINEAR}}, - {OperatorAttributeKey::OUT_CHANNELS, - OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, - {OperatorAttributeKey::USE_BIAS, - OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, - {OperatorAttributeKey::DATA_TYPE, - OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, - {OperatorAttributeKey::ACTIVATION, - OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, - {OperatorAttributeKey::REGULARIZER, - OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; - - OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REDUCTION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - auto og = NodeLabelledOpenMultiDiGraph::create< - UnorderedNodeLabelledOpenMultiDiGraph>(); - Node n1 = og.add_node(op_ass_n1); - Node n2 = og.add_node(op_ass_n2); - Node n3 = og.add_node(op_ass_n3); - NodePort p1 = og.add_node_port(); - NodePort p2 = og.add_node_port(); - NodePort p3 = og.add_node_port(); - InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; - MultiDiEdge e2{n2, p2, n1, p1}; - MultiDiEdge e3{n3, p3, n2, p2}; - og.add_edge(e1); - og.add_edge(e2); - og.add_edge(e3); - OutputGraphExpr output_graph_expr{og}; - - RC_ASSERT(get_nodes(og).size() == 3); - RC_ASSERT(get_edges(og).size() == 3); - - bidict input_mapping; - input_mapping.equate(e0, e1); - bidict output_mapping; - - Substitution substitution{ - input_graph, output_graph_expr, input_mapping, output_mapping}; - - SubParallelComputationGraph pcg = - OutputLabelledOpenMultiDiGraph::create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - - Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, - "linear"}); - NodePort p4 = pcg.add_node_port(); - NodePort p5 = pcg.add_node_port(); - - MultiDiEdge e4{n5, p5, n4, p4}; - pcg.add_edge(e4); - ParallelDim dim = {2, 1, false}; - ParallelTensorDims dims = {FFOrdered{dim}}; - pcg.add_label(e4, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); - - MatchAdditionalCriterion criterion{ - [&](Node const &pattern_node, Node const &graph_node) { - return operator_satisfies(pcg.at(graph_node), - input_graph.value().at(pattern_node)); - }, - [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies( - pcg.at(graph_edge), input_graph.value().at(pattern_edge)); - }}; - - RC_ASSERT(criterion.node_criterion(n0, n5)); - - std::vector matches = - find_pattern_matches(input_graph, pcg, criterion); - - RC_ASSERT(matches.size() == 1); - - SubParallelComputationGraph new_pcg = - apply_substitution(pcg, substitution, matches[0]); - - RC_ASSERT(get_nodes(new_pcg).size() == 4); - RC_ASSERT(get_edges(new_pcg).size() == 3); - } -} +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("apply_substitution") { +// OperatorPattern operator_pattern_n0{ +// std::vector{ +// OperatorAttributeConstraint{ConstraintType::EQUAL, +// OperatorAttributeKey::OP_TYPE, +// OperatorType::LINEAR}}}; +// +// ParallelTensorPattern tensor_pattern_e0{ +// std::vector{ +// TensorAttributeConstraint{ConstraintType::EQUAL, +// ListIndexAccess{ +// TensorAttributeKey::DIM_SIZES, 0}, +// 2}}}; +// +// ParallelTensorPattern tensor_pattern_empty{ +// std::vector{}}; +// +// auto ig = +// OutputLabelledOpenMultiDiGraph:: +// create>(); +// Node n0 = ig.add_node(operator_pattern_n0); +// NodePort p0 = ig.add_node_port(); +// InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; +// ig.add_edge(e0); +// ig.add_label(e0, tensor_pattern_e0); +// +// RC_ASSERT(get_nodes(ig).size() == 1); +// RC_ASSERT(get_edges(ig).size() == 1); +// +// GraphPattern input_graph{ig}; +// +// OperatorAttrAssignment op_ass_n1{ +// {{OperatorAttributeKey::OP_TYPE, +// AttrConstant{OperatorType::REPARTITION}}, +// {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, +// {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; +// +// OperatorAttrAssignment op_ass_n2{ +// {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::LINEAR}}, +// {OperatorAttributeKey::OUT_CHANNELS, +// OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, +// {OperatorAttributeKey::USE_BIAS, +// OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, +// {OperatorAttributeKey::DATA_TYPE, +// OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, +// {OperatorAttributeKey::ACTIVATION, +// OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, +// {OperatorAttributeKey::REGULARIZER, +// OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; +// +// OperatorAttrAssignment op_ass_n3{ +// {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REDUCTION}}, +// {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, +// {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; +// +// auto og = NodeLabelledOpenMultiDiGraph::create< +// UnorderedNodeLabelledOpenMultiDiGraph>(); +// Node n1 = og.add_node(op_ass_n1); +// Node n2 = og.add_node(op_ass_n2); +// Node n3 = og.add_node(op_ass_n3); +// NodePort p1 = og.add_node_port(); +// NodePort p2 = og.add_node_port(); +// NodePort p3 = og.add_node_port(); +// InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; +// MultiDiEdge e2{n2, p2, n1, p1}; +// MultiDiEdge e3{n3, p3, n2, p2}; +// og.add_edge(e1); +// og.add_edge(e2); +// og.add_edge(e3); +// OutputGraphExpr output_graph_expr{og}; +// +// RC_ASSERT(get_nodes(og).size() == 3); +// RC_ASSERT(get_edges(og).size() == 3); +// +// bidict input_mapping; +// input_mapping.equate(e0, e1); +// bidict output_mapping; +// +// Substitution substitution{ +// input_graph, output_graph_expr, input_mapping, output_mapping}; +// +// SubParallelComputationGraph pcg = +// OutputLabelledOpenMultiDiGraph::create< +// UnorderedOutputLabelledOpenMultiDiGraph>(); +// +// Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); +// Node n5 = pcg.add_node(Operator{ +// LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, +// "linear"}); +// NodePort p4 = pcg.add_node_port(); +// NodePort p5 = pcg.add_node_port(); +// +// MultiDiEdge e4{n5, p5, n4, p4}; +// pcg.add_edge(e4); +// ParallelDim dim = {2, 1, false}; +// ParallelTensorDims dims = {FFOrdered{dim}}; +// pcg.add_label(e4, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); +// +// MatchAdditionalCriterion criterion{ +// [&](Node const &pattern_node, Node const &graph_node) { +// return operator_satisfies(pcg.at(graph_node), +// input_graph.value().at(pattern_node)); +// }, +// [&](OpenMultiDiEdge const &pattern_edge, +// OpenMultiDiEdge const &graph_edge) { +// return parallel_tensor_satisfies( +// pcg.at(graph_edge), input_graph.value().at(pattern_edge)); +// }}; +// +// RC_ASSERT(criterion.node_criterion(n0, n5)); +// +// std::vector matches = +// find_pattern_matches(input_graph, pcg, criterion); +// +// RC_ASSERT(matches.size() == 1); +// +// SubParallelComputationGraph new_pcg = +// apply_substitution(pcg, substitution, matches[0]); +// +// RC_ASSERT(get_nodes(new_pcg).size() == 4); +// RC_ASSERT(get_edges(new_pcg).size() == 3); +// } +// } diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index b3d0db8822..0ad0e6b756 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -298,7 +298,7 @@ std::unordered_map generate_map(C const &c, F const &f) { static_assert(is_hashable::value, "Key type should be hashable (but is not)"); - auto transformed = transform(c, [&](K const &k) -> std::pair { + auto transformed = transform(as_vector(c), [&](K const &k) -> std::pair { return {k, f(k)}; }); return {transformed.cbegin(), transformed.cend()}; diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index e04f1a92e1..ccb4a0e0a5 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -15,7 +15,7 @@ struct DataflowGraph : virtual DataflowGraphView { void add_node_unsafe(Node const &node, std::vector const &inputs, std::vector const &outputs); - + std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(DataflowEdgeQuery const &) const; std::unordered_set query_outputs(DataflowOutputQuery const &) const; diff --git a/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h index 5f552f6d66..90cce9bf91 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_dataflow_graph.h @@ -2,20 +2,27 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_UNORDERED_SET_DATAFLOW_GRAPH_H #include "utils/graph/dataflow_graph/i_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" +#include "utils/graph/open_dataflow_graph/i_open_dataflow_graph.h" #include "utils/graph/node/node_source.h" namespace FlexFlow { -struct UnorderedSetDataflowGraph final : public IDataflowGraph { +struct UnorderedSetDataflowGraph final : virtual public IDataflowGraph, + virtual public IOpenDataflowGraph { public: UnorderedSetDataflowGraph(); NodeAddedResult add_node(std::vector const &inputs, int num_outputs) override; + NodeAddedResult add_node(std::vector const &inputs, + int num_outputs) override; + DataflowGraphInput add_input() override; std::unordered_set query_nodes(NodeQuery const &) const override; - std::unordered_set query_edges(DataflowEdgeQuery const &) const override; + std::unordered_set query_edges(OpenDataflowEdgeQuery const &) const override; std::unordered_set query_outputs(DataflowOutputQuery const &) const override; + std::unordered_set get_inputs() const override; void add_node_unsafe(Node const &node, std::vector const &inputs, @@ -25,16 +32,24 @@ struct UnorderedSetDataflowGraph final : public IDataflowGraph { UnorderedSetDataflowGraph *clone() const override; private: + void add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs); + UnorderedSetDataflowGraph(NodeSource const &node_source, + DataflowGraphInputSource const &graph_input_source, std::unordered_set const &nodes, - std::unordered_set const &edges, - std::unordered_set const &outputs); + std::unordered_set const &edges, + std::unordered_set const &outputs, + std::unordered_set const &graph_inputs); private: NodeSource node_source; + DataflowGraphInputSource graph_input_source; std::unordered_set nodes; - std::unordered_set edges; + std::unordered_set edges; std::unordered_set outputs; + std::unordered_set graph_inputs; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(UnorderedSetDataflowGraph); diff --git a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc index 589d778d4f..c24bb4f9a7 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_dataflow_graph.cc @@ -3,19 +3,34 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" #include "utils/containers.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" namespace FlexFlow { UnorderedSetDataflowGraph::UnorderedSetDataflowGraph() {} UnorderedSetDataflowGraph::UnorderedSetDataflowGraph(NodeSource const &node_source, + DataflowGraphInputSource const &graph_input_source, std::unordered_set const &nodes, - std::unordered_set const &edges, - std::unordered_set const &outputs) - : node_source(node_source), nodes(nodes), edges(edges), outputs(outputs) + std::unordered_set const &edges, + std::unordered_set const &outputs, + std::unordered_set const &graph_inputs) + : node_source(node_source), + graph_input_source(graph_input_source), + nodes(nodes), + edges(edges), + outputs(outputs), + graph_inputs(graph_inputs) {} NodeAddedResult UnorderedSetDataflowGraph::add_node(std::vector const &inputs, int num_outputs) { + std::vector open_inputs = transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }); + return this->add_node(open_inputs, num_outputs); +} + +NodeAddedResult UnorderedSetDataflowGraph::add_node(std::vector const &inputs, + int num_outputs) { Node new_node = this->node_source.new_node(); std::vector new_outputs = transform(count(num_outputs), @@ -26,16 +41,21 @@ NodeAddedResult UnorderedSetDataflowGraph::add_node(std::vector return NodeAddedResult{new_node, new_outputs}; } +DataflowGraphInput UnorderedSetDataflowGraph::add_input() { + DataflowGraphInput new_graph_input = this->graph_input_source.new_dataflow_graph_input(); + + this->graph_inputs.insert(new_graph_input); + + return new_graph_input; +} + std::unordered_set UnorderedSetDataflowGraph::query_nodes(NodeQuery const &q) const { return apply_query(q.nodes, this->nodes); } -std::unordered_set UnorderedSetDataflowGraph::query_edges(DataflowEdgeQuery const &q) const { - return filter(this->edges, [&](DataflowEdge const &e) { - return includes(q.src_nodes, e.src.node) - && includes(q.dst_nodes, e.dst.node) - && includes(q.src_idxs, e.src.idx) - && includes(q.dst_idxs, e.dst.idx); +std::unordered_set UnorderedSetDataflowGraph::query_edges(OpenDataflowEdgeQuery const &q) const { + return filter(this->edges, [&](OpenDataflowEdge const &e) { + return open_dataflow_edge_query_includes(q, e); }); } @@ -46,16 +66,28 @@ std::unordered_set UnorderedSetDataflowGraph::query_outputs(Data }); } +std::unordered_set UnorderedSetDataflowGraph::get_inputs() const { + return this->graph_inputs; +} + void UnorderedSetDataflowGraph::add_node_unsafe(Node const &node, std::vector const &inputs, std::vector const &outputs) { + std::vector open_inputs = transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }); + this->add_node_unsafe(node, open_inputs, outputs); +} + + +void UnorderedSetDataflowGraph::add_node_unsafe(Node const &node, + std::vector const &inputs, + std::vector const &outputs) { assert (!contains(this->nodes, node)); assert (are_disjoint(this->outputs, without_order(outputs))); this->nodes.insert(node); for (auto const &[input_idx, input_src] : enumerate_vector(inputs)) { - this->edges.insert(DataflowEdge{input_src, DataflowInput{node, input_idx}}); + this->edges.insert(open_dataflow_edge_from_src_and_dst(input_src, DataflowInput{node, input_idx})); } extend(this->outputs, outputs); @@ -67,16 +99,18 @@ void UnorderedSetDataflowGraph::inplace_materialize_from(DataflowGraphView const std::unordered_set outputs = get_all_dataflow_outputs(view); this->nodes = nodes; - this->edges = edges; + this->edges = transform(edges, [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); this->outputs = outputs; } UnorderedSetDataflowGraph *UnorderedSetDataflowGraph::clone() const { return new UnorderedSetDataflowGraph{ this->node_source, + this->graph_input_source, this->nodes, this->edges, this->outputs, + this->graph_inputs, }; } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc index 3c91bebbf1..a4113195e0 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc @@ -1,4 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/containers/group_by.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -35,8 +36,8 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, }), [](OpenDataflowEdge const &l, OpenDataflowEdge const &r) { return get_open_dataflow_edge_dst_idx(l) < get_open_dataflow_edge_dst_idx(r); }); } -std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &) { - NOT_IMPLEMENTED(); +std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &g, std::unordered_set const &ns) { + return generate_map(ns, [&](Node const &n) { return get_incoming_edges(g, n); }); } std::unordered_set get_open_dataflow_values(OpenDataflowGraphView const &g) { From 3228f2d5d4c9b2d70399e9cb23c65ae539cb3133 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 13 Jul 2024 23:09:58 -0700 Subject: [PATCH 18/71] Fix bug in filter, pass some initial substitution tests --- .../unlabelled/pattern_matching.h | 9 +- .../unlabelled/pattern_matching.cc | 4 +- .../test/src/test_pattern_matches.cc | 182 ++++++++++++++---- lib/utils/include/utils/bidict.h | 4 + lib/utils/include/utils/containers.decl.h | 9 - lib/utils/include/utils/containers.h | 24 +-- lib/utils/include/utils/containers/filter.h | 49 +++++ .../include/utils/containers/inplace_filter.h | 38 ++++ .../dataflow_graph/dataflow_graph_view.h | 2 +- .../open_dataflow_graph_view.h | 3 +- lib/utils/src/utils/containers/filter.cc | 1 + .../src/utils/containers/inplace_filter.cc | 1 + .../algorithms/get_subgraph_inputs.cc | 12 +- lib/utils/test/CMakeLists.txt | 2 + lib/utils/test/src/utils/containers/filter.cc | 89 +++++++++ .../src/utils/containers/inplace_filter.cc | 90 +++++++++ 16 files changed, 438 insertions(+), 81 deletions(-) create mode 100644 lib/utils/include/utils/containers/filter.h create mode 100644 lib/utils/include/utils/containers/inplace_filter.h create mode 100644 lib/utils/src/utils/containers/filter.cc create mode 100644 lib/utils/src/utils/containers/inplace_filter.cc create mode 100644 lib/utils/test/src/utils/containers/filter.cc create mode 100644 lib/utils/test/src/utils/containers/inplace_filter.cc diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index 381a4933e3..e91bfbb905 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -5,6 +5,7 @@ #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "utils/graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_subgraph_result.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" namespace FlexFlow { @@ -12,8 +13,12 @@ namespace FlexFlow { // OpenDataflowGraphView apply_match(UnlabelledGraphPattern const &pattern, // UnlabelledDataflowGraphPatternMatch const &match); -OpenDataflowGraphView subgraph_matched(UnlabelledGraphPattern const &pattern, - UnlabelledDataflowGraphPatternMatch const &match); +OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &graph, + UnlabelledDataflowGraphPatternMatch const &match); +bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, + OpenDataflowGraphView const &subgraph, + UnlabelledDataflowGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion); bool unlabelled_pattern_does_match( UnlabelledGraphPattern const &pattern, diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 9081f90f08..bf2b37e0f6 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -115,7 +115,7 @@ struct ConcreteFromPattern { } }; -static bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, +bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &subgraph, UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { @@ -163,7 +163,7 @@ bool unlabelled_pattern_does_match( UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { - OpenDataflowGraphView matched_subgraph = subgraph_matched(matched_subgraph, match).graph; + OpenDataflowGraphView matched_subgraph = subgraph_matched(graph, match).graph; assert (keys(match.node_assignment) == get_nodes(pattern)); assert (keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph)); diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index d558ffeefa..33903ce0a0 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -3,9 +3,15 @@ #include "substitutions/unlabelled/find_pattern_matches.h" #include "test/utils/all.h" #include "substitutions/unlabelled/match_additional_criterion.h" +#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "substitutions/unlabelled/pattern_matching.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" +#include "utils/overload.h" +#include "utils/containers.h" using namespace FlexFlow; @@ -68,52 +74,144 @@ namespace rc { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find_pattern_matches_small") { - UnlabelledGraphPattern pattern = [] { - OpenDataflowGraph g = OpenDataflowGraph::create(); - - NodeAddedResult n0_added = g.add_node({}, 1); - Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; - - NodeAddedResult n1_added = g.add_node({v0}, 1); - Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; - - return UnlabelledGraphPattern{g}; - }(); - - OpenDataflowGraph graph = [] { - OpenDataflowGraph g = OpenDataflowGraph::create(); - - NodeAddedResult n0_added = g.add_node({}, 1); - Node n0 = n0_added.node; - OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; - - NodeAddedResult n1_added = g.add_node({v0}, 1); - Node n1 = n1_added.node; - OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; - - NodeAddedResult n2_added = g.add_node({v1}, 1); - Node n2 = n2_added.node; - OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; - - NodeAddedResult n3_added = g.add_node({v2}, 1); - Node n3 = n3_added.node; - OpenDataflowValue v3 = OpenDataflowValue{get_only(n3_added.outputs)}; + OpenDataflowGraph pattern_graph = OpenDataflowGraph::create(); + + NodeAddedResult pattern_n0_added = pattern_graph.add_node({}, 1); + Node pattern_n0 = pattern_n0_added.node; + OpenDataflowValue pattern_v0 = OpenDataflowValue{get_only(pattern_n0_added.outputs)}; + + NodeAddedResult pattern_n1_added = pattern_graph.add_node({pattern_v0}, 1); + Node pattern_n1 = pattern_n1_added.node; + OpenDataflowValue pattern_v1 = OpenDataflowValue{get_only(pattern_n1_added.outputs)}; + + UnlabelledGraphPattern pattern = UnlabelledGraphPattern{pattern_graph}; + PatternNode p0 = PatternNode{pattern_n0}; + PatternNode p1 = PatternNode{pattern_n1}; + + OpenDataflowGraph graph = OpenDataflowGraph::create(); + + NodeAddedResult n0_added = graph.add_node({}, 1); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + // CHECK(v0 == OpenDataflowValue{DataflowOutput{n0, 0}}); + + NodeAddedResult n1_added = graph.add_node({v0}, 1); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + + NodeAddedResult n2_added = graph.add_node({v1}, 1); + Node n2 = n2_added.node; + OpenDataflowValue v2 = OpenDataflowValue{get_only(n2_added.outputs)}; + + NodeAddedResult n3_added = graph.add_node({v2}, 1); + Node n3 = n3_added.node; + OpenDataflowValue v3 = OpenDataflowValue{get_only(n3_added.outputs)}; + + UnlabelledDataflowGraphPatternMatch match = UnlabelledDataflowGraphPatternMatch{ + bidict{ + {p0, n0}, + {p1, n1}, + }, + bidict{} + }; + + std::vector n1_incoming = {OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{n0, 0}, + DataflowInput{n1, 0}, + }, + }}; + + SUBCASE("get_incoming_edges") { + SUBCASE("n0") { + std::vector result = get_incoming_edges(graph, n0); + std::vector correct = {}; + CHECK(result == correct); + } + SUBCASE("n1") { + std::vector result = get_incoming_edges(graph, n1); + std::vector correct = n1_incoming; + CHECK(result == correct); + } + SUBCASE("both") { + std::unordered_map> result = get_incoming_edges(graph, {n0, n1}); + std::unordered_map> correct = { + { + n0, + {} + }, + { + n1, + n1_incoming + } + }; + CHECK(result == correct); + } + } - return g; - }(); + // { + // std::unordered_set xs = {n0, n1}; + // REQUIRE(contains(xs, n0)); + // REQUIRE(contains(xs, n1)); + // std::vector es = {OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0}, DataflowInput{n1, 0}}}}; + // auto myfilter = [&](OpenDataflowEdge const &e) { + // return e.visit(overload { + // [](DataflowInputEdge const &) { return true; }, + // [&](DataflowEdge const &ee) { return !contains(xs, ee.src.node); }, + // }); + // }; + // std::vector result = filter(es, myfilter); + // auto myrealfilter = [&](OpenDataflowEdge const &e) { return true; }; + // // REQUIRE(myfilter(es.at(0)) == false); + // REQUIRE(result.size() == 0); + // } + + + SUBCASE("get_subgraph_inputs") { + std::unordered_set result = get_subgraph_inputs(graph, {n0, n1}); + std::unordered_set correct = {}; + CHECK(result == correct); + } - std::vector matches = find_pattern_matches( - pattern, graph, match_additional_crition_always_true()); + SUBCASE("get_subgraph") { + OpenDataflowGraphView g = get_subgraph(graph, {n0, n1}).graph; + SUBCASE("nodes") { + std::unordered_set result = get_nodes(g); + std::unordered_set correct = {n0, n1}; + CHECK(result == correct); + } + SUBCASE("inputs") { + std::unordered_set result = g.get_inputs(); + std::unordered_set correct = {}; + CHECK(result == correct); + } + SUBCASE("get_open_dataflow_values") { + std::unordered_set values = get_open_dataflow_values(g); + CHECK(values.size() == 2); + } + } - CHECK(matches.size() == 3); + SUBCASE("subgraph_matched") { + OpenDataflowGraphView result = subgraph_matched(graph, match).graph; + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = {n0, n1}; + CHECK(result_nodes == correct_nodes); + } - for (UnlabelledDataflowGraphPatternMatch const &match : matches) { - CHECK(unlabelled_pattern_does_match(pattern, - graph, - match, - match_additional_crition_always_true())); + SUBCASE("unlabelled_pattern_does_match") { + CHECK(unlabelled_pattern_does_match(pattern, graph, match, match_additional_crition_always_true())); } + + // std::vector matches = find_pattern_matches( + // pattern, graph, match_additional_crition_always_true()); + + // CHECK(matches.size() == 3); + // + // for (UnlabelledDataflowGraphPatternMatch const &match : matches) { + // CHECK(unlabelled_pattern_does_match(pattern, + // graph, + // match, + // match_additional_crition_always_true())); + // } } } diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index d568600465..eed5ee2ae3 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -11,6 +11,10 @@ template struct bidict { bidict() : fwd_map{}, bwd_map{} {} + bidict(std::initializer_list> init) + : bidict(init.begin(), init.end()) + { } + template bidict(InputIt first, InputIt last) { for (auto it = first; it != last; it++) { diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 46979f4945..51fb160453 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -262,15 +262,6 @@ std::vector sorted_by(C const &c, F const &f); template std::function compare_by(F const &f); -template -C filter(C const &v, F const &f); - -template -std::unordered_set filter(std::unordered_set const &v, F const &f); - -template -void inplace_filter(C &v, F const &f); - template std::pair, std::vector> vector_split(std::vector const &v, std::size_t idx); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 0ad0e6b756..73a3f45f26 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -21,6 +21,7 @@ #include #include #include "utils/hash/pair.h" +#include "utils/containers/filter.h" namespace FlexFlow { @@ -615,29 +616,6 @@ std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; } -template -C filter(C const &v, F const &f) { - C result(v); - inplace_filter(result, f); - return result; -} - -template -std::unordered_set filter(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (T const &t : v) { - if (f(t)) { - result.insert(t); - } - } - return result; -} - -template -void inplace_filter(C &v, F const &f) { - std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); }); -} - template std::pair, std::vector> vector_split(std::vector const &v, std::size_t idx) { diff --git a/lib/utils/include/utils/containers/filter.h b/lib/utils/include/utils/containers/filter.h new file mode 100644 index 0000000000..1a7e8ecf35 --- /dev/null +++ b/lib/utils/include/utils/containers/filter.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTER_H + +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +std::vector filter(std::vector const &v, F const &f) { + std::vector result; + std::copy_if(v.cbegin(), v.cend(), std::back_inserter(result), f); + return result; +} + +template +std::unordered_set filter(std::unordered_set const &s, F const &f) { + std::unordered_set result; + std::copy_if(s.cbegin(), s.cend(), std::inserter(result, result.begin()), f); + return result; +} + +template +std::unordered_map filter(std::unordered_map const &m, F const &f) { + std::unordered_map result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + +template +std::set filter(std::set const &s, F const &f) { + std::set result; + std::copy_if(s.cbegin(), s.cend(), std::inserter(result, result.begin()), f); + return result; +} + +template +std::map filter(std::map const &m, F const &f) { + std::map result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/inplace_filter.h b/lib/utils/include/utils/containers/inplace_filter.h new file mode 100644 index 0000000000..949323c452 --- /dev/null +++ b/lib/utils/include/utils/containers/inplace_filter.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INPLACE_FILTER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INPLACE_FILTER_H + +#include +#include +#include +#include "utils/containers/filter.h" + +namespace FlexFlow { + +template +void inplace_filter(std::vector &v, F const &f) { + v.erase(std::remove_if(v.begin(), v.end(), [&](Elem const &e) { return !f(e); }), v.end()); +} + +template +void inplace_filter(std::unordered_set &s, F const &f) { + s = filter(s, f); +} + +template +void inplace_filter(std::set &s, F const &f) { + s = filter(s, f); +} + +template +void inplace_filter(std::unordered_map &s, F const &f) { + s = filter(s, f); +} + +template +void inplace_filter(std::map &s, F const &f) { + s = filter(s, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h index 54708cff31..76a1f3dcf5 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph_view.h @@ -8,7 +8,7 @@ namespace FlexFlow { -struct DataflowGraphView : virtual DiGraphView { +struct DataflowGraphView : virtual public DiGraphView { DataflowGraphView(DataflowGraphView const &) = default; DataflowGraphView &operator=(DataflowGraphView const &) = default; diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h index cfd6d7f5dd..ba85f31348 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_graph_view.h @@ -6,7 +6,7 @@ namespace FlexFlow { -struct OpenDataflowGraphView : virtual DataflowGraphView { +struct OpenDataflowGraphView : virtual public DataflowGraphView { public: OpenDataflowGraphView(OpenDataflowGraphView const &) = default; OpenDataflowGraphView &operator=(OpenDataflowGraphView const &) = default; @@ -27,6 +27,7 @@ struct OpenDataflowGraphView : virtual DataflowGraphView { private: IOpenDataflowGraphView const &get_interface() const; }; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OpenDataflowGraphView); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/filter.cc b/lib/utils/src/utils/containers/filter.cc new file mode 100644 index 0000000000..dc11d0dffa --- /dev/null +++ b/lib/utils/src/utils/containers/filter.cc @@ -0,0 +1 @@ +#include "utils/containers/filter.h" diff --git a/lib/utils/src/utils/containers/inplace_filter.cc b/lib/utils/src/utils/containers/inplace_filter.cc new file mode 100644 index 0000000000..6123b2f56a --- /dev/null +++ b/lib/utils/src/utils/containers/inplace_filter.cc @@ -0,0 +1 @@ +#include "utils/containers/inplace_filter.h" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index c0d537925a..60ddcfb4e8 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -9,7 +9,17 @@ std::unordered_set get_subgraph_inputs(OpenDataflowGraphView std::unordered_set const &subgraph_nodes) { std::unordered_set relevant_edges; for (std::vector const &incoming : values(get_incoming_edges(g, subgraph_nodes))) { - extend(relevant_edges, incoming); + auto comes_from_outside_subgraph = [&](OpenDataflowEdge const &e) -> bool { + return e.visit(overload { + [](DataflowInputEdge const &) { return true; }, + [&](DataflowEdge const &ee) { + assert (contains(subgraph_nodes, ee.dst.node)); + return !contains(subgraph_nodes, ee.src.node); + }, + }); + }; + + extend(relevant_edges, filter(incoming, comes_from_outside_subgraph)); } return transform(relevant_edges, get_open_dataflow_edge_source); diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index 7df8c3395d..e1e7a477b0 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -5,6 +5,8 @@ ff_add_test_executable( src/test_cow_ptr.cc src/utils/graph/dataflow_graph/unordered_set_dataflow_graph.cc src/test_optional.cc + src/utils/containers/filter.cc + src/utils/containers/inplace_filter.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc new file mode 100644 index 0000000000..4dc70af3fe --- /dev/null +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -0,0 +1,89 @@ +#include "test/utils/all.h" +#include "utils/containers/filter.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("filter(T, F)", T, std::vector, + std::unordered_set, + std::set, + std::unordered_map, + std::map) { + RC_SUBCASE("filter returns empty for predicate always_false", [](T const &t) { + auto always_false = [](auto const &) { return false; }; + T result = filter(t, always_false); + return result.size() == 0; + }); + + RC_SUBCASE("filter returns input for predicate always_true", [](T const &t) { + auto always_true = [](auto const &) { return true; }; + T result = filter(t, always_true); + return result == t; + }); + } + + TEST_CASE("filter(std::vector, F)") { + std::vector input = {1, 2, 3, 4, 5}; + auto predicate = [](int x) { return x % 2 == 0; }; + + std::vector result = filter(input, predicate); + std::vector correct = { 2, 4 }; + CHECK(result == correct); + } + + TEST_CASE("filter(std::unordered_set, F)") { + std::unordered_set input = {1, 2, 3, 4, 5, 6, 7, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + std::unordered_set result = filter(input, predicate); + std::unordered_set correct = { 2, 4, 6, 8 }; + CHECK(result == correct); + } + + TEST_CASE("filter(std::set, F)") { + std::set input = { 3, 2, 5, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + std::set result = filter(input, predicate); + std::set correct = { 2, 8 }; + CHECK(result == correct); + } + + TEST_CASE("filter(std::unordered_map, F)") { + std::unordered_map input = { + {3, "4"}, + {1, "1"}, + {2, "9"}, + {4, "4"}, + }; + auto predicate = [](std::pair const &x) { + return std::to_string(x.first) == x.second; + }; + + std::unordered_map result = filter(input, predicate); + std::unordered_map correct = { + {1, "1"}, + {4, "4"}, + }; + CHECK(result == correct); + } + + TEST_CASE("filter(std::map, F)") { + std::map input = { + {3, "4"}, + {1, "1"}, + {2, "9"}, + {4, "4"}, + }; + auto predicate = [](std::pair const &x) { + return std::to_string(x.first) != x.second; + }; + + std::map result = filter(input, predicate); + std::map correct = { + {3, "4"}, + {2, "9"}, + }; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc new file mode 100644 index 0000000000..d66af05e18 --- /dev/null +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -0,0 +1,90 @@ +#include "test/utils/all.h" +#include "utils/containers/inplace_filter.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("inplace_filter(T, F)", T, std::vector, + std::unordered_set, + std::set, + std::unordered_map, + std::map) { + RC_SUBCASE("inplace_filter returns empty for predicate always_false", [](T t) { + auto always_false = [](auto const &) { return false; }; + inplace_filter(t, always_false); + return t.size() == 0; + }); + + RC_SUBCASE("inplace_filter returns input for predicate always_true", [](T t) { + T input = t; + auto always_true = [](auto const &) { return true; }; + inplace_filter(t, always_true); + return t == input; + }); + } + + TEST_CASE("inplace_filter(std::vector &, F)") { + std::vector input = {1, 2, 3, 4, 5}; + auto predicate = [](int x) { return x % 2 == 0; }; + + inplace_filter(input, predicate); + std::vector correct = { 2, 4 }; + CHECK(input == correct); + } + + TEST_CASE("inplace_filter(std::unordered_set &, F)") { + std::unordered_set input = {1, 2, 3, 4, 5, 6, 7, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + inplace_filter(input, predicate); + std::unordered_set correct = { 2, 4, 6, 8 }; + CHECK(input == correct); + } + + TEST_CASE("inplace_filter(std::set &, F)") { + std::set input = { 3, 2, 5, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + inplace_filter(input, predicate); + std::set correct = { 2, 8 }; + CHECK(input == correct); + } + + TEST_CASE("inplace_filter(std::unordered_map &, F)") { + std::unordered_map input = { + {3, "4"}, + {1, "1"}, + {2, "9"}, + {4, "4"}, + }; + auto predicate = [](std::pair const &x) { + return std::to_string(x.first) == x.second; + }; + + inplace_filter(input, predicate); + std::unordered_map correct = { + {1, "1"}, + {4, "4"}, + }; + CHECK(input == correct); + } + + TEST_CASE("inplace_filter(std::map &, F)") { + std::map input = { + {3, "4"}, + {1, "1"}, + {2, "9"}, + {4, "4"}, + }; + auto predicate = [](std::pair const &x) { + return std::to_string(x.first) != x.second; + }; + + inplace_filter(input, predicate); + std::map correct = { + {3, "4"}, + {2, "9"}, + }; + CHECK(input == correct); + } +} From 5f4cc01c7dd8575c13ccc1ec03f237630e1e976f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 14 Jul 2024 19:00:40 -0700 Subject: [PATCH 19/71] Add tests for fmt::to_string, fix some substitutions bugs --- .../unlabelled/pattern_matching.h | 1 + .../unlabelled/find_pattern_matches.cc | 3 +- .../unlabelled/pattern_matching.cc | 18 ++++-- .../substitutions/unlabelled/pattern_split.cc | 61 +++++++++++++++++++ .../unlabelled/unlabelled_graph_pattern.cc | 36 +++++++++++ .../test/src/test_pattern_matches.cc | 60 +++++++++++------- lib/utils/include/utils/containers.decl.h | 10 +-- lib/utils/include/utils/containers.h | 24 -------- lib/utils/include/utils/containers/sorted.h | 53 ++++++++++++++++ lib/utils/include/utils/fmt/map.h | 21 ++++--- lib/utils/include/utils/fmt/optional.h | 26 ++++++++ lib/utils/include/utils/fmt/pair.h | 24 +++++++- lib/utils/include/utils/fmt/set.h | 5 +- lib/utils/include/utils/fmt/unordered_map.h | 23 ++++--- lib/utils/include/utils/fmt/unordered_set.h | 4 +- lib/utils/src/utils/containers/sorted.cc | 1 + lib/utils/test/CMakeLists.txt | 1 + lib/utils/test/src/utils/fmt/expected.cc | 22 +++++++ lib/utils/test/src/utils/fmt/map.cc | 13 ++++ lib/utils/test/src/utils/fmt/optional.cc | 22 +++++++ lib/utils/test/src/utils/fmt/pair.cc | 13 ++++ lib/utils/test/src/utils/fmt/set.cc | 13 ++++ lib/utils/test/src/utils/fmt/unordered_map.cc | 13 ++++ lib/utils/test/src/utils/fmt/unordered_set.cc | 13 ++++ lib/utils/test/src/utils/fmt/variant.cc | 22 +++++++ lib/utils/test/src/utils/fmt/vector.cc | 13 ++++ 26 files changed, 432 insertions(+), 83 deletions(-) create mode 100644 lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc create mode 100644 lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc create mode 100644 lib/utils/include/utils/containers/sorted.h create mode 100644 lib/utils/src/utils/containers/sorted.cc create mode 100644 lib/utils/test/src/utils/fmt/expected.cc create mode 100644 lib/utils/test/src/utils/fmt/map.cc create mode 100644 lib/utils/test/src/utils/fmt/optional.cc create mode 100644 lib/utils/test/src/utils/fmt/pair.cc create mode 100644 lib/utils/test/src/utils/fmt/set.cc create mode 100644 lib/utils/test/src/utils/fmt/unordered_map.cc create mode 100644 lib/utils/test/src/utils/fmt/unordered_set.cc create mode 100644 lib/utils/test/src/utils/fmt/variant.cc create mode 100644 lib/utils/test/src/utils/fmt/vector.cc diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index e91bfbb905..0fdeaa455e 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -17,6 +17,7 @@ OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &graph, UnlabelledDataflowGraphPatternMatch const &match); bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &subgraph, + bidict const &full_graph_values_to_subgraph_inputs, UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion); diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 009e955350..17a3a7da10 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -47,7 +47,6 @@ static std::optional for (auto const &[pattern_node_input, graph_node_input] : zip(pattern_node_inputs, graph_node_inputs)) { assert (pattern_node_input.has()); - assert (graph_node_input.has()); match.input_assignment.insert({ pattern_node_input.get(), @@ -56,7 +55,7 @@ static std::optional } assert (unlabelled_pattern_does_match(pattern, graph, match, match_additional_crition_always_true())); - + return match; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index bf2b37e0f6..fa86f33c3d 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -62,18 +62,22 @@ OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &g, // } struct ConcreteFromPattern { - ConcreteFromPattern(UnlabelledDataflowGraphPatternMatch const &match) - : match(match) + ConcreteFromPattern(UnlabelledDataflowGraphPatternMatch const &match, + bidict const &full_graph_values_to_subgraph_inputs) + : match(match), full_graph_values_to_subgraph_inputs(full_graph_values_to_subgraph_inputs) { } UnlabelledDataflowGraphPatternMatch const &match; + bidict const &full_graph_values_to_subgraph_inputs; Node operator()(PatternNode const &n) const { return match.node_assignment.at_l(n); } OpenDataflowValue operator()(PatternInput const &i) const { - return match.input_assignment.at(i); + return OpenDataflowValue{ + full_graph_values_to_subgraph_inputs.at_l(match.input_assignment.at(i)) + }; } OpenDataflowEdge operator()(InputPatternEdge const &e) const { @@ -117,9 +121,10 @@ struct ConcreteFromPattern { bool pattern_matches_subgraph_under(UnlabelledGraphPattern const &pattern, OpenDataflowGraphView const &subgraph, + bidict const &full_graph_values_to_subgraph_inputs, UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { - ConcreteFromPattern concrete_from_pattern{match}; + ConcreteFromPattern concrete_from_pattern{match, full_graph_values_to_subgraph_inputs}; std::unordered_set concrete_nodes = get_nodes(subgraph); std::unordered_set concrete_nodes_from_match = transform(get_nodes(pattern), concrete_from_pattern); @@ -163,12 +168,13 @@ bool unlabelled_pattern_does_match( UnlabelledDataflowGraphPatternMatch const &match, MatchAdditionalCriterion const &additional_criterion) { - OpenDataflowGraphView matched_subgraph = subgraph_matched(graph, match).graph; + OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); + OpenDataflowGraphView matched_subgraph = subgraph_result.graph; assert (keys(match.node_assignment) == get_nodes(pattern)); assert (keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph)); - return pattern_matches_subgraph_under(pattern, matched_subgraph, match, additional_criterion); + return pattern_matches_subgraph_under(pattern, matched_subgraph, subgraph_result.full_graph_values_to_subgraph_inputs, match, additional_criterion); } } // namespace FlexFlow diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc new file mode 100644 index 0000000000..0853be57d3 --- /dev/null +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -0,0 +1,61 @@ +#include "substitutions/unlabelled/pattern_value.h" +#include "test/utils/doctest.h" +#include "substitutions/unlabelled/pattern_split.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("pattern_split") { + OpenDataflowGraph g = OpenDataflowGraph::create(); + + NodeAddedResult n0_added = g.add_node({}, 1); + Node n0 = n0_added.node; + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + NodeAddedResult n1_added = g.add_node({v0}, 1); + Node n1 = n1_added.node; + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + + UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; + PatternNode p0 = PatternNode{n0}; + PatternNode p1 = PatternNode{n1}; + PatternValue pv0 = pattern_value_from_raw_open_dataflow_value(v0); + PatternValue pv1 = pattern_value_from_raw_open_dataflow_value(v1); + + PatternSplit even_split = PatternSplit{ + std::unordered_set{p0}, + std::unordered_set{p1}, + }; + + SUBCASE("find_even_split") { + PatternSplit result = find_even_split(pattern); + PatternSplit correct = even_split; + CHECK(result == correct); + } + + SUBCASE("apply_split") { + PatternSplitResult split_result = apply_split(pattern, even_split); + SUBCASE("subpattern_1") { + std::unordered_set result = get_nodes(split_result.subpattern_1); + std::unordered_set correct = even_split.first; + CHECK(result == correct); + } + SUBCASE("subpattern_2") { + std::unordered_set result = get_nodes(split_result.subpattern_2); + std::unordered_set correct = even_split.second; + CHECK(result == correct); + } + SUBCASE("subpattern_1_outputs_to_subpattern_2_inputs") { + bidict result = split_result.subpattern_1_outputs_to_subpattern_2_inputs; + PatternInput i0 = get_only(get_inputs(split_result.subpattern_2)); + bidict correct = { + {pv0, i0}, + }; + CHECK(result == correct); + } + } + } +} diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc new file mode 100644 index 0000000000..48207a769d --- /dev/null +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -0,0 +1,36 @@ +#include "test/utils/doctest.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_singleton_pattern") { + OpenDataflowGraph g = OpenDataflowGraph::create(); + + SUBCASE("0 nodes") { + UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; + + CHECK_FALSE(is_singleton_pattern(pattern)); + } + + NodeAddedResult n0_added = g.add_node({}, 1); + OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; + + SUBCASE("1 node") { + UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; + + CHECK(is_singleton_pattern(pattern)); + } + + NodeAddedResult n1_added = g.add_node({v0}, 1); + OpenDataflowValue v1 = OpenDataflowValue{get_only(n1_added.outputs)}; + + SUBCASE("more than 1 node") { + UnlabelledGraphPattern pattern = UnlabelledGraphPattern{g}; + + CHECK_FALSE(is_singleton_pattern(pattern)); + } + } +} diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 33903ce0a0..73e49df616 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -115,6 +115,14 @@ TEST_SUITE(FF_TEST_SUITE) { bidict{} }; + UnlabelledDataflowGraphPatternMatch invalid_match = UnlabelledDataflowGraphPatternMatch{ + bidict{ + {p0, n1}, + {p1, n2}, + }, + bidict{} + }; + std::vector n1_incoming = {OpenDataflowEdge{ DataflowEdge{ DataflowOutput{n0, 0}, @@ -149,24 +157,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - // { - // std::unordered_set xs = {n0, n1}; - // REQUIRE(contains(xs, n0)); - // REQUIRE(contains(xs, n1)); - // std::vector es = {OpenDataflowEdge{DataflowEdge{DataflowOutput{n0, 0}, DataflowInput{n1, 0}}}}; - // auto myfilter = [&](OpenDataflowEdge const &e) { - // return e.visit(overload { - // [](DataflowInputEdge const &) { return true; }, - // [&](DataflowEdge const &ee) { return !contains(xs, ee.src.node); }, - // }); - // }; - // std::vector result = filter(es, myfilter); - // auto myrealfilter = [&](OpenDataflowEdge const &e) { return true; }; - // // REQUIRE(myfilter(es.at(0)) == false); - // REQUIRE(result.size() == 0); - // } - - SUBCASE("get_subgraph_inputs") { std::unordered_set result = get_subgraph_inputs(graph, {n0, n1}); std::unordered_set correct = {}; @@ -200,12 +190,40 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("unlabelled_pattern_does_match") { CHECK(unlabelled_pattern_does_match(pattern, graph, match, match_additional_crition_always_true())); + CHECK_FALSE(unlabelled_pattern_does_match(pattern, graph, invalid_match, match_additional_crition_always_true())); + } + + SUBCASE("unlabelled_pattern_does_match (open)") { + OpenDataflowGraph g = OpenDataflowGraph::create(); + DataflowGraphInput i0 = g.add_input(); + + NodeAddedResult g_n0_added = g.add_node({OpenDataflowValue{i0}}, 1); + Node g_n0 = g_n0_added.node; + OpenDataflowValue g_v0 = OpenDataflowValue{get_only(g_n0_added.outputs)}; + PatternNode g_p0 = PatternNode{g_n0}; + PatternInput g_pi0 = PatternInput{i0}; + + UnlabelledGraphPattern open_pattern = UnlabelledGraphPattern{g}; + + UnlabelledDataflowGraphPatternMatch open_match = UnlabelledDataflowGraphPatternMatch{ + bidict{ + {g_p0, n1}, + }, + bidict{ + {g_pi0, v0}, + } + }; + CHECK(unlabelled_pattern_does_match(open_pattern, graph, open_match, match_additional_crition_always_true())); } - // std::vector matches = find_pattern_matches( - // pattern, graph, match_additional_crition_always_true()); + SUBCASE("find_pattern_matches") { + std::vector matches = find_pattern_matches( + pattern, graph, match_additional_crition_always_true()); + std::vector correct = {match}; + + CHECK(matches == correct); + } - // CHECK(matches.size() == 3); // // for (UnlabelledDataflowGraphPatternMatch const &match : matches) { // CHECK(unlabelled_pattern_does_match(pattern, diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 51fb160453..24dd694acb 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -7,6 +7,7 @@ #include #include #include +#include "utils/containers/sorted.h" namespace FlexFlow { @@ -250,15 +251,6 @@ template std::unordered_set flatmap_v2(std::unordered_set const &v, std::unordered_set (*f)(In const &)); -template -void inplace_sorted_by(C &c, F const &f); - -template -std::vector sorted(C const &c); - -template -std::vector sorted_by(C const &c, F const &f); - template std::function compare_by(F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 73a3f45f26..34fb6b4dbb 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -587,30 +587,6 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } -template -void inplace_sorted_by(C &c, F const &f) { - CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C); - - auto custom_comparator = [&](Elem const &lhs, Elem const &rhs) -> bool { - return f(lhs, rhs); - }; - std::sort(c.begin(), c.end(), custom_comparator); -} - -template -std::vector sorted(C const &c) { - std::vector result(c.begin(), c.end()); - inplace_sorted_by(result, [](Elem const &l, Elem const &r) { return l < r; }); - return result; -} - -template -std::vector sorted_by(C const &c, F const &f) { - std::vector result(c.begin(), c.end()); - inplace_sorted_by(result, f); - return result; -} - template std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; diff --git a/lib/utils/include/utils/containers/sorted.h b/lib/utils/include/utils/containers/sorted.h new file mode 100644 index 0000000000..8f95e4c334 --- /dev/null +++ b/lib/utils/include/utils/containers/sorted.h @@ -0,0 +1,53 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SORTED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SORTED_H + +#include "utils/type_traits_core.h" +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct sort_value_type : + type_identity {}; + +template +struct sort_value_type> : + type_identity> {}; + +template +struct sort_value_type> : + type_identity> {}; + +template +using sort_value_type_t = typename sort_value_type::type; + +template > +void inplace_sorted_by(C &c, F const &f) { + CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C); + + auto custom_comparator = [&](Elem const &lhs, Elem const &rhs) -> bool { + return f(lhs, rhs); + }; + std::sort(c.begin(), c.end(), custom_comparator); +} + +template > +std::vector sorted(C const &c) { + std::vector result(c.begin(), c.end()); + inplace_sorted_by(result, [](Elem const &l, Elem const &r) { return l < r; }); + return result; +} + +template > +std::vector sorted_by(C const &c, F const &f) { + std::vector result(c.begin(), c.end()); + inplace_sorted_by(result, f); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 1744130134..892c92da22 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -1,10 +1,12 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MAP_H -#include "fmt/format.h" +#include #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include +#include "utils/containers/sorted.h" +#include "utils/fmt/pair.h" namespace fmt { @@ -17,14 +19,15 @@ struct formatter< template auto format(::std::map const &m, FormatContext &ctx) -> decltype(ctx.out()) { - /* CHECK_FMTABLE(K); */ - /* CHECK_FMTABLE(V); */ - - /* std::string result = ::FlexFlow::join_strings( */ - /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return - * fmt::to_string(p); }); */ - std::string result = ""; - return formatter::format(result, ctx); + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + std::vector> items = ::FlexFlow::sorted(m); + + std::string result = ::FlexFlow::join_strings( + items.cbegin(), items.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); + + return formatter::format("{" + result + "}", ctx); } }; diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index aef77f7d0f..7f00c504af 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -5,6 +5,32 @@ #include #include "utils/check_fmtable.h" +namespace fmt { + +template +struct formatter< + ::std::optional, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::optional const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result; + if (m.has_value()) { + result = fmt::to_string(m.value()); + } else { + result = "nullopt"; + } + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + namespace FlexFlow { template diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index eb1147ae3c..4ca4efa1a3 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -1,10 +1,32 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#include "fmt/format.h" +#include #include "utils/check_fmtable.h" #include +namespace fmt { + +template +struct formatter< + ::std::pair, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::pair const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + std::string result = fmt::format("{{{}, {}}}", m.first, m.second); + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + namespace FlexFlow { template diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index bc50757400..c04d4b0653 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -5,6 +5,8 @@ #include "utils/join_strings.h" #include #include +#include +#include "utils/containers/sorted.h" namespace fmt { @@ -19,8 +21,9 @@ struct formatter< -> decltype(ctx.out()) { CHECK_FMTABLE(T); + std::vector items = ::FlexFlow::sorted(m); std::string result = - ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + ::FlexFlow::join_strings(items.cbegin(), items.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format("{" + result + "}", ctx); diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 19701bfb0c..086cd4eef0 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H -#include "fmt/format.h" +#include +#include #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include +#include "utils/fmt/pair.h" +#include +#include "utils/containers/sorted.h" namespace fmt { @@ -17,14 +21,15 @@ struct formatter< template auto format(::std::unordered_map const &m, FormatContext &ctx) -> decltype(ctx.out()) { - /* CHECK_FMTABLE(K); */ - /* CHECK_FMTABLE(V); */ - - /* std::string result = ::FlexFlow::join_strings( */ - /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return - * fmt::to_string(p); }); */ - std::string result = ""; - return formatter::format(result, ctx); + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + std::vector> items = ::FlexFlow::sorted(m); + + std::string result = ::FlexFlow::join_strings( + items.cbegin(), items.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); + + return formatter::format("{" + result + "}", ctx); } }; diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 8954faf7c5..7005b9aa70 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -5,6 +5,7 @@ #include "utils/join_strings.h" #include #include +#include "utils/containers/sorted.h" namespace fmt { @@ -19,8 +20,9 @@ struct formatter< -> decltype(ctx.out()) { CHECK_FMTABLE(T); + std::vector in_order = ::FlexFlow::sorted(m); std::string result = - ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + ::FlexFlow::join_strings(in_order.cbegin(), in_order.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format("{" + result + "}", ctx); diff --git a/lib/utils/src/utils/containers/sorted.cc b/lib/utils/src/utils/containers/sorted.cc new file mode 100644 index 0000000000..737310cd4f --- /dev/null +++ b/lib/utils/src/utils/containers/sorted.cc @@ -0,0 +1 @@ +#include "utils/containers/sorted.h" diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index e1e7a477b0..8791dc75f3 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -7,6 +7,7 @@ ff_add_test_executable( src/test_optional.cc src/utils/containers/filter.cc src/utils/containers/inplace_filter.cc + src/utils/fmt/*.cc PRIVATE_INCLUDE src/ DEPS diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc new file mode 100644 index 0000000000..47a4c82434 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -0,0 +1,22 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/expected.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(tl::expected)") { + SUBCASE("expected") { + tl::expected input = 4; + std::string result = fmt::to_string(input); + std::string correct = "expected(4)"; + CHECK(result == correct); + } + + SUBCASE("unexpected") { + tl::expected input = tl::unexpected("hello world"); + std::string result = fmt::to_string(input); + std::string correct = "unexpected(hello world)"; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/map.cc b/lib/utils/test/src/utils/fmt/map.cc new file mode 100644 index 0000000000..7beb03fe1a --- /dev/null +++ b/lib/utils/test/src/utils/fmt/map.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/map.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::map)") { + std::map input = {{0, 10}, {1, 1}, {3, 5}, {2, 8}}; + std::string result = fmt::to_string(input); + std::string correct = "{{0, 10}, {1, 1}, {2, 8}, {3, 5}}"; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/fmt/optional.cc b/lib/utils/test/src/utils/fmt/optional.cc new file mode 100644 index 0000000000..127efcba41 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/optional.cc @@ -0,0 +1,22 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/optional.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::optional)") { + SUBCASE("has value") { + std::optional input = 4; + std::string result = fmt::to_string(input); + std::string correct = "4"; + CHECK(result == correct); + } + + SUBCASE("does not have value") { + std::optional input = std::nullopt; + std::string result = fmt::to_string(input); + std::string correct = "nullopt"; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc new file mode 100644 index 0000000000..f22e2b0c65 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::pair)") { + std::pair input = {3, 5}; + std::string result = fmt::to_string(input); + std::string correct = "{3, 5}"; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/fmt/set.cc b/lib/utils/test/src/utils/fmt/set.cc new file mode 100644 index 0000000000..f527c82f09 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/set.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/set.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::set)") { + std::set input = {0, 1, 3, 2}; + std::string result = fmt::to_string(input); + std::string correct = "{0, 1, 2, 3}"; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/fmt/unordered_map.cc b/lib/utils/test/src/utils/fmt/unordered_map.cc new file mode 100644 index 0000000000..5235752c2a --- /dev/null +++ b/lib/utils/test/src/utils/fmt/unordered_map.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/unordered_map.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::unordered_map)") { + std::unordered_map input = {{0, 10}, {1, 1}, {3, 5}, {2, 8}}; + std::string result = fmt::to_string(input); + std::string correct = "{{0, 10}, {1, 1}, {2, 8}, {3, 5}}"; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc new file mode 100644 index 0000000000..4f083ed358 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/unordered_set.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::unordered_set)") { + std::unordered_set input = {0, 1, 3, 2}; + std::string result = fmt::to_string(input); + std::string correct = "{0, 1, 2, 3}"; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/fmt/variant.cc b/lib/utils/test/src/utils/fmt/variant.cc new file mode 100644 index 0000000000..3ef8aada65 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/variant.cc @@ -0,0 +1,22 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/variant.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::variant)") { + SUBCASE("has int") { + std::variant input = 4; + std::string result = fmt::to_string(input); + std::string correct = "4"; + CHECK(result == correct); + } + + SUBCASE("has string") { + std::variant input = "hello world"; + std::string result = fmt::to_string(input); + std::string correct = "hello world"; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/vector.cc b/lib/utils/test/src/utils/fmt/vector.cc new file mode 100644 index 0000000000..2baef980d8 --- /dev/null +++ b/lib/utils/test/src/utils/fmt/vector.cc @@ -0,0 +1,13 @@ +#include "test/utils/doctest.h" +#include "utils/fmt/vector.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::vector)") { + std::vector input = {0, 1, 3, 2}; + std::string result = fmt::to_string(input); + std::string correct = "[0, 1, 3, 2]"; + CHECK(result == correct); + } +} From ad60be0ce09a22bd6b79ccd877b3d2b967848a66 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 14 Jul 2024 19:18:58 -0700 Subject: [PATCH 20/71] Pass initial unit tests for find_pattern_matches --- .../unlabelled/find_pattern_matches.cc | 2 +- .../test/src/test_pattern_matches.cc | 9 ------ lib/utils/include/utils/optional.decl | 16 ----------- lib/utils/include/utils/optional.h | 28 ++----------------- 4 files changed, 3 insertions(+), 52 deletions(-) delete mode 100644 lib/utils/include/utils/optional.decl diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 17a3a7da10..827e8e77dc 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -88,7 +88,7 @@ std::vector merge_unlabelled_dataflow_graph_pattern_matches(prefix_match, postfix_match, subpatterns.subpattern_1_outputs_to_subpattern_2_inputs); - if (unsplit.has_value()) { + if (unsplit.has_value() && unlabelled_pattern_does_match(pattern, graph, unsplit.value(), additional_criterion)) { matches.push_back(unsplit.value()); } } diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 73e49df616..73c5b98b08 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -93,7 +93,6 @@ TEST_SUITE(FF_TEST_SUITE) { NodeAddedResult n0_added = graph.add_node({}, 1); Node n0 = n0_added.node; OpenDataflowValue v0 = OpenDataflowValue{get_only(n0_added.outputs)}; - // CHECK(v0 == OpenDataflowValue{DataflowOutput{n0, 0}}); NodeAddedResult n1_added = graph.add_node({v0}, 1); Node n1 = n1_added.node; @@ -223,13 +222,5 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(matches == correct); } - - // - // for (UnlabelledDataflowGraphPatternMatch const &match : matches) { - // CHECK(unlabelled_pattern_does_match(pattern, - // graph, - // match, - // match_additional_crition_always_true())); - // } } } diff --git a/lib/utils/include/utils/optional.decl b/lib/utils/include/utils/optional.decl deleted file mode 100644 index 82f4bd984d..0000000000 --- a/lib/utils/include/utils/optional.decl +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_OPTIONAL_H - -#include - -namespace FlexFlow { - -template -T const &unwrap(std::optional const &o, F const &f); - -template -T const &assert_unwrap(std::optional const &o); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 2594a96c8e..f7efa69e83 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,10 +1,9 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#include "fmt.h" -#include "rapidcheck.h" +#include #include "utils/exception.h" -#include "utils/optional.decl" +#include "utils/fmt/optional.h" namespace FlexFlow { @@ -38,29 +37,6 @@ std::optional> transform(std::optional const &o, } // namespace FlexFlow -namespace fmt { - -template -struct formatter< - ::std::optional, - Char, - std::enable_if_t>::value>> - : formatter { - template - auto format(::std::optional const &q, FormatContext &ctx) - -> decltype(ctx.out()) { - std::string result; - if (q.has_value()) { - result = fmt::to_string(q.value()); - } else { - result = "nullopt"; - } - return formatter::format(result, ctx); - } -}; - -} // namespace fmt - namespace rc { template From a972da26ebd79f45eb50e7d2014a51e53f946c1a Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 15 Jul 2024 00:15:50 -0700 Subject: [PATCH 21/71] Start on unit tests for pcg pattern --- .../parallel_computation_graph.h | 5 + .../parallel_computation_graph.cc | 8 ++ .../include/substitutions/pcg_pattern.h | 4 + .../sub_parallel_computation_graph.h | 9 ++ ...sub_parallel_computation_graph.struct.toml | 2 +- .../src/substitutions/pcg_pattern.cc | 35 +++-- .../sub_parallel_computation_graph.cc | 31 ++++ .../test/src/substitutions/pcg_pattern.cc | 133 ++++++++++++++++++ .../test/src/test_substitution.cc | 9 ++ lib/utils/include/utils/containers/sorted.h | 7 + lib/utils/include/utils/fmt/unordered_set.h | 19 ++- ...ordered_set_labelled_open_dataflow_graph.h | 29 +++- ...opy_of_labelled_open_dataflow_graph_view.h | 100 +++++++++++++ .../view_as_labelled_open_dataflow_graph.h | 54 +++++++ .../i_labelled_dataflow_graph.h | 3 + .../labelled_dataflow_graph.h | 14 +- .../i_labelled_open_dataflow_graph.h | 12 +- .../labelled_open_dataflow_graph.h | 7 + lib/utils/include/utils/type_traits.h | 9 -- lib/utils/include/utils/type_traits_core.h | 10 ++ .../view_as_labelled_open_dataflow_graph.cc | 1 + 21 files changed, 466 insertions(+), 35 deletions(-) create mode 100644 lib/substitutions/test/src/substitutions/pcg_pattern.cc create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index a320a4bbc1..46cf775b9d 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -12,6 +12,8 @@ ParallelComputationGraph empty_parallel_computation_graph(); std::unordered_set get_parallel_layers(ParallelComputationGraph const &); +std::unordered_set + get_parallel_tensors(ParallelComputationGraph const &); ParallelLayerAddedResult add_parallel_layer(ParallelComputationGraph &pcg, @@ -37,6 +39,9 @@ ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, std::vector topological_ordering(ParallelComputationGraph const &); +parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, + std::string const &name); + } // namespace FlexFlow #endif diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 4cc152d7b3..0c28aefe1c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -72,4 +72,12 @@ std::vector topological_ordering(ParallelComputationGraph [](Node const &n) { return parallel_layer_guid_t{n}; }); } +parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, + std::string const &name) { + std::unordered_set found = filter(get_parallel_layers(pcg), [&](parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).name == name; + }); + return get_only(found); +} + } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index 0d99818860..c91b7e9364 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -10,6 +10,10 @@ namespace FlexFlow { +std::vector + find_pattern_matches(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg); + UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); TensorAttributePattern get_tensor_pattern(PCGPattern const &, diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 9f45887206..9ed1e738df 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -1,10 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" namespace FlexFlow { + +std::unordered_set get_parallel_layers(SubParallelComputationGraph const &sub_pcg); ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, Node const &); PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, @@ -12,6 +16,11 @@ PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, OpenDataflowValue const &); +SubParallelComputationGraph sub_pcg_from_full_pcg(ParallelComputationGraph const &); +ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &); + +parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, + std::string const &name); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml index bcd5e42fc0..38ce364b49 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml @@ -10,4 +10,4 @@ includes = [ [[fields]] name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" +type = "::FlexFlow::LabelledOpenDataflowGraphView<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index a5fe879696..f07b5789dd 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -6,6 +6,29 @@ namespace FlexFlow { +static MatchAdditionalCriterion pcg_pattern_criteria(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg) { + return MatchAdditionalCriterion{ + [&](PatternNode const &patternNode, Node const &pcgNode) { + return operator_satisfies_pattern( + get_operator_attrs(pcg, pcgNode), + get_operator_pattern(pattern, patternNode)); + }, + [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { + return parallel_tensor_satisfies_pattern( + get_parallel_tensor_attrs(pcg, pcgValue), + get_tensor_pattern(pattern, patternValue)); + }}; +} + +std::vector + find_pattern_matches(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg) { + return find_pattern_matches(get_unlabelled_pattern(pattern), + pcg.raw_graph, + pcg_pattern_criteria(pattern, pcg)); +} + UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { return UnlabelledGraphPattern{p.raw_graph}; } @@ -27,17 +50,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, get_unlabelled_pattern(pattern), pcg.raw_graph, patternMatch, - MatchAdditionalCriterion{ - [&](PatternNode const &patternNode, Node const &pcgNode) { - return operator_satisfies_pattern( - get_operator_attrs(pcg, pcgNode), - get_operator_pattern(pattern, patternNode)); - }, - [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { - return parallel_tensor_satisfies_pattern( - get_parallel_tensor_attrs(pcg, pcgValue), - get_tensor_pattern(pattern, patternValue)); - }}); + pcg_pattern_criteria(pattern, pcg)); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 965d77f3d1..e597f8b5f5 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -1,7 +1,17 @@ #include "substitutions/sub_parallel_computation_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/node/algorithms.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" namespace FlexFlow { +std::unordered_set + get_parallel_layers(SubParallelComputationGraph const &sub_pcg) { + return get_parallel_layers(pcg_from_sub_pcg_by_dropping_inputs(sub_pcg)); +} + ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, Node const &n) { @@ -19,4 +29,25 @@ ParallelTensorAttrs return spcg.raw_graph.at(v); } +SubParallelComputationGraph sub_pcg_from_full_pcg(ParallelComputationGraph const &pcg) { + return SubParallelComputationGraph{view_as_labelled_open_dataflow_graph(pcg.raw_graph)}; +} + +ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs(SubParallelComputationGraph const &sub_pcg) { + return ParallelComputationGraph{ + LabelledDataflowGraph::create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>(sub_pcg.raw_graph) + }; + // return ParallelComputationGraph{ + // make_lazy_copy_of< + // UnorderedSetLabelledOpenDataflowGraph + // >(sub_pcg.raw_graph) + // }; +} + +parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, + std::string const &name) { + return get_parallel_layer_by_name(pcg_from_sub_pcg_by_dropping_inputs(pcg), name); +} + } // namespace FlexFlow diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc new file mode 100644 index 0000000000..0da2d87527 --- /dev/null +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -0,0 +1,133 @@ +#define DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "test/utils/doctest.h" +#include "substitutions/pcg_pattern.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_pattern_matches(PCGPattern, SubParallelComputationGraph)") { + ParallelComputationGraphBuilder builder; + + size_t batch_size = 16; + int batch_degree = 2; + size_t num_channels = 24; + + ParallelTensorShape a_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{batch_size, batch_degree}, + ShardParallelDim{num_channels, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + std::string a_name = "a"; + + parallel_tensor_guid_t a_tensor = builder.create_input_tensor(a_shape, /*create_grad=*/true, a_name); + + int outDim = 16; + std::string x_matmul_name = "x_matmul"; + std::string y_matmul_name = "y_matmul"; + parallel_tensor_guid_t t0 = builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + x_matmul_name); + parallel_tensor_guid_t t1 = builder.dense(a_tensor, + outDim, + /*activation=*/std::nullopt, + /*use_bias=*/false, + DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + y_matmul_name); + parallel_tensor_guid_t t2 = builder.add(t0, t1); + + ParallelComputationGraph pcg = builder.pcg; + parallel_layer_guid_t x_matmul = get_parallel_layer_by_name(pcg, x_matmul_name); + parallel_layer_guid_t y_matmul = get_parallel_layer_by_name(pcg, y_matmul_name); + std::vector x_inputs = get_layer_inputs(pcg, x_matmul); + REQUIRE(x_inputs.size() == 2); + parallel_tensor_guid_t x_weights = x_inputs.at(1); + std::vector y_inputs = get_layer_inputs(pcg, y_matmul); + REQUIRE(y_inputs.size() == 2); + parallel_tensor_guid_t y_weights = y_inputs.at(1); + + LabelledOpenDataflowGraph g = + LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + TensorAttributePattern pattern_tensor_a = TensorAttributePattern{{}}; + TensorAttributePattern pattern_tensor_b = TensorAttributePattern{{}}; + TensorAttributePattern pattern_tensor_c = TensorAttributePattern{{}}; + TensorAttributePattern pattern_tensor_x = TensorAttributePattern{{}}; + TensorAttributePattern pattern_tensor_y = TensorAttributePattern{{}}; + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{ + { + OperatorAttributeConstraint{ + ConstraintType::EQUAL, + OperatorAttributeExpr{OperatorAttributeKey::OP_TYPE}, + OperatorAttributeValue{OperatorType::LINEAR}, + } + } + }; + + OperatorAttributePattern op_pattern_2 = op_pattern_1; + + DataflowGraphInput pt_a = g.add_input(pattern_tensor_a); + DataflowGraphInput pt_b = g.add_input(pattern_tensor_b); + DataflowGraphInput pt_c = g.add_input(pattern_tensor_c); + + NodeAddedResult op_pattern_1_added = g.add_node(op_pattern_1, {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_b}}, {pattern_tensor_x}); + PatternNode op_pattern_1_node = PatternNode{op_pattern_1_added.node}; + OpenDataflowValue pt_x = OpenDataflowValue{get_only(op_pattern_1_added.outputs)}; + + NodeAddedResult op_pattern_2_added = g.add_node(op_pattern_2, {OpenDataflowValue{pt_a}, OpenDataflowValue{pt_c}}, {pattern_tensor_y}); + PatternNode op_pattern_2_node = PatternNode{op_pattern_2_added.node}; + OpenDataflowValue pt_y = OpenDataflowValue{get_only(op_pattern_2_added.outputs)}; + + PCGPattern pattern = PCGPattern{g}; + + std::unordered_set result = without_order(find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + // + // UnlabelledDataflowGraphPatternMatch match1 = UnlabelledDataflowGraphPatternMatch{ + // bidict{ + // {op_pattern_1_node, x_matmul.raw_graph_node}, + // {op_pattern_2_node, y_matmul.raw_graph_node}, + // }, + // bidict{ + // {PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}}, + // {PatternInput{pt_b}, OpenDataflowValue{x_weights.raw_graph_output}}, + // {PatternInput{pt_c}, OpenDataflowValue{y_weights.raw_graph_output}}, + // } + // }; + // + // UnlabelledDataflowGraphPatternMatch match2 = UnlabelledDataflowGraphPatternMatch{ + // bidict{ + // {op_pattern_1_node, y_matmul.raw_graph_node}, + // {op_pattern_2_node, x_matmul.raw_graph_node}, + // }, + // bidict{ + // {PatternInput{pt_a}, OpenDataflowValue{a_tensor.raw_graph_output}}, + // {PatternInput{pt_b}, OpenDataflowValue{y_weights.raw_graph_output}}, + // {PatternInput{pt_c}, OpenDataflowValue{x_weights.raw_graph_output}}, + // } + // }; + // + // std::unordered_set correct = {match1, match2}; + // + // CHECK(result == correct); + } +} diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 156d573ab8..cc1333959d 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -5,6 +5,15 @@ using namespace FlexFlow; +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("substitution") { + // PCGPattern pattern; + // OutputGraphExpr output_expr; + // bidict> : template using sort_value_type_t = typename sort_value_type::type; +template +struct is_sortable : + is_lt_comparable> {}; + +template +inline constexpr bool is_sortable_v = is_sortable::value; + template > void inplace_sorted_by(C &c, F const &f) { CHECK_SUPPORTS_ITERATOR_TAG(std::random_access_iterator_tag, C); diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 7005b9aa70..4687da43ff 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -6,6 +6,7 @@ #include #include #include "utils/containers/sorted.h" +#include "utils/type_traits_core.h" namespace fmt { @@ -20,11 +21,19 @@ struct formatter< -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::vector in_order = ::FlexFlow::sorted(m); - std::string result = - ::FlexFlow::join_strings(in_order.cbegin(), in_order.cend(), ", ", [](T const &t) { - return fmt::to_string(t); - }); + std::string result; + if constexpr (::FlexFlow::is_sortable_v>) { + std::vector in_order = ::FlexFlow::sorted(m); + result = + ::FlexFlow::join_strings(in_order.cbegin(), in_order.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + } else { + result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + } return formatter::format("{" + result + "}", ctx); } }; diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index 8213534e7c..3793bee4d5 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -12,14 +12,25 @@ #include "utils/containers/without_nullopts.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/containers.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { template -struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflowGraph { +struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflowGraph, + public ILabelledDataflowGraph { public: UnorderedSetLabelledOpenDataflowGraph() = default; + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) override { + return this->add_node(node_label, + transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }), + output_labels); + } + NodeAddedResult add_node(NodeLabel const &node_label, std::vector const &inputs, std::vector const &output_labels) override { @@ -65,6 +76,10 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { return without_nullopts(transform(keys(this->values), [&](OpenDataflowValue const &v) -> std::optional { + if (!v.has()) { + return std::nullopt; + } + DataflowOutput o = v.get(); if (dataflow_output_query_includes_dataflow_output(q, o)) { return o; @@ -86,6 +101,18 @@ struct UnorderedSetLabelledOpenDataflowGraph final : public ILabelledOpenDataflo return this->values.at(v); } + virtual void inplace_materialize_from(LabelledDataflowGraphView const &view) override { + std::unordered_set nodes = get_nodes(view); + std::unordered_set outputs = get_all_dataflow_outputs(view); + std::unordered_set edges = get_edges(view); + std::unordered_map labelled_outputs = generate_map(outputs, [&](DataflowOutput const &o) { return view.at(o); }); + + this->inputs.clear(); + this->nodes = generate_map(nodes, [&](Node const &n) { return view.at(n); }); + this->edges = transform(edges, [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); + this->values = map_keys(labelled_outputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }); + } + UnorderedSetLabelledOpenDataflowGraph *clone() const override { return new UnorderedSetLabelledOpenDataflowGraph{ this->node_source, diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h new file mode 100644 index 0000000000..f2fcbac281 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h @@ -0,0 +1,100 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H + +#include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include + +namespace FlexFlow { + +template +struct LazyLabelledDataflowGraph final : public ILabelledDataflowGraph { +public: + LazyLabelledDataflowGraph() = delete; + LazyLabelledDataflowGraph(LabelledDataflowGraphView const &view, + std::function(LabelledDataflowGraphView const &)> const &make_copy_func) + : g(view), make_copy_func(make_copy_func) {} + + NodeAddedResult add_node(NodeLabel const &node_label, + std::vector const &inputs, + std::vector const &output_labels) override { + return this->get_mutable_graph().add_node(node_label, inputs, output_labels); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->get_view().query_nodes(q); + } + + std::unordered_set query_edges(DataflowEdgeQuery const &q) const override { + return this->get_view().query_edges(q); + } + + std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { + return this->get_view().query_outputs(q); + } + + NodeLabel const &at(Node const &n) const override { + return this->get_view().at(n); + } + + ValueLabel const &at(DataflowOutput const &v) const override { + return this->get_view().at(v); + } + + LazyLabelledDataflowGraph *clone() const override { + return new LazyLabelledDataflowGraph(this->g, this->make_copy_func); + } + + void inplace_materialize_from(LabelledDataflowGraphView const &view) override { + this->g = view; + } +private: + std::variant< + LabelledDataflowGraphView, + LabelledDataflowGraph + > g; + std::function(LabelledDataflowGraphView const &)> make_copy_func; +private: + LazyLabelledDataflowGraph(decltype(g) const &g, + decltype(make_copy_func) const &make_copy_func) + : g(g), make_copy_func(make_copy_func) {} + + LabelledDataflowGraphView const &get_view() const { + if (g.index() == 0) { + return std::get<0>(this->g); + } else { + assert (g.index() == 1); + return std::get<1>(this->g); + } + } + + LabelledDataflowGraph &get_mutable_graph() { + if (g.index() == 0) { + this->g = this->make_copy_func(std::get<0>(g)); + } + assert (g.index() == 1); + + return std::get<1>(g); + } +}; + +template + static typename std::enable_if, T>::value, + LabelledDataflowGraph>::type + make_lazy_copy_of(LabelledDataflowGraphView const &view) { + std::function< + LabelledDataflowGraph( + LabelledDataflowGraphView const & + ) + > make_copy_func = [](LabelledDataflowGraphView const &v) { + return LabelledDataflowGraph::template create_copy_of(v); + }; + return LabelledDataflowGraph::template create< + LazyLabelledDataflowGraph>(view, make_copy_func); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h new file mode 100644 index 0000000000..9ae0f6ab7c --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_AS_OPEN_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_AS_OPEN_GRAPH_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +template +struct LabelledDataflowGraphAsOpenView final : public ILabelledOpenDataflowGraphView { +public: + LabelledDataflowGraphAsOpenView() = delete; + LabelledDataflowGraphAsOpenView(LabelledDataflowGraphView const &g) + : g(g) {} + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set query_edges(OpenDataflowEdgeQuery const &q) const override { + return transform(this->g.query_edges(q.standard_edge_query), [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); + } + + std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { + return this->g.query_outputs(q); + } + + std::unordered_set get_inputs() const override { + return {}; + } + + NodeLabel const &at(Node const &n) const override { + return this->g.at(n); + } + + ValueLabel const &at(OpenDataflowValue const &v) const override { + return this->g.at(v.get()); + } + + LabelledDataflowGraphAsOpenView *clone() const override { + return new LabelledDataflowGraphAsOpenView{this->g}; + } +private: + LabelledDataflowGraphView g; +}; + +template +LabelledOpenDataflowGraphView view_as_labelled_open_dataflow_graph(LabelledDataflowGraphView const &g) { + return LabelledOpenDataflowGraphView::template create< + LabelledDataflowGraphAsOpenView>(g); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h index 34ae475ab4..b8d139b228 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h @@ -3,6 +3,7 @@ #include "utils/graph/dataflow_graph/node_added_result.dtg.h" #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" namespace FlexFlow { @@ -13,6 +14,8 @@ struct ILabelledDataflowGraph : virtual public ILabelledDataflowGraphView const &inputs, std::vector const &output_labels) = 0; + virtual void inplace_materialize_from(LabelledDataflowGraphView const &) = 0; + virtual ~ILabelledDataflowGraph() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledDataflowGraph); diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h index ea9b463790..9f9df9d79c 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h @@ -20,12 +20,22 @@ struct LabelledDataflowGraph : virtual LabelledDataflowGraphViewget_interface().add_node(node_label, inputs, output_labels); } + template + static typename std::enable_if::value, + LabelledDataflowGraph>::type + create(Args && ...args) { + return LabelledDataflowGraph(make_cow_ptr(std::forward(args)...)); + } + template static typename std::enable_if::value, LabelledDataflowGraph>::type - create() { - return LabelledDataflowGraph(make_cow_ptr()); + create_copy_of(LabelledDataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return LabelledDataflowGraph(std::move(impl)); } + protected: using LabelledDataflowGraphView::LabelledDataflowGraphView; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h index 1fe84179c2..b3039880bf 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -9,18 +9,18 @@ namespace FlexFlow { template struct ILabelledOpenDataflowGraph : virtual public ILabelledOpenDataflowGraphView - , virtual public ILabelledDataflowGraph { + , virtual public ILabelledDataflowGraphView { virtual NodeAddedResult add_node(NodeLabel const &node_label, std::vector const &inputs, std::vector const &output_labels) = 0; virtual DataflowGraphInput add_input(ValueLabel const &value_label) = 0; - NodeAddedResult add_node(NodeLabel const &node_label, - std::vector const &inputs, - std::vector const &output_labels) override final { - return this->add_node(node_label, transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }), output_labels); - } + // NodeAddedResult add_node(NodeLabel const &node_label, + // std::vector const &inputs, + // std::vector const &output_labels) override final { + // return this->add_node(node_label, transform(inputs, [](DataflowOutput const &o) { return OpenDataflowValue{o}; }), output_labels); + // } virtual ~ILabelledOpenDataflowGraph() = default; }; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h index f0d6b6bd8f..2f2c78390b 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h @@ -24,6 +24,13 @@ struct LabelledOpenDataflowGraph : virtual public LabelledOpenDataflowGraphView< DataflowGraphInput add_input(ValueLabel const &value_label) { return this->get_interface().add_input(value_label); } + + template + static typename std::enable_if::value, + LabelledOpenDataflowGraph>::type + create() { + return LabelledOpenDataflowGraph(make_cow_ptr()); + } protected: using LabelledOpenDataflowGraphView::LabelledOpenDataflowGraphView; private: diff --git a/lib/utils/include/utils/type_traits.h b/lib/utils/include/utils/type_traits.h index 0c0408723d..7abb3ffd5b 100644 --- a/lib/utils/include/utils/type_traits.h +++ b/lib/utils/include/utils/type_traits.h @@ -64,15 +64,6 @@ template struct is_streamable())>> : std::true_type {}; -template -struct is_lt_comparable : std::false_type {}; - -template -struct is_lt_comparable< - T, - void_t() < std::declval()))>> - : std::true_type {}; - template