diff --git a/README.md b/README.md index 07a583f..c82ab92 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ RETURN A.club, B.club """) ``` +See [examples.md](examples.md) for more! + ### Example Usage with SQL Create your own "Sqlite for Neo4j"! This example uses [grand-graph](https://github.com/aplbrain/grand) to run queries in SQL: @@ -81,6 +83,7 @@ RETURN | Graph mutations (e.g. `DELETE`, `SET`,...) | 🛣 | | | `DISTINCT` | ✅ Thanks @jackboyla! | | | `ORDER BY` | ✅ Thanks @jackboyla! | | +| Aggregation functions (`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`) | ✅ Thanks @jackboyla! | | | | | | | -------------- | -------------- | ---------------- | diff --git a/examples.md b/examples.md new file mode 100644 index 0000000..0122329 --- /dev/null +++ b/examples.md @@ -0,0 +1,66 @@ + +## Multigraph + +```python +from grandcypher import GrandCypher +import networkx as nx + +host = nx.MultiDiGraph() +host.add_node("a", name="Alice", age=25) +host.add_node("b", name="Bob", age=30) +host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") +host.add_edge("b", "a", __labels__={"paid"}, amount=6) +host.add_edge("b", "a", __labels__={"paid"}, value=14) +host.add_edge("a", "b", __labels__={"friends"}, years=9) +host.add_edge("a", "b", __labels__={"paid"}, amount=40) + +qry = """ +MATCH (n)-[r:paid]->(m) +RETURN n.name, m.name, r.amount +""" +res = GrandCypher(host).run(qry) +print(res) + +''' +{ + 'n.name': ['Alice', 'Bob'], + 'm.name': ['Bob', 'Alice'], + 'r.amount': [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}] +} +''' +``` + +## Aggregation Functions + +```python +from grandcypher import GrandCypher +import networkx as nx + +host = nx.MultiDiGraph() +host.add_node("a", name="Alice", age=25) +host.add_node("b", name="Bob", age=30) +host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") +host.add_edge("b", "a", __labels__={"paid"}, amount=6) +host.add_edge("b", "a", __labels__={"paid"}, value=14) +host.add_edge("a", "b", __labels__={"friends"}, years=9) +host.add_edge("a", "b", __labels__={"paid"}, amount=40) + +qry = """ +MATCH (n)-[r:paid]->(m) +RETURN n.name, m.name, SUM(r.amount) +""" +res = GrandCypher(host).run(qry) +print(res) + +''' +{ + 'n.name': ['Alice', 'Bob'], + 'm.name': ['Bob', 'Alice'], + 'SUM(r.amount)': [{'paid': 52, 'friends': 0}, {'paid': 6}] +} +''' +``` + + + + diff --git a/grandcypher/__init__.py b/grandcypher/__init__.py index abd2824..e582af2 100644 --- a/grandcypher/__init__.py +++ b/grandcypher/__init__.py @@ -16,7 +16,7 @@ import grandiso -from lark import Lark, Transformer, v_args, Token +from lark import Lark, Transformer, v_args, Token, Tree _OPERATORS = { @@ -81,7 +81,13 @@ -return_clause : "return"i distinct_return? entity_id ("," entity_id)* + +return_clause : "return"i distinct_return? return_item ("," return_item)* +return_item : entity_id | aggregation_function | entity_id "." attribute_id + +aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")" +AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN" +attribute_id : CNAME distinct_return : "DISTINCT"i limit_clause : "limit"i NUMBER @@ -107,8 +113,8 @@ edge_match : LEFT_ANGLE? "--" RIGHT_ANGLE? | LEFT_ANGLE? "-[]-" RIGHT_ANGLE? | LEFT_ANGLE? "-[" CNAME "]-" RIGHT_ANGLE? - | LEFT_ANGLE? "-[" CNAME ":" TYPE "]-" RIGHT_ANGLE? - | LEFT_ANGLE? "-[" ":" TYPE "]-" RIGHT_ANGLE? + | LEFT_ANGLE? "-[" CNAME ":" type_list "]-" RIGHT_ANGLE? + | LEFT_ANGLE? "-[" ":" type_list "]-" RIGHT_ANGLE? | LEFT_ANGLE? "-[" "*" MIN_HOP "]-" RIGHT_ANGLE? | LEFT_ANGLE? "-[" "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE? | LEFT_ANGLE? "-[" CNAME "*" MIN_HOP "]-" RIGHT_ANGLE? @@ -118,6 +124,7 @@ | LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP "]-" RIGHT_ANGLE? | LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE? +type_list : TYPE ( "|" TYPE )* LEFT_ANGLE : "<" RIGHT_ANGLE : ">" @@ -228,10 +235,14 @@ def _is_edge_attr_match( motif_edges = _aggregate_edge_labels(motif_edges) host_edges = _aggregate_edge_labels(host_edges) + motif_types = motif_edges.get('__labels__', set()) + host_types = host_edges.get('__labels__', set()) + + if motif_types and not motif_types.intersection(host_types): + return False + for attr, val in motif_edges.items(): if attr == "__labels__": - if val and val - host_edges.get("__labels__", set()): - return False continue if host_edges.get(attr) != val: return False @@ -277,6 +288,7 @@ def _get_entity_from_host( edge_data = host.get_edge_data(*entity_name) if not edge_data: return None # print(f"Nothing found for {entity_name} {entity_attribute}") + if entity_attribute: # looking for edge attribute: if isinstance(host, nx.MultiDiGraph): @@ -371,6 +383,7 @@ def __init__(self, target_graph: nx.Graph, limit=None): self._matche_paths = None self._return_requests = [] self._return_edges = {} + self._aggregate_functions = [] self._distinct = False self._order_by = None self._order_by_attributes = set() @@ -478,9 +491,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: for r in ret: r_attr = {} for i, v in r.items(): - r_attr[i] = v.get(entity_attribute, None) + r_attr[(i, list(v.get('__labels__'))[0])] = v.get(entity_attribute, None) + # eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}] ret_with_attr.append(r_attr) - + ret = ret_with_attr result[data_path] = list(ret)[offset_limit] @@ -492,9 +506,19 @@ def return_clause(self, clause): # collect all entity identifiers to be returned for item in clause: if item: - if not isinstance(item, str): - item = str(item.value) - self._return_requests.append(item) + item = item.children[0] if isinstance(item, Tree) else item + if isinstance(item, Tree) and item.data == "aggregation_function": + func = str(item.children[0].value) # AGGREGATE_FUNC + entity = str(item.children[1].value) + if len(item.children) > 2: + entity += "." + str(item.children[2].children[0].value) + self._aggregate_functions.append((func, entity)) + self._return_requests.append(entity) + else: + if not isinstance(item, str): + item = str(item.value) + self._return_requests.append(item) + def order_clause(self, order_clause): self._order_by = [] @@ -520,12 +544,73 @@ def skip_clause(self, skip): skip = int(skip[-1]) self._skip = skip + + def aggregate(self, func, results, entity, group_keys): + # Collect data based on group keys + grouped_data = {} + for i in range(len(results[entity])): + group_tuple = tuple(results[key][i] for key in group_keys if key in results) + if group_tuple not in grouped_data: + grouped_data[group_tuple] = [] + grouped_data[group_tuple].append(results[entity][i]) + + def _collate_data(data, unique_labels, func): + # for ["COUNT", "SUM", "AVG"], we treat None as 0 + if func in ["COUNT", "SUM", "AVG"]: + collated_data = { + label: [(v or 0) for rel in data for k, v in rel.items() if k[1] == label] for label in unique_labels + } + # for ["MAX", "MIN"], we treat None as non-existent + elif func in ["MAX", "MIN"]: + collated_data = { + label: [v for rel in data for k, v in rel.items() if (k[1] == label and v is not None)] for label in unique_labels + } + + return collated_data + + # Apply aggregation function + aggregate_results = {} + for group, data in grouped_data.items(): + # data => [{(0, 'paid'): 70, (1, 'paid'): 90}] + unique_labels = set([k[1] for rel in data for k in rel.keys()]) + collated_data = _collate_data(data, unique_labels, func) + if func == "COUNT": + count_data = {label: len(data) for label, data in collated_data.items()} + aggregate_results[group] = count_data + elif func == "SUM": + sum_data = {label: sum(data) for label, data in collated_data.items()} + aggregate_results[group] = sum_data + elif func == "AVG": + sum_data = {label: sum(data) for label, data in collated_data.items()} + count_data = {label: len(data) for label, data in collated_data.items()} + avg_data = {label: sum_data[label] / count_data[label] if count_data[label] > 0 else 0 for label in sum_data} + aggregate_results[group] = avg_data + elif func == "MAX": + max_data = {label: max(data) for label, data in collated_data.items()} + aggregate_results[group] = max_data + elif func == "MIN": + min_data = {label: min(data) for label, data in collated_data.items()} + aggregate_results[group] = min_data + + aggregate_results = [v for v in aggregate_results.values()] + return aggregate_results + def returns(self, ignore_limit=False): results = self._lookup( self._return_requests + list(self._order_by_attributes), offset_limit=slice(0, None), ) + if len(self._aggregate_functions) > 0: + group_keys = [key for key in results.keys() if not any(key.endswith(func[1]) for func in self._aggregate_functions)] + + aggregated_results = {} + for func, entity in self._aggregate_functions: + aggregated_data = self.aggregate(func, results, entity, group_keys) + func_key = f"{func}({entity})" + aggregated_results[func_key] = aggregated_data + self._return_requests.append(func_key) + results.update(aggregated_results) if self._order_by: results = self._apply_order_by(results) if self._distinct: @@ -775,10 +860,21 @@ def entity_id(self, entity_id): return ".".join(entity_id) return entity_id.value - def edge_match(self, edge_name): - direction = cname = min_hop = max_hop = edge_type = None + def edge_match(self, edge_tokens): + def flatten_tokens(edge_tokens): + flat_tokens = [] + for token in edge_tokens: + if isinstance(token, Tree): + flat_tokens.extend(flatten_tokens(token.children)) # Recursively flatten the tree + else: + flat_tokens.append(token) + return flat_tokens + + direction = cname = min_hop = max_hop = None + edge_types = [] + edge_tokens = flatten_tokens(edge_tokens) - for token in edge_name: + for token in edge_tokens: if token.type == "MIN_HOP": min_hop = int(token.value) elif token.type == "MAX_HOP": @@ -790,15 +886,19 @@ def edge_match(self, edge_name): elif token.type == "RIGHT_ANGLE": direction = "r" elif token.type == "TYPE": - edge_type = token.value + edge_types.append(token.value) else: cname = token direction = direction if direction is not None else "b" if (min_hop is not None or max_hop is not None) and (direction == "b"): - raise TypeError("not support edge hopping for bidirectional edge") + raise TypeError("Bidirectional edge does not support edge hopping") + + # Handle the case where no edge types are specified, defaulting to a generic type if needed + if edge_types == []: + edge_types = None - return (cname, edge_type, direction, min_hop, max_hop) + return (cname, edge_types, direction, min_hop, max_hop) def node_match(self, node_name): cname = node_type = json_data = None @@ -845,7 +945,7 @@ def match_clause(self, match_clause: Tuple): if maxh > self._max_hop: raise ValueError(f"max hop is caped at 100, found {maxh}!") if t: - t = set([t]) + t = set([t] if type(t) is str else t) self._motif.add_edges_from( edges, __min_hop__=minh, __max_hop__=maxh, __is_hop__=ish, __labels__=t ) diff --git a/grandcypher/test_queries.py b/grandcypher/test_queries.py index 450b28b..afea528 100644 --- a/grandcypher/test_queries.py +++ b/grandcypher/test_queries.py @@ -909,8 +909,8 @@ def test_multiple_edges_specific_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice"] assert res["b.name"] == ["Bob"] - assert res["r.years"] == [{0: 3, 1: 5, 2: None}] # should return None when attr is missing - + assert res["r.years"] == [{(0, 'colleague'): 3, (1, 'friend'): 5, (2, 'enemy'): None}] # should return None when attr is missing + def test_edge_directionality(self): host = nx.MultiDiGraph() host.add_node("a", name="Alice", age=25) @@ -926,9 +926,8 @@ def test_edge_directionality(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Bob", "Alice"] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}] - + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2, (1, 'mentor'): 4}] def test_query_with_missing_edge_attribute(self): host = nx.MultiDiGraph() @@ -947,7 +946,7 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Charlie", "Charlie"] - assert res["r.duration"] == [{0: None}, {0: 10, 1: None}] # should return None when attr is missing + assert res["r.duration"] == [{(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}] qry = """ MATCH (a)-[r:colleague]->(b) @@ -956,7 +955,7 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Charlie", "Charlie"] - assert res["r.years"] == [{0: 10}, {0: None, 1: 2}] + assert res["r.years"] == [{(0, 'colleague'): 10}, {(0, 'colleague'): None, (1, 'mentor'): 2}] qry = """ MATCH (a)-[r]->(b) @@ -965,8 +964,8 @@ def test_query_with_missing_edge_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ['Alice', 'Alice', 'Bob'] assert res["b.name"] == ['Bob', 'Charlie', 'Charlie'] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.duration"] == [{0: None}, {0: None}, {0: 10, 1: None}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.duration"] == [{(0, 'friend'): None}, {(0, 'colleague'): None}, {(0, 'colleague'): 10, (1, 'mentor'): None}] def test_multigraph_single_edge_where(self): host = nx.MultiDiGraph() @@ -986,9 +985,9 @@ def test_multigraph_single_edge_where(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice", "Bob"] assert res["b.name"] == ["Bob", "Alice"] - assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}] - assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}] - assert res["r.friendly"] == [{0: 'very'}, {0: None, 1: None}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}, {(0, 'colleague'): {'colleague'}, (1, 'mentor'): {'mentor'}}] + assert res["r.years"] == [{(0, 'friend'): 1}, {(0, 'colleague'): 2, (1, 'mentor'): 4}] + assert res["r.friendly"] == [{(0, 'friend'): 'very'}, {(0, 'colleague'): None, (1, 'mentor'): None}] def test_multigraph_where_node_attribute(self): host = nx.MultiDiGraph() @@ -1008,9 +1007,136 @@ def test_multigraph_where_node_attribute(self): res = GrandCypher(host).run(qry) assert res["a.name"] == ["Alice"] assert res["b.name"] == ["Bob"] - assert res["r.__labels__"] == [{0: {'friend'}}] - assert res["r.years"] == [{0: 1}] - assert res["r.friendly"] == [{0: 'very'}] + assert res["r.__labels__"] == [{(0, 'friend'): {'friend'}}] + assert res["r.years"] == [{(0, 'friend'): 1}] + assert res["r.friendly"] == [{(0, 'friend'): 'very'}] + + def test_multigraph_multiple_same_edge_labels(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"friends"}, years=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, r.amount + """ + res = GrandCypher(host).run(qry) + assert res["n.name"] == ["Alice", "Bob"] + assert res["m.name"] == ["Bob", "Alice"] + # the second "paid" edge between Bob -> Alice has no "amount" attribute, so it should be None + assert res["r.amount"] == [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}] + + def test_multigraph_aggregation_function_sum(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("b", "a", __labels__={"paid"}, value=14) + host.add_edge("a", "b", __labels__={"friends"}, years=9) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, SUM(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res['SUM(r.amount)'] == [{'friends': 0, 'paid': 52}, {'paid': 6}] + + def test_multigraph_aggregation_function_avg(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June") + host.add_edge("b", "a", __labels__={"paid"}, amount=6, message="Thanks") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, AVG(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["AVG(r.amount)"] == [{'paid': 26}, {'paid': 6}] + + def test_multigraph_aggregation_function_min(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + host.add_edge("a", "b", __labels__={"paid"}, value=4) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, MIN(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MIN(r.amount)"] == [{'paid': 12}, {'paid': 6}] + + def test_multigraph_aggregation_function_max(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, MAX(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MAX(r.amount)"] == [{'paid': 40}, {'paid': 6}] + + qry = """ + MATCH (n)-[r:owes]->(m) + RETURN n.name, m.name, MAX(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["MAX(r.amount)"] == [{'owes': 39}] + + def test_multigraph_aggregation_function_count(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, COUNT(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["COUNT(r.amount)"] == [{'paid': 2}, {'paid': 1}] + + def test_multigraph_multiple_aggregation_functions(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice", age=25) + host.add_node("b", name="Bob", age=30) + host.add_node("c", name="Christine") + host.add_edge("a", "b", __labels__={"paid"}, amount=40) + host.add_edge("a", "b", __labels__={"paid"}, amount=12) + host.add_edge("a", "c", __labels__={"owes"}, amount=39) + host.add_edge("b", "a", __labels__={"paid"}, amount=6) + + qry = """ + MATCH (n)-[r:paid]->(m) + RETURN n.name, m.name, COUNT(r.amount), SUM(r.amount) + """ + res = GrandCypher(host).run(qry) + assert res["COUNT(r.amount)"] == [{'paid': 2}, {'paid': 1}] + assert res["SUM(r.amount)"] == [{'paid': 52}, {'paid': 6}] class TestVariableLengthRelationship: @@ -1661,3 +1787,94 @@ def test_path(self, graph_type): res = GrandCypher(host).run(qry) assert len(res["P"][0]) == 5 + + +class TestMatchWithOrOperatorInRelationships: + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_match_with_single_or_operator(self, graph_type): + host = graph_type() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_node("c", name="Carol") + host.add_edge("a", "b", __labels__={"LOVES"}) + host.add_edge("b", "c", __labels__={"WORKS_WITH"}) + + qry = """ + MATCH (n1)-[r:LOVES|WORKS_WITH]->(n2) + RETURN n1.name, n2.name + """ + res = GrandCypher(host).run(qry) + assert res["n1.name"] == ["Alice", "Bob"] + assert res["n2.name"] == ["Bob", "Carol"] + + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_match_with_multiple_or_operators(self, graph_type): + host = graph_type() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_node("c", name="Carol") + host.add_node("d", name="Derek") + host.add_edge("a", "b", __labels__={"LOVES"}) + host.add_edge("a", "c", __labels__={"KNOWS"}) + host.add_edge("b", "c", __labels__={"LIVES_NEAR"}) + host.add_edge("b", "d", __labels__={"WORKS_WITH"}) + + qry = """ + MATCH (n1)-[r:LOVES|KNOWS|LIVES_NEAR]->(n2) + RETURN n1.name, n2.name + """ + res = GrandCypher(host).run(qry) + assert res["n1.name"] == ["Alice", "Alice", "Bob"] + assert res["n2.name"] == ["Bob", "Carol", "Carol"] + + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_match_with_or_operator_and_other_conditions(self, graph_type): + host = graph_type() + host.add_node("a", name="Alice", age=30) + host.add_node("b", name="Bob", age=25) + host.add_node("c", name="Carol", age=40) + host.add_edge("a", "b", __labels__={"LOVES"}) + host.add_edge("a", "c", __labels__={"KNOWS"}) + host.add_edge("b", "c", __labels__={"WORKS_WITH"}) + + qry = """ + MATCH (n1)-[r:LOVES|KNOWS]->(n2) + WHERE n1.age > 28 AND n2.age > 35 + RETURN n1.name, n2.name + """ + res = GrandCypher(host).run(qry) + assert res["n1.name"] == ["Alice"] + assert res["n2.name"] == ["Carol"] + + @pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES) + def test_no_results_when_no_matching_edges(self, graph_type): + host = graph_type() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_edge("a", "b", __labels__={"WORKS_WITH"}) + + qry = """ + MATCH (n1)-[r:IN_CITY|HAS_ROUTE]->(n2) + RETURN n1.name, n2.name + """ + res = GrandCypher(host).run(qry) + assert len(res["n1.name"]) == 0 # No results because no edges match + + def test_multigraph_match_with_single_or_operator(self): + host = nx.MultiDiGraph() + host.add_node("a", name="Alice") + host.add_node("b", name="Bob") + host.add_node("c", name="Carol") + host.add_node("d", name="Derek") + host.add_edge("a", "b", __labels__={"LOVES"}) + host.add_edge("b", "c", __labels__={"WORKS_WITH"}) + host.add_edge("b", "c", __labels__={"DISLIKES"}) + host.add_edge("b", "d", __labels__={"DISLIKES"}) + + qry = """ + MATCH (n1)-[r:IS_SUING|DISLIKES]->(n2) + RETURN n1.name, n2.name + """ + res = GrandCypher(host).run(qry) + assert res["n1.name"] == ["Bob", "Bob"] + assert res["n2.name"] == ["Carol", "Derek"]