Skip to content

Commit

Permalink
Re-enable substitutions (#1471)
Browse files Browse the repository at this point in the history
* Start on pcg builder

* Add tests and some implementation for pcg builder

* Add pcg tests, make dtgen constructors explicit to fix bug

* Add remainder of PCG tests

* Fix build issues in local-execution

* Format

* Address Reyna comments, add topological_order function for PCG

* Pre multidigraph refactor

* Removing visitable from sp code

* Add open dataflow graph, start to replace pcg dataflow graph

* Start refactoring substitutions

* Add utility functions to support pattern matching

* Pre-refactor inputs

* Fix proj url

* Get back to substitutions, now with unordered graph inputs

* Get substitutions building

* substitutions-tests now builds

* Fix bug in filter, pass some initial substitution tests

* Add tests for fmt::to_string, fix some substitutions bugs

* Pass initial unit tests for find_pattern_matches

* Start on unit tests for pcg pattern

* Pass initial test for find_pattern_matches

* Fix small build issue in tests

* Format

* Sync tests in CI with tests in proj

* Fix minor build errors in kernels and local-execution

* Format

* Remove outdated code

* More outdated code removal

* More cleanup, add test for sp decomposition

* Pull apart containers.h

* More sp testing and fixes

* Break up graph algorithms.h

* Pre- full SP algo commit

* Add initial implementation and tests for cbc decomposition and inverse line graph

* Pass test for get_inverse_line_graph

* Add new multidigraph

* Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph

* Add tests for parallel and series reduction finding

* Add really rough implementation of valdez sp decomposition

* Fix local-execution build

* Add implementations and tests for applying series/parallel reductions

* Format

* Clean up sp decomposition interface and tests

* Format

* Add comments for top-level substitutions functions, add proj doxygen support

* Start sketching out substitutions code

* Fix build errors

* Add ability to permute node ids

* Cleanup and start to test new substitutions code

* Add test case for evaluate_substitution_output

* Add naive isomorphism detection code

* Add graph inputs to open dataflow graph isomorphism

* Add input permutation to evaluate_substitution_output

* Fix permute_node_ids

* Add test for permute_input_ids

* Migrate over to mutable implementation of apply_substitution

* Add fast isomorphism checking and an initial implementation of full substitution logic

* Pass initial full substitutions test

* Cleanup old isomorphism checking code

* Fix post-merge bugs

* Fix broken pcg builder test

* Format

* Reorganize code and remove some outdated code pre-code-review

* Format

* Address review comments

* Address missed comment

* Remove latex dependency to avoid CI out-of-disk-space

* Format

* Fix build issues

* Fix incorrect test case
  • Loading branch information
lockshaw authored Sep 6, 2024
1 parent 1cfb07e commit 2b4106f
Show file tree
Hide file tree
Showing 332 changed files with 7,939 additions and 1,396 deletions.
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/compiler/src/machine_mapping.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
#include "utils/containers/contains_key.h"
#include "utils/containers/get_only.h"
#include "utils/containers/keys.h"
#include "utils/containers/merge_maps.h"
#include "utils/exception.h"
#include "utils/graph/graph_split.dtg.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.h"
Expand Down
File renamed without changes.
1 change: 0 additions & 1 deletion lib/local-execution/src/ops/pool_2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/pool_2d.h"
#include "utils/exception.decl.h"
#include "utils/exception.h"
#include "utils/hash-utils.h"

Expand Down
1 change: 0 additions & 1 deletion lib/local-execution/src/ops/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "kernels/transpose_kernels.h"
#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/transpose.h"
#include "utils/exception.decl.h"

using namespace FlexFlow::Kernels::Transpose;

Expand Down
15 changes: 0 additions & 15 deletions lib/op-attrs/include/op-attrs/as_dot.h

This file was deleted.

2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H

#include "op-attrs/computation_graph_op_attrs.dtg.h"
#include "utils/record_formatter.h"

namespace FlexFlow {

OperatorType get_op_type(ComputationGraphOpAttrs const &);
RecordFormatter as_dot(ComputationGraphOpAttrs const &);

} // namespace FlexFlow

Expand Down
30 changes: 30 additions & 0 deletions lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H

#include "op-attrs/dim_ordered.h"
#include "utils/bidict/bidict.h"
#include "utils/containers/count.h"

namespace FlexFlow {

/**
* @brief Generate a map from indices to elements of \p c.
*
* @note We return a <tt>std::map</tt> to prevent mixups of \ref ff_dim_t and
* \ref legion_dim_t. Note that <tt>std::map</tt> provides ordered iteration in
* increasing order, so iterating through the result of this function should
* function as expected.
*/
template <typename T>
std::map<ff_dim_t, T> enumerate(FFOrdered<T> const &ff_ordered) {
std::map<ff_dim_t, T> result;
for (int raw_ff_dim : count(ff_ordered.size())) {
ff_dim_t ff_dim = ff_dim_t{raw_ff_dim};
result.insert({ff_dim, ff_ordered.at(ff_dim)});
}
return result;
}

} // namespace FlexFlow

#endif
223 changes: 5 additions & 218 deletions lib/op-attrs/include/op-attrs/get_output_shapes.h
Original file line number Diff line number Diff line change
@@ -1,228 +1,15 @@
#ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H
#define _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H

#include "op-attrs/operator_attrs.h"
#include "op-attrs/parallel_tensor_shape.h"
#include "ops/reverse.h"
#include "tensor_shape.h"
#include "utils/containers/get_only.h"
#include "utils/optional.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/pcg_operator_attrs.dtg.h"
#include <vector>

namespace FlexFlow {

template <typename T, typename Enable = void>
struct has_unary_output_t : std::false_type {};
template <typename T, typename Enable = void>
struct has_unary_input_t : std::false_type {};
template <typename T, typename Enable = void>
struct has_binary_input_t : std::false_type {};

template <typename T, typename Enable = void>
struct has_multi_output_t : std::true_type {};
template <typename T, typename Enable = void>
struct has_multi_input_t : std::true_type {};

template <typename T>
struct has_multi_output_t<
T,
typename std::enable_if<has_unary_output_t<T>::value>::type>
: std::false_type {};

template <typename T>
struct has_multi_input_t<
T,
typename std::enable_if<(has_unary_input_t<T>::value ||
has_binary_input_t<T>::value)>::type>
: std::false_type {};

/* template <typename T, typename Enable = void> struct output_type_t { using
* type = std::vector<ParallelTensorShape>; }; */

template <typename T>
typename std::enable_if<has_unary_input_t<T>::value, bool>::type
is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
if (shapes.size() != 1) {
return false;
}

return is_valid(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<has_binary_input_t<T>::value, bool>::type
is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
if (shapes.size() != 2) {
return false;
}

return is_valid(t, shapes.at(0), shapes.at(1));
}

template <typename T>
typename std::enable_if<(has_unary_input_t<T>::value &&
has_unary_output_t<T>::value),
ParallelTensorShape>::type
output_shapes(T const &t, std::vector<ParallelTensorShape> const &shapes) {
return output_shape(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<(has_binary_input_t<T>::value &&
has_unary_output_t<T>::value),
std::vector<ParallelTensorShape>>::type
output_shapes(T const &t, std::vector<ParallelTensorShape> const &shapes) {
assert(shapes.size() == 2);

return {output_shape(t, shapes.at(0), shapes.at(1))};
}

TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &);
std::vector<TensorShape>
get_tensor_shapes_unsafe(std::vector<ParallelTensorShape> const &);

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs, TensorShape const &shape) {
NOT_IMPLEMENTED();
}

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs,
TensorShape const &,
TensorShape const &) {
NOT_IMPLEMENTED();
}

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs,
std::vector<TensorShape> const &) {
NOT_IMPLEMENTED();
}
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
TensorShape const &);
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
TensorShape const &,
TensorShape const &) {
NOT_IMPLEMENTED();
}
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
std::vector<TensorShape> const &);

ParallelTensorShape get_output_shape(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
ParallelTensorShape get_output_shape(FlatAttrs const &,
ParallelTensorShape const &);
std::vector<ParallelTensorShape> get_output_shapes(GatherAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(Pool2DAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReduceAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReverseAttrs const &,
ParallelTensorShape const &);
std::vector<ParallelTensorShape> get_output_shapes(SplitAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(TopKAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(TransposeAttrs const &,
std::vector<ParallelTensorShape> const &);

struct GetOutputShapesFunctor {
GetOutputShapesFunctor(std::vector<ParallelTensorShape> const &s) : s(s) {}

std::vector<ParallelTensorShape> const &s;

template <typename T>
std::vector<ParallelTensorShape> operator()(T const &t) {
return get_output_shapes(t, s);
}
};

template <typename... Ts>
std::vector<ParallelTensorShape>
get_output_shapes(std::variant<Ts...> const &t,
std::vector<ParallelTensorShape> const &s) {
return get_output_shape(GetOutputShapesFunctor{s}, t);
}

template <typename T>
typename std::enable_if<!has_unary_output_t<T>::value, std::optional<int>>::type
get_num_outputs(T const &) {
return std::nullopt;
}

template <typename T>
typename std::enable_if<has_unary_output_t<T>::value, std::optional<int>>::type
get_num_outputs(T const &) {
return 1;
}

int get_num_outputs(SplitAttrs const &attrs);

template <typename T>
bool is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
auto num_outputs = get_num_outputs(t);
if (num_outputs.has_value() && shapes.size() != num_outputs.value()) {
return false;
}

for (ParallelTensorShape const &shape : shapes) {
if (!is_valid(shape)) {
return false;
}
}

return is_valid_internal(t, shapes);
}

template <typename T>
typename std::enable_if<has_unary_input_t<T>::value, bool>::type
is_valid_internal(T const &t,
std::vector<ParallelTensorShape> const &shapes) {
return is_valid_internal(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<has_binary_input_t<T>::value, bool>::type
is_valid_internal(T const &t,
std::vector<ParallelTensorShape> const &shapes) {
return is_valid_internal(t, shapes.at(0), shapes.at(1));
}

bool is_valid_internal(MultiHeadAttentionAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(BatchMatmulAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ElementBinaryAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(GatherAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &);
get_output_shapes(PCGOperatorAttrs const &,
std::vector<ParallelTensorShape> const &);

} // namespace FlexFlow

Expand Down
59 changes: 59 additions & 0 deletions lib/op-attrs/include/op-attrs/is_valid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H

#include "op-attrs/parallel_tensor_shape.h"
#include "op-attrs/pcg_operator_attrs.dtg.h"

namespace FlexFlow {

template <typename T>
bool is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
auto num_outputs = get_num_outputs(t);
if (num_outputs.has_value() && shapes.size() != num_outputs.value()) {
return false;
}

for (ParallelTensorShape const &shape : shapes) {
if (!is_valid(shape)) {
return false;
}
}

return is_valid_internal(t, shapes);
}

bool is_valid_internal(MultiHeadAttentionAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(BatchMatmulAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ElementBinaryAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(GatherAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &);

} // namespace FlexFlow

#endif
Loading

0 comments on commit 2b4106f

Please sign in to comment.