Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for utils/graph library #1224

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
51332bd
fix the stack_vector
lambda7xx Oct 29, 2023
4efe602
fix the stack_string
lambda7xx Oct 29, 2023
43d7a6e
fix the build error for utils test
lambda7xx Oct 29, 2023
dde8733
start to implement the label
lambda7xx Nov 4, 2023
7a8e18c
Merge branch 'test-substitution' into lambda-utils-implement
lambda7xx Nov 5, 2023
c6c24a0
add test_openmultidigraph
lambda7xx Nov 6, 2023
0a50c47
start to implement the LabelledOpenMultiDiGraph
lambda7xx Nov 6, 2023
417552a
implement the first version to test LabelledOpenMultiDiGraph
lambda7xx Nov 6, 2023
06df2a4
add todo
lambda7xx Nov 6, 2023
950e11d
add test for class NodeLabelledOpenMultiDiGraph
lambda7xx Nov 7, 2023
143116a
add test for class NodeLabelledOpenMultiDiGraph
lambda7xx Nov 7, 2023
9f5ba9e
add test for node_labelled.h
lambda7xx Nov 7, 2023
b449853
add test for the OutputLabelledOpenMultiDiGraph
lambda7xx Nov 7, 2023
d7b1f91
add test for OutputLabelledMultiDiGraph
lambda7xx Nov 8, 2023
dd36780
add test for the LabelledMultiDiGraph
lambda7xx Nov 8, 2023
adc4b91
add test and implement the OpenMultiDiGraph
lambda7xx Nov 8, 2023
609bf2c
fix the open_graph.cc
lambda7xx Nov 9, 2023
089e07e
fix the ViewMultiDiGraphAsOpenMultiDiGraph::clone
lambda7xx Nov 9, 2023
a534fa2
fix the LabelledOpenMultiDiGraph<N, E, I, O>::operator LabelledOpenMu…
lambda7xx Nov 9, 2023
570f42f
fix the LabelledOpenMultiDiGraph<N, E, I, O>::operator LabelledOpenMu…
lambda7xx Nov 9, 2023
f91a04c
fix one bug of lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 9, 2023
a186681
fix some test
lambda7xx Nov 9, 2023
265bd03
fix the test of lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 10, 2023
e925ad0
leave const bug
lambda7xx Nov 10, 2023
2fdba1e
fix the bug of AdjacencyOpenMultiDiGraph
lambda7xx Nov 10, 2023
5dcef1c
build the lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 10, 2023
36739da
fix the bug of AdjacencyMultiDiGraph::add_edge
lambda7xx Nov 11, 2023
b5d3be8
has some bug on test_algorithm for get_incoming_edges(g, {n[1], n[3]}
lambda7xx Nov 11, 2023
2a78e11
debug the code
lambda7xx Nov 13, 2023
b17278a
merge the latest test-substitution
lambda7xx Nov 13, 2023
3712b21
fix the cow_ptr_t
lambda7xx Nov 13, 2023
0e26c55
fix the bug of IOpenMultiDiGraphView::query_edges
lambda7xx Nov 13, 2023
99e0087
fix the test_algorithms
lambda7xx Nov 13, 2023
e7c7241
weired bug multiple definition of doctest::toString(doctest::Approx …
lambda7xx Nov 13, 2023
2c58398
fix the weird bug by comment #define DOCTEST_CONFIG_IMPLEMENT_WITH_M…
lambda7xx Nov 14, 2023
d5c0cfa
add fmt for std::unorded_map
lambda7xx Nov 14, 2023
add01bf
refine the algorithms
lambda7xx Nov 14, 2023
0cb9249
refine the algorithm
lambda7xx Nov 14, 2023
f997e68
leave get_bfs_ordering
lambda7xx Nov 14, 2023
b4ab5af
the first version for lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 14, 2023
cff6b07
add more test for lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 14, 2023
b7f710c
fix the lib/utils/test/src/test_node_labelled_open.cc
lambda7xx Nov 18, 2023
8474f9b
fix the NodeLabelledMultiDiGraph in node_labelled.h
lambda7xx Nov 18, 2023
308826a
fix the constructor for NodeLabelledMultiDiGraph
lambda7xx Nov 18, 2023
b70ae72
pass the test_node_labelled
lambda7xx Nov 18, 2023
f9e8299
has some problem for OpenMultiDiGraph::create
lambda7xx Nov 18, 2023
ccb3245
fix some bug
lambda7xx Nov 18, 2023
0eefc39
pass the test_output_labelled_open.cc
lambda7xx Dec 1, 2023
4471ae1
format the code and there has some bug about the class OutputLabell…
lambda7xx Dec 1, 2023
4f35b7f
leave the lib/utils/test/src/test_openmultidigraph.cc to implment
lambda7xx Dec 1, 2023
cf4948c
LabelledMultiDiGraph has bug
lambda7xx Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
add_subdirectory(pcg)
add_subdirectory(compiler)
add_subdirectory(runtime)
add_subdirectory(op-attrs)
add_subdirectory(kernels)
#add_subdirectory(pcg)
#add_subdirectory(compiler)
#add_subdirectory(runtime)
#add_subdirectory(op-attrs)
#add_subdirectory(kernels)
add_subdirectory(utils)
add_subdirectory(ffi)
add_subdirectory(substitutions)
#add_subdirectory(ffi)
#add_subdirectory(substitutions)
11 changes: 11 additions & 0 deletions lib/utils/include/utils/containers.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,17 @@ std::vector<Out> repeat(int n, F const &f) {
return result;
}

template <typename F, typename Out>
std::vector<Out> repeat2(int n, F const &f, Out type_holder = nullptr) {
assert(n >= 0);

std::vector<Out> result;
for (int i = 0; i < n; i++) {
result.push_back(f(i));
}
return result;
}

template <typename T>
bidict<size_t, T> enumerate(std::unordered_set<T> const &c) {
bidict<size_t, T> m;
Expand Down
21 changes: 21 additions & 0 deletions lib/utils/include/utils/fmt.decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H

#include "fmt/format.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace FlexFlow {
Expand Down Expand Up @@ -36,6 +38,25 @@ struct formatter<::std::vector<T>> : formatter<::std::string> {
-> decltype(ctx.out());
};

template <typename Key,
typename T,
typename Hash,
typename KeyEqual,
typename Allocator>
struct formatter<::std::unordered_map<Key, T, Hash, KeyEqual, Allocator>>
: formatter<::std::string> {
template <typename FormatContext>
auto format(::std::unordered_map<Key, T, Hash, KeyEqual, Allocator> const &m,
FormatContext &ctx) -> decltype(ctx.out());
};

template <typename T, typename U>
struct formatter<::std::pair<T, U>> : formatter<std::string> {
template <typename FormatContext>
auto format(std::pair<T, U> const &p, FormatContext &ctx)
-> decltype(ctx.out());
};

} // namespace fmt

#endif
31 changes: 31 additions & 0 deletions lib/utils/include/utils/fmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,37 @@ auto formatter<::std::vector<T>>::format(::std::vector<T> const &m,
return formatter<std::string>::format(result, ctx);
}

template <typename Key,
typename T,
typename Hash,
typename KeyEqual,
typename Allocator>
template <typename FormatContext>
auto formatter<::std::unordered_map<Key, T, Hash, KeyEqual, Allocator>>::format(
::std::unordered_map<Key, T, Hash, KeyEqual, Allocator> const &m,
FormatContext &ctx) -> decltype(ctx.out()) {
std::string result = "1";
join_strings(
m.begin(),
m.end(),
", ",
[](const typename std::unordered_map<Key, T, Hash, KeyEqual, Allocator>::
value_type &entry) {
// Format each entry as "key: value"
return fmt::to_string(entry.first);
});

return formatter<std::string>::format(result, ctx);
}

template <typename T, typename U>
template <typename FormatContext>
auto formatter<::std::pair<T, U>>::format(std::pair<T, U> const &p,
FormatContext &ctx)
-> decltype(ctx.out()) {
return formatter<std::string>::format(fmt::to_string(p.first), ctx);
}

// CHECK_FMTABLE(std::vector<int>);
// CHECK_FMTABLE(std::unordered_set<int>);

Expand Down
6 changes: 3 additions & 3 deletions lib/utils/include/utils/graph/adjacency_openmultidigraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class AdjacencyOpenMultiDiGraph : virtual public IOpenMultiDiGraph {
AdjacencyOpenMultiDiGraph() = default;
std::unordered_set<Node> query_nodes(NodeQuery const &) const override;

// std::unordered_set<MultiDiEdge> query_edges(MultiDiEdgeQuery const &) const
// override;
std::unordered_set<MultiDiEdge>
query_edges(MultiDiEdgeQuery const &) const override;

std::unordered_set<OpenMultiDiEdge>
query_edges(OpenMultiDiEdgeQuery const &) const override;
Expand All @@ -63,7 +63,7 @@ class AdjacencyOpenMultiDiGraph : virtual public IOpenMultiDiGraph {
AdjacencyOutputEdges outputs;
};

CHECK_NOT_ABSTRACT(AdjacencyOpenMultiDiGraph);
// CHECK_NOT_ABSTRACT(AdjacencyOpenMultiDiGraph);

} // namespace FlexFlow

Expand Down
6 changes: 6 additions & 0 deletions lib/utils/include/utils/graph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ std::vector<Node> add_nodes(Graph &, int);
std::vector<Node> add_nodes(UndirectedGraph &, int);
std::vector<Node> add_nodes(DiGraph &, int);
std::vector<Node> add_nodes(MultiDiGraph &, int);
std::vector<Node> add_nodes(MultiDiGraphView &, int);

std::vector<NodePort> add_node_ports(MultiDiGraph &, int);

Expand Down Expand Up @@ -108,6 +109,9 @@ std::unordered_set<MultiDiInput> get_inputs(MultiDiGraphView const &);

std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &,
Node const &);

std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraph const &,
Node const &);
std::unordered_set<DirectedEdge> get_incoming_edges(DiGraphView const &,
Node const &);
std::unordered_set<UpwardOpenMultiDiEdge>
Expand All @@ -119,6 +123,8 @@ std::unordered_set<UpwardOpenMultiDiEdge>

std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &,
std::unordered_set<Node>);
std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraph const &,
std::unordered_set<Node>);
std::unordered_set<DirectedEdge>
get_incoming_edges(DiGraphView const &, std::unordered_set<Node> const &);

Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/digraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct DiGraphView : virtual public GraphView {
private:
IDiGraphView &get_ptr() const;

friend struct GraphInternal;
// friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView);

Expand Down Expand Up @@ -70,7 +70,7 @@ struct DiGraph : virtual DiGraphView {
private:
IDiGraph &get_ptr() const;

friend struct GraphInternal;
// friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph);

Expand Down
9 changes: 4 additions & 5 deletions lib/utils/include/utils/graph/labelled/labelled_open.decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ struct LabelledOpenMultiDiGraph {
void add_edge(InputMultiDiEdge const &e);
void add_edge(OutputMultiDiEdge const &e);

void add_label(MultiDiEdge const &e, EdgeLabel const &l);
void add_label(InputMultiDiEdge const &e, EdgeLabel const &l);
void add_label(OutputMultiDiEdge const &e, EdgeLabel const &l);

void add_edge(MultiDiEdge const &e, EdgeLabel const &l);
// void add_edge(InputMultiDiEdge const &e, EdgeLabel const &l);
// void add_edge(OutputMultiDiEdge const &e, EdgeLabel const &l);

EdgeLabel &at(MultiDiEdge const &e);
EdgeLabel const &at(MultiDiEdge const &e) const;

Expand All @@ -111,7 +110,7 @@ struct LabelledOpenMultiDiGraph {
create();

private:
LabelledOpenMultiDiGraph(cow_ptr_t<Interface> ptr);
LabelledOpenMultiDiGraph(cow_ptr_t<Interface> ptr) : ptr(ptr) {}

private:
cow_ptr_t<Interface> ptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ struct ILabelledOpenMultiDiGraphView
: public IOpenMultiDiGraphView,
public ILabelledMultiDiGraphView<NodeLabel, EdgeLabel> {
public:
std::unordered_set<MultiDiEdge>
query_edges(MultiDiEdgeQuery const &q) const final {
return map_over_unordered_set(
[](OpenMultiDiEdge const &e) { return get<MultiDiEdge>(e); },
IOpenMultiDiGraphView::query_edges(
static_cast<OpenMultiDiEdgeQuery>(q)));
using IOpenMultiDiGraphView::query_edges; // Add this line

std::unordered_set<MultiDiEdge> query_edges(MultiDiEdgeQuery const &q) const {
// return IOpenMultiDiGraphView::query_edges(q);
return IOpenMultiDiGraphView::query_edges(q);
}

using ILabelledMultiDiGraphView<NodeLabel, EdgeLabel>::at;
Expand Down
16 changes: 10 additions & 6 deletions lib/utils/include/utils/graph/labelled/node_labelled.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ struct NodeLabelledMultiDiGraph
}

NodeLabel &at(Node const &n) {
return nl.get_mutable()->get_label(n);
return get_nodelabel_ptr().get_label(n);
}

std::unordered_set<Node> query_nodes(NodeQuery const &q) const {
return get_ptr().query_nodes();
return get_ptr().query_nodes(q);
}

std::unordered_set<MultiDiEdge> query_edges(MultiDiEdge const &q) const {
return get_ptr().query_edges();
std::unordered_set<MultiDiEdge> query_edges(MultiDiEdgeQuery const &q) const {
return get_ptr().query_edges(q);
}

Node add_node(NodeLabel const &l) {
Node n = get_ptr().add_node();
nl->add_label(n, l);
get_nodelabel_ptr().add_label(n, l);
return n;
}

Expand All @@ -114,13 +114,17 @@ struct NodeLabelledMultiDiGraph

protected:
NodeLabelledMultiDiGraph(cow_ptr_t<Interface> ptr, cow_ptr_t<NodeLabelIf> nl)
: NodeLabelledMultiDiGraphView<NodeLabel>(ptr), nl(nl) {} //todo: this may have some problem, because it seems we don't have constructor method NodeLabelledMultiDiGraphView<NodeLabel>(ptr
: GraphView(ptr), nl(nl) {}

Interface &get_ptr() const {
return *std::reinterpret_pointer_cast<Interface>(
GraphView::ptr.get_mutable());
}

NodeLabelIf &get_nodelabel_ptr() const {
return *std::reinterpret_pointer_cast<NodeLabelIf>(nl.get_mutable());
}

cow_ptr_t<NodeLabelIf> nl;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph<int>);
Expand Down
11 changes: 8 additions & 3 deletions lib/utils/include/utils/graph/labelled/node_labelled_open.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_NODE_LABELLED_OPEN
#define _FLEXFLOW_UTILS_GRAPH_LABELLED_NODE_LABELLED_OPEN

#include "utils/graph/labelled/node_labelled.h"
#include "utils/graph/open_graphs.h"

namespace FlexFlow {
Expand Down Expand Up @@ -77,21 +78,21 @@ struct NodeLabelledOpenMultiDiGraph
}

NodeLabel &at(Node const &n) {
return nl->get_label(n);
return get_nodelabel_ptr().get_label(n);
}

std::unordered_set<Node> query_nodes(NodeQuery const &q) const {
return get_ptr().query_nodes(q);
}

std::unordered_set<OpenMultiDiEdge>
query_edges(OpenMultiDiEdge const &q) const {
query_edges(OpenMultiDiEdgeQuery const &q) const {
return get_ptr().query_edges(q);
}

Node add_node(NodeLabel const &l) {
Node n = get_ptr().add_node();
nl.get_mutable()->add_label(n, l);
get_nodelabel_ptr().add_label(n, l);
return n;
}

Expand Down Expand Up @@ -123,6 +124,10 @@ struct NodeLabelledOpenMultiDiGraph
GraphView::ptr.get_mutable());
}

INodeLabel &get_nodelabel_ptr() const {
return *std::reinterpret_pointer_cast<INodeLabel>(nl.get_mutable());
}

cow_ptr_t<INodeLabel> nl;
};

Expand Down
12 changes: 7 additions & 5 deletions lib/utils/include/utils/graph/labelled/output_labelled.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct OutputLabelledMultiDiGraph

Node add_node(NodeLabel const &l) {
Node n = get_ptr().add_node();
nl->add_label(n, l);
nl.get_mutable()->add_label(n, l);
return n;
}

Expand All @@ -96,8 +96,8 @@ struct OutputLabelledMultiDiGraph
return nl->get_label(n);
}

void add_output(MultiDiOutput const &o, OutputLabel const &l) {
ol->add_label(o, l);
void add_edge(MultiDiOutput const &o, OutputLabel const &l) {
ol.get_mutable()->add_label(o, l);
};

void add_edge(MultiDiOutput const &o, MultiDiInput const &i) {
Expand All @@ -109,7 +109,7 @@ struct OutputLabelledMultiDiGraph
}

OutputLabel &at(MultiDiOutput const &o) {
return ol->get_label(o);
return ol.get_mutable()->get_label(o);
}

OutputLabel const &at(MultiDiOutput const &o) const {
Expand Down Expand Up @@ -139,7 +139,9 @@ struct OutputLabelledMultiDiGraph
cow_ptr_t<INodeLabel> nl,
cow_ptr_t<IOutputLabel> ol)
: OutputLabelledMultiDiGraphView<NodeLabel, OutputLabel>(ptr), nl(nl),
ol(ol) {}
ol(ol) {
} // this exists some problem, interface is IMultiDiGraph, but
// OutputLabelledMultiDiGraphView needs IOutputLabelledMultiDiGraphView

private:
Interface &get_ptr() const {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_GRAPH_INTERFACES_H
#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_GRAPH_INTERFACES_H

#include "node_labelled_interfaces.h"
#include "node_labelled.h"

namespace FlexFlow {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "node_labelled.h"
#include "utils/graph/adjacency_openmultidigraph.h"
#include "utils/graph/labelled/node_labelled_open.h"

namespace FlexFlow {

Expand Down
Loading
Loading