Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggregation functions (COUNT, SUM, MIN, MAX, AVG) #45

Merged
merged 25 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3d6caa0
Adds support for multigraphs
jackboyla May 8, 2024
1f2d658
Refactors `_is_edge_attr_match`
jackboyla May 8, 2024
aed457e
Filters relations by __label__ during `_lookup`
jackboyla May 8, 2024
849ad2f
Bundles relation attributes together for lookup
jackboyla May 9, 2024
2823281
Refactors and adds inline docs
jackboyla May 9, 2024
ee801b3
Adds tests for multigraph support
jackboyla May 9, 2024
cb2a4e9
Cleans up inline docs
jackboyla May 9, 2024
3595706
Removes slicing list twice to avoid two copies in memory
jackboyla May 9, 2024
da81cfd
Supports WHERE clause for relationships in multigraphs
jackboyla May 9, 2024
577d843
Adds test for multigraph with WHERE clause on single edge
jackboyla May 9, 2024
e759563
Accounts for WHERE with string node attributes in MultiDiGraphs
jackboyla May 21, 2024
b76b825
Unifies all unit tests to work with both DiGraphs and MultiDiGraphs
jackboyla May 21, 2024
46f5261
Merge branch 'master' into unify-tests-for-digraph-and-multidigraph
jackboyla May 22, 2024
6748db7
Completes multidigraph test for WHERE on node attribute
jackboyla May 22, 2024
022a438
Supports logical OR for relationship matching
jackboyla May 22, 2024
05f98b3
Adds tests for logical OR in MATCH for relationships
jackboyla May 22, 2024
b106914
Merge remote-tracking branch 'origin/master' into logical-or-for-rela…
jackboyla May 24, 2024
351eb6e
Implements aggregation functions
jackboyla Jun 7, 2024
72db2a8
Removes unused code
jackboyla Jun 7, 2024
aa007b9
Adds agg function results to `_return_requests`
jackboyla Jun 10, 2024
963fa8f
Handles `None` values appropriately for MIN and MAX
jackboyla Jun 10, 2024
346a044
Adds tests for agg functions and adjusts existing tests to new output
jackboyla Jun 10, 2024
3d7ebae
Adds examples page
jackboyla Jun 10, 2024
ddd4db4
Adds test for multiple agg functions
jackboyla Jun 10, 2024
d37f3d2
Removes commented code
jackboyla Jun 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ RETURN A.club, B.club
""")
```

See [examples.md](examples.md) for more!

### Example Usage with SQL

Create your own "Sqlite for Neo4j"! This example uses [grand-graph](https://github.com/aplbrain/grand) to run queries in SQL:
Expand Down Expand Up @@ -81,6 +83,7 @@ RETURN
| Graph mutations (e.g. `DELETE`, `SET`,...) | 🛣 | |
| `DISTINCT` | ✅ Thanks @jackboyla! | |
| `ORDER BY` | ✅ Thanks @jackboyla! | |
| Aggregation functions (`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`) | ✅ Thanks @jackboyla! | |

| | | |
| -------------- | -------------- | ---------------- |
Expand Down
66 changes: 66 additions & 0 deletions examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

## Multigraph

```python
from grandcypher import GrandCypher
import networkx as nx

host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June")
host.add_edge("b", "a", __labels__={"paid"}, amount=6)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"friends"}, years=9)
host.add_edge("a", "b", __labels__={"paid"}, amount=40)

qry = """
MATCH (n)-[r:paid]->(m)
RETURN n.name, m.name, r.amount
"""
res = GrandCypher(host).run(qry)
print(res)

'''
{
'n.name': ['Alice', 'Bob'],
'm.name': ['Bob', 'Alice'],
'r.amount': [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}]
}
'''
```

## Aggregation Functions

```python
from grandcypher import GrandCypher
import networkx as nx

host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June")
host.add_edge("b", "a", __labels__={"paid"}, amount=6)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"friends"}, years=9)
host.add_edge("a", "b", __labels__={"paid"}, amount=40)

qry = """
MATCH (n)-[r:paid]->(m)
RETURN n.name, m.name, SUM(r.amount)
"""
res = GrandCypher(host).run(qry)
print(res)

'''
{
'n.name': ['Alice', 'Bob'],
'm.name': ['Bob', 'Alice'],
'SUM(r.amount)': [{'paid': 52, 'friends': 0}, {'paid': 6}]
}
'''
```




136 changes: 118 additions & 18 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import grandiso

from lark import Lark, Transformer, v_args, Token
from lark import Lark, Transformer, v_args, Token, Tree


_OPERATORS = {
Expand Down Expand Up @@ -81,7 +81,13 @@



return_clause : "return"i distinct_return? entity_id ("," entity_id)*

return_clause : "return"i distinct_return? return_item ("," return_item)*
return_item : entity_id | aggregation_function | entity_id "." attribute_id

aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")"
AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN"
attribute_id : CNAME

distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
Expand All @@ -107,8 +113,8 @@
edge_match : LEFT_ANGLE? "--" RIGHT_ANGLE?
| LEFT_ANGLE? "-[]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" CNAME "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" CNAME ":" TYPE "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" ":" TYPE "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" CNAME ":" type_list "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" ":" type_list "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" "*" MIN_HOP "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" CNAME "*" MIN_HOP "]-" RIGHT_ANGLE?
Expand All @@ -118,6 +124,7 @@
| LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP "]-" RIGHT_ANGLE?
| LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE?

type_list : TYPE ( "|" TYPE )*

LEFT_ANGLE : "<"
RIGHT_ANGLE : ">"
Expand Down Expand Up @@ -228,10 +235,14 @@ def _is_edge_attr_match(
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)

motif_types = motif_edges.get('__labels__', set())
host_types = host_edges.get('__labels__', set())

if motif_types and not motif_types.intersection(host_types):
return False

for attr, val in motif_edges.items():
if attr == "__labels__":
if val and val - host_edges.get("__labels__", set()):
return False
continue
if host_edges.get(attr) != val:
return False
Expand Down Expand Up @@ -277,6 +288,7 @@ def _get_entity_from_host(
edge_data = host.get_edge_data(*entity_name)
if not edge_data:
return None # print(f"Nothing found for {entity_name} {entity_attribute}")

if entity_attribute:
# looking for edge attribute:
if isinstance(host, nx.MultiDiGraph):
Expand Down Expand Up @@ -371,6 +383,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._matche_paths = None
self._return_requests = []
self._return_edges = {}
self._aggregate_functions = []
self._distinct = False
self._order_by = None
self._order_by_attributes = set()
Expand Down Expand Up @@ -478,9 +491,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[i] = v.get(entity_attribute, None)
r_attr[(i, list(v.get('__labels__'))[0])] = v.get(entity_attribute, None)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)

ret = ret_with_attr

result[data_path] = list(ret)[offset_limit]
Expand All @@ -492,9 +506,19 @@ def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)
item = item.children[0] if isinstance(item, Tree) else item
if isinstance(item, Tree) and item.data == "aggregation_function":
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)
self._aggregate_functions.append((func, entity))
self._return_requests.append(entity)
else:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)


def order_clause(self, order_clause):
self._order_by = []
Expand All @@ -520,12 +544,73 @@ def skip_clause(self, skip):
skip = int(skip[-1])
self._skip = skip


def aggregate(self, func, results, entity, group_keys):
# Collect data based on group keys
grouped_data = {}
for i in range(len(results[entity])):
group_tuple = tuple(results[key][i] for key in group_keys if key in results)
if group_tuple not in grouped_data:
grouped_data[group_tuple] = []
grouped_data[group_tuple].append(results[entity][i])

def _collate_data(data, unique_labels, func):
# for ["COUNT", "SUM", "AVG"], we treat None as 0
if func in ["COUNT", "SUM", "AVG"]:
collated_data = {
label: [(v or 0) for rel in data for k, v in rel.items() if k[1] == label] for label in unique_labels
}
# for ["MAX", "MIN"], we treat None as non-existent
elif func in ["MAX", "MIN"]:
collated_data = {
label: [v for rel in data for k, v in rel.items() if (k[1] == label and v is not None)] for label in unique_labels
}

return collated_data

# Apply aggregation function
aggregate_results = {}
for group, data in grouped_data.items():
# data => [{(0, 'paid'): 70, (1, 'paid'): 90}]
unique_labels = set([k[1] for rel in data for k in rel.keys()])
collated_data = _collate_data(data, unique_labels, func)
if func == "COUNT":
count_data = {label: len(data) for label, data in collated_data.items()}
aggregate_results[group] = count_data
elif func == "SUM":
sum_data = {label: sum(data) for label, data in collated_data.items()}
aggregate_results[group] = sum_data
elif func == "AVG":
sum_data = {label: sum(data) for label, data in collated_data.items()}
count_data = {label: len(data) for label, data in collated_data.items()}
avg_data = {label: sum_data[label] / count_data[label] if count_data[label] > 0 else 0 for label in sum_data}
aggregate_results[group] = avg_data
elif func == "MAX":
max_data = {label: max(data) for label, data in collated_data.items()}
aggregate_results[group] = max_data
elif func == "MIN":
min_data = {label: min(data) for label, data in collated_data.items()}
aggregate_results[group] = min_data

aggregate_results = [v for v in aggregate_results.values()]
return aggregate_results

def returns(self, ignore_limit=False):

results = self._lookup(
self._return_requests + list(self._order_by_attributes),
offset_limit=slice(0, None),
)
if len(self._aggregate_functions) > 0:
group_keys = [key for key in results.keys() if not any(key.endswith(func[1]) for func in self._aggregate_functions)]

aggregated_results = {}
for func, entity in self._aggregate_functions:
aggregated_data = self.aggregate(func, results, entity, group_keys)
func_key = f"{func}({entity})"
aggregated_results[func_key] = aggregated_data
self._return_requests.append(func_key)
results.update(aggregated_results)
if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
Expand Down Expand Up @@ -775,10 +860,21 @@ def entity_id(self, entity_id):
return ".".join(entity_id)
return entity_id.value

def edge_match(self, edge_name):
direction = cname = min_hop = max_hop = edge_type = None
def edge_match(self, edge_tokens):
def flatten_tokens(edge_tokens):
flat_tokens = []
for token in edge_tokens:
if isinstance(token, Tree):
flat_tokens.extend(flatten_tokens(token.children)) # Recursively flatten the tree
else:
flat_tokens.append(token)
return flat_tokens

direction = cname = min_hop = max_hop = None
edge_types = []
edge_tokens = flatten_tokens(edge_tokens)

for token in edge_name:
for token in edge_tokens:
if token.type == "MIN_HOP":
min_hop = int(token.value)
elif token.type == "MAX_HOP":
Expand All @@ -790,15 +886,19 @@ def edge_match(self, edge_name):
elif token.type == "RIGHT_ANGLE":
direction = "r"
elif token.type == "TYPE":
edge_type = token.value
edge_types.append(token.value)
else:
cname = token

direction = direction if direction is not None else "b"
if (min_hop is not None or max_hop is not None) and (direction == "b"):
raise TypeError("not support edge hopping for bidirectional edge")
raise TypeError("Bidirectional edge does not support edge hopping")

# Handle the case where no edge types are specified, defaulting to a generic type if needed
if edge_types == []:
edge_types = None

return (cname, edge_type, direction, min_hop, max_hop)
return (cname, edge_types, direction, min_hop, max_hop)

def node_match(self, node_name):
cname = node_type = json_data = None
Expand Down Expand Up @@ -845,7 +945,7 @@ def match_clause(self, match_clause: Tuple):
if maxh > self._max_hop:
raise ValueError(f"max hop is caped at 100, found {maxh}!")
if t:
t = set([t])
t = set([t] if type(t) is str else t)
self._motif.add_edges_from(
edges, __min_hop__=minh, __max_hop__=maxh, __is_hop__=ish, __labels__=t
)
Expand Down
Loading
Loading