Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for multigraphs #42

Merged
merged 10 commits into from
May 14, 2024
126 changes: 94 additions & 32 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

"""

from typing import Dict, List, Callable, Tuple
from typing import Dict, List, Callable, Tuple, Union
from collections import OrderedDict
import random
import string
Expand Down Expand Up @@ -119,7 +119,6 @@
| LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE?



LEFT_ANGLE : "<"
RIGHT_ANGLE : ">"
EQUAL : "="
Expand Down Expand Up @@ -198,14 +197,16 @@ def _is_node_attr_match(

@lru_cache()
def _is_edge_attr_match(
motif_edge_id: Tuple[str, str],
host_edge_id: Tuple[str, str],
motif: nx.Graph,
host: nx.Graph,
motif_edge_id: Tuple[str, str, Union[int, str]],
host_edge_id: Tuple[str, str, Union[int, str]],
motif: Union[nx.Graph, nx.MultiDiGraph],
host: Union[nx.Graph, nx.MultiDiGraph]
) -> bool:
"""
Check if an edge in the host graph matches the attributes in the motif.
This also check the __labels__ of edges.
Check if an edge in the host graph matches the attributes in the motif,
including the special '__labels__' set attribute.
This function formats edges into
nx.MultiDiGraph format i.e {0: first_relation, 1: ...}.

Arguments:
motif_edge_id (str): The motif edge ID
Expand All @@ -215,23 +216,50 @@ def _is_edge_attr_match(

Returns:
bool: True if the host edge matches the attributes in the motif

"""
motif_edge = motif.edges[motif_edge_id]
host_edge = host.edges[host_edge_id]
motif_u, motif_v = motif_edge_id
host_u, host_v = host_edge_id

# Format edges for both DiGraph and MultiDiGraph
motif_edges = _get_edge_attributes(motif, motif_u, motif_v)
host_edges = _get_edge_attributes(host, host_u, host_v)

for attr, val in motif_edge.items():
# Aggregate all __labels__ into one set
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)

for attr, val in motif_edges.items():
if attr == "__labels__":
if val and val - host_edge.get("__labels__", set()):
if val and val - host_edges.get("__labels__", set()):
return False
continue
if host_edge.get(attr) != val:
if host_edges.get(attr) != val:
return False

return True


def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict:
"""
Retrieve edge attributes from a graph, handling both Graph and MultiDiGraph.
"""
if isinstance(graph, nx.MultiDiGraph):
return graph[u][v]
return {0: graph[u][v]} # Mock single edge for DiGraph

def _aggregate_edge_labels(edges: Dict) -> Dict:
"""
Aggregate '__labels__' attributes from edges into a single set.
"""
aggregated = {"__labels__": set()}
for edge_id, attrs in edges.items():
if "__labels__" in attrs and attrs["__labels__"]:
aggregated["__labels__"].update(attrs["__labels__"])
elif "__labels__" not in attrs:
aggregated[edge_id] = attrs
return aggregated

def _get_entity_from_host(host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None):
if entity_name in host.nodes():
# We are looking for a node mapping in the target graph:
if entity_attribute:
Expand All @@ -248,7 +276,10 @@ def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
return None # print(f"Nothing found for {entity_name} {entity_attribute}")
if entity_attribute:
# looking for edge attribute:
return edge_data.get(entity_attribute, None)
if isinstance(host, nx.MultiDiGraph):
return [r.get(entity_attribute, None) for r in edge_data.values()]
else:
return edge_data.get(entity_attribute, None)
else:
return host.get_edge_data(*entity_name)

Expand Down Expand Up @@ -279,7 +310,7 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:


def cond_(should_be, entity_id, operator, value) -> CONDITION:
def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
def inner(match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list) -> bool:
host_entity_id = entity_id.split(".")
if host_entity_id[0] in match:
host_entity_id[0] = match[host_entity_id[0]]
Expand All @@ -290,7 +321,13 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")
try:
val = operator(_get_entity_from_host(host, *host_entity_id), value)
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
val = any(operator(r_val, value) for r_val in r_vals)
else:
val = operator(_get_entity_from_host(host, *host_entity_id), value)

except:
val = False
if val != should_be:
Expand Down Expand Up @@ -323,7 +360,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
self._paths = []
self._where_condition: CONDITION = None
self._motif = nx.DiGraph()
self._motif = nx.MultiDiGraph()
self._matches = None
self._matche_paths = None
self._return_requests = []
Expand Down Expand Up @@ -383,9 +420,9 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret.append(path)

else:
mapping_u, mapping_v = self._return_edges[data_path]
mapping_u, mapping_v = self._return_edges[data_path.split('.')[0]]
# We are looking for an edge mapping in the target graph:
is_hop = self._motif.edges[(mapping_u, mapping_v)]["__is_hop__"]
is_hop = self._motif.edges[(mapping_u, mapping_v, 0)]["__is_hop__"]
ret = (
_get_edge(
self._target_graph, mapping, match_path, mapping_u, mapping_v
Expand All @@ -395,13 +432,38 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret = (r[0] if is_hop else r for r in ret)
# we keep the original list if len > 2 (edge hop 2+)

# Get all edge labels from the motif -- this is used to filter the relations for multigraphs
motif_edge_labels = set()
for edge in self._motif.get_edge_data(mapping_u, mapping_v).values():
if edge.get('__labels__', None):
motif_edge_labels.update(edge['__labels__'])

if entity_attribute:
# Get the correct entity from the target host graph,
# and then return the attribute:
ret = (r.get(entity_attribute, None) for r in ret)
if isinstance(self._motif, nx.MultiDiGraph) and len(motif_edge_labels) > 0:
# filter the retrieved edge(s) based on the motif edge labels
filtered_ret = []
for r in ret:

if any([i.get('__labels__', None).issubset(motif_edge_labels) for i in r.values()]):
filtered_ret.append(r)

ret = filtered_ret

# get the attribute from the retrieved edge(s)
ret_with_attr = []
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[i] = v.get(entity_attribute, None)
ret_with_attr.append(r_attr)

ret = ret_with_attr

result[data_path] = list(ret)[offset_limit]


return result

def return_clause(self, clause):
Expand Down Expand Up @@ -606,7 +668,7 @@ def _is_limit(self, count):
# Check if limit reached
return self._limit and count >= (self._limit + self._skip)

def _edge_hop_motifs(self, motif: nx.DiGraph) -> List[Tuple[nx.Graph, dict]]:
def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]]:
"""generate a list of edge-hop-expanded motif with edge-hop-map.

Arguments:
Expand All @@ -618,29 +680,29 @@ def _edge_hop_motifs(self, motif: nx.DiGraph) -> List[Tuple[nx.Graph, dict]]:
where a real edge path can have more than 2 element (hop >= 2)
or it can have 2 same element (hop = 0).
"""
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
for n in motif.nodes:
if motif.out_degree(n) == 0 and motif.in_degree(n) == 0:
new_motif.add_node(n, **motif.nodes[n])
motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})]
for u, v in motif.edges:
for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
new_motifs = []
min_hop = motif.edges[u, v]["__min_hop__"]
max_hop = motif.edges[u, v]["__max_hop__"]
edge_type = motif.edges[u, v]["__labels__"]
min_hop = motif.edges[u, v, k]["__min_hop__"]
max_hop = motif.edges[u, v, k]["__max_hop__"]
edge_type = motif.edges[u, v, k]["__labels__"]
hops = []
if min_hop == 0:
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
new_motif.add_node(u, **motif.nodes[u])
new_motifs.append((new_motif, {(u, v): (u, u)}))
elif min_hop >= 1:
for _ in range(1, min_hop):
hops.append(shortuuid())
for _ in range(max(min_hop, 1), max_hop):
new_edges = [u] + hops + [v]
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
new_motif.add_edges_from(
list(zip(new_edges[:-1], new_edges[1:])), __labels__=edge_type
zip(new_edges, new_edges[1:]), __labels__=edge_type
)
new_motif.add_node(u, **motif.nodes[u])
new_motif.add_node(v, **motif.nodes[v])
Expand Down
123 changes: 123 additions & 0 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,129 @@ def test_order_by_with_non_returned_field(self):
assert res["n.name"] == ["Carol", "Alice", "Bob"]


class TestMultigraphRelations:
def test_query_with_multiple_relations(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Charlie", age=25)
host.add_node("d", name="Diana", age=25)

# Adding edges with labels for different types of relationship_type
host.add_edge("a", "b", __labels__={"friends"})
host.add_edge("a", "b", __labels__={"colleagues"})
host.add_edge("a", "c", __labels__={"colleagues"})
host.add_edge("b", "d", __labels__={"family"})
host.add_edge("c", "d", __labels__={"family"})
host.add_edge("c", "d", __labels__={"friends"})
host.add_edge("d", "a", __labels__={"friends"})
host.add_edge("d", "a", __labels__={"colleagues"})

qry = """
MATCH (n)-[r:friends]->(m)
RETURN n.name, m.name
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ['Alice', 'Charlie', 'Diana']
assert res["m.name"] == ['Bob', 'Diana', 'Alice']

def test_multiple_edges_specific_attribute(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=30)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"colleague"}, years=3)
host.add_edge("a", "b", __labels__={"friend"}, years=5)
host.add_edge("a", "b", __labels__={"enemy"}, hatred=10)

qry = """
MATCH (a)-[r:friend]->(b)
RETURN a.name, b.name, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice"]
assert res["b.name"] == ["Bob"]
assert res["r.years"] == [{0: 3, 1: 5, 2: None}] # should return None when attr is missing

def test_edge_directionality(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"friend"}, years=1)
host.add_edge("b", "a", __labels__={"colleague"}, years=2)
host.add_edge("b", "a", __labels__={"mentor"}, years=4)

qry = """
MATCH (a)-[r]->(b)
RETURN a.name, b.name, r.__labels__, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Bob", "Alice"]
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}]


def test_query_with_missing_edge_attribute(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=30)
host.add_node("b", name="Bob", age=40)
host.add_node("c", name="Charlie", age=50)
host.add_edge("a", "b", __labels__={"friend"}, years=3)
host.add_edge("a", "c", __labels__={"colleague"}, years=10)
host.add_edge("b", "c", __labels__={"colleague"}, duration=10)
host.add_edge("b", "c", __labels__={"mentor"}, years=2)

qry = """
MATCH (a)-[r:colleague]->(b)
RETURN a.name, b.name, r.duration
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.duration"] == [{0: None}, {0: 10, 1: None}] # should return None when attr is missing

qry = """
MATCH (a)-[r:colleague]->(b)
RETURN a.name, b.name, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.years"] == [{0: 10}, {0: None, 1: 2}]

qry = """
MATCH (a)-[r]->(b)
RETURN a.name, b.name, r.__labels__, r.duration
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ['Alice', 'Alice', 'Bob']
assert res["b.name"] == ['Bob', 'Charlie', 'Charlie']
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.duration"] == [{0: None}, {0: None}, {0: 10, 1: None}]

def test_multigraph_single_edge_where(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Christine", age=30)
host.add_edge("a", "b", __labels__={"friend"}, years=1, friendly="very")
host.add_edge("b", "a", __labels__={"colleague"}, years=2)
host.add_edge("b", "a", __labels__={"mentor"}, years=4)
host.add_edge("b", "c", __labels__={"chef"}, years=12)

qry = """
MATCH (a)-[r]->(b)
WHERE r.friendly == "very" OR r.years == 2
RETURN a.name, b.name, r.__labels__, r.years, r.friendly
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Bob", "Alice"]
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}]
assert res["r.friendly"] == [{0: 'very'}, {0: None, 1: None}]


class TestVariableLengthRelationship:
def test_single_variable_length_relationship(self):
host = nx.DiGraph()
Expand Down
Loading