Skip to content

Commit

Permalink
Adds ORDER BY (#41)
Browse files Browse the repository at this point in the history
* Implements `DISTINCT`

* Initial implementation of `ORDER BY` with modified `DISTINCT`

* Adds test cases for `ORDER BY`

* Updates README

* Adds test for `ORDER BY` with no direction provided

* Removes crusty bits

* Refactors `returns()`

* Updates README
  • Loading branch information
jackboyla authored May 3, 2024
1 parent 6f4e32c commit 664ab96
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 20 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ RETURN
| `(:Type)` node-labels | ✅ Thanks @khoale88! | |
| `[:Type]` edge-labels | ✅ Thanks @khoale88! | |
| Graph mutations (e.g. `DELETE`, `SET`,...) | 🛣 | |
| `DISTINCT` | ✅ Thanks @jackboyla! | |
| `ORDER BY` | ✅ Thanks @jackboyla! | |

| | | |
| -------------- | -------------- | ---------------- |
Expand Down
117 changes: 98 additions & 19 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from typing import Dict, List, Callable, Tuple
from collections import OrderedDict
import random
import string
from functools import lru_cache
Expand Down Expand Up @@ -39,8 +40,8 @@
r"""
start : query
query : many_match_clause where_clause return_clause
| many_match_clause return_clause
query : many_match_clause where_clause return_clause order_clause? skip_clause? limit_clause?
| many_match_clause return_clause order_clause? skip_clause? limit_clause?
many_match_clause : (match_clause)+
Expand Down Expand Up @@ -81,14 +82,21 @@
return_clause : "return"i distinct_return? entity_id ("," entity_id)*
| "return"i distinct_return? entity_id ("," entity_id)* limit_clause
| "return"i distinct_return? entity_id ("," entity_id)* skip_clause
| "return"i distinct_return? entity_id ("," entity_id)* skip_clause limit_clause
distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
skip_clause : "skip"i NUMBER
order_clause : "order"i "by"i order_items
order_items : order_item ("," order_item)*
order_item : entity_id order_direction?
order_direction : "ASC"i -> asc
| "DESC"i -> desc
| -> no_direction
?entity_id : CNAME
| CNAME "." CNAME
Expand Down Expand Up @@ -321,6 +329,8 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._return_requests = []
self._return_edges = {}
self._distinct = False
self._order_by = None
self._order_by_attributes = set()
self._limit = limit
self._skip = 0
self._max_hop = 100
Expand Down Expand Up @@ -402,6 +412,20 @@ def return_clause(self, clause):
item = str(item.value)
self._return_requests.append(item)


def order_clause(self, order_clause):
self._order_by = []
for item in order_clause[0].children:
field = str(item.children[0]) # assuming the field name is the first child
# Default to 'ASC' if not specified
if len(item.children) > 1 and str(item.children[1].data).lower() != 'desc':
direction = 'ASC'
else:
direction = 'DESC'

self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...]
self._order_by_attributes.add(field)

def distinct_return(self, distinct):
self._distinct = True

Expand All @@ -414,21 +438,76 @@ def skip_clause(self, skip):
self._skip = skip

def returns(self, ignore_limit=False):
if self._limit and ignore_limit is False:
offset_limit = slice(self._skip, self._skip + self._limit)
else:
offset_limit = slice(self._skip, None)

results = self._lookup(self._return_requests, offset_limit=offset_limit)

results = self._lookup(
self._return_requests + list(self._order_by_attributes),
offset_limit=slice(0, None)
)
if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
results = self._apply_distinct(results)
results = self._apply_pagination(results, ignore_limit)


# Exclude order-by-only attributes from the final results
results = {
key: values for key, values in results.items() if key in self._return_requests
}

# process distinct for each key in results
distinct_results = {}
for key, values in results.items():
# remove duplicates
distinct_results[key] = list(set(values))
results = distinct_results
return results

def _apply_order_by(self, results):
if self._order_by:
sort_lists = [(results[field], direction) for field, direction in self._order_by if field in results]

if sort_lists:
# Generate a list of indices sorted by the specified fields
indices = range(len(next(iter(results.values())))) # Safe because all lists are assumed to be of the same length
for sort_list, direction in reversed(sort_lists): # reverse to ensure the first sort key is primary
indices = sorted(indices, key=lambda i: sort_list[i], reverse=(direction == 'DESC'))

# Reorder all lists in results using sorted indices
for key in results:
results[key] = [results[key][i] for i in indices]

return results

def _apply_distinct(self, results):
if self._order_by:
assert self._order_by_attributes.issubset(self._return_requests), "In a WITH/RETURN with DISTINCT or an aggregation, it is not possible to access variables declared before the WITH/RETURN"

# ordered dict to maintain the first occurrence of each unique tuple based on return requests
unique_rows = OrderedDict()

# Iterate over each 'row' by index
for i in range(len(next(iter(results.values())))): # assume all columns are of the same length
# create a tuple key of all the values from return requests for this row
row_key = tuple(results[key][i] for key in self._return_requests if key in results)

if row_key not in unique_rows:
unique_rows[row_key] = i # store the index of the first occurrence of this unique row

# construct the results based on unique indices collected
distinct_results = {key: [] for key in self._return_requests}
for row_key, index in unique_rows.items():
for _, key in enumerate(self._return_requests):
distinct_results[key].append(results[key][index])

return distinct_results

def _apply_pagination(self, results, ignore_limit):
# apply LIMIT and SKIP (if set) after ordering
if self._limit is not None and not ignore_limit:
start_index = self._skip
end_index = start_index + self._limit
for key in results.keys():
results[key] = results[key][start_index:end_index]
# else just apply SKIP (if set)
else:
for key in results.keys():
start_index = self._skip
results[key] = results[key][start_index:]

return results

Expand Down Expand Up @@ -471,8 +550,8 @@ def _get_true_matches(self):
self_matches.append(match)
self_matche_paths.append(edge_hop_map)

# Check if limit reached
if self._is_limit(len(self_matches)):
# Check if limit reached; stop ONLY IF we are not ordering
if self._is_limit(len(self_matches)) and not self._order_by:
complete = True
break

Expand Down
156 changes: 155 additions & 1 deletion grandcypher/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import networkx as nx
import pytest

from . import _GrandCypherGrammar, _GrandCypherTransformer, GrandCypher

Expand Down Expand Up @@ -589,7 +590,7 @@ def test_complex_where(self):


class TestDistinct:
def test_basic_distinct(self):
def test_basic_distinct1(self):
host = nx.DiGraph()
host.add_node("a", name="Alice")
host.add_node("b", name="Bob")
Expand All @@ -603,6 +604,23 @@ def test_basic_distinct(self):
assert len(res["n.name"]) == 2 # should return "Alice" and "Bob" only once
assert "Alice" in res["n.name"] and "Bob" in res["n.name"]

def test_basic_distinct2(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=25)
host.add_node("c", name="Carol", age=21)
host.add_node("d", name="Alice", age=25)
host.add_node("e", name="Greg", age=32)

qry = """
MATCH (n)
RETURN DISTINCT n.name
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 4 # should return "Alice" and "Bob" only once
assert "Alice" in res["n.name"] and "Bob" in res["n.name"] and "Carol" in res["n.name"] and "Greg" in res["n.name"]


def test_distinct_with_relationships(self):
host = nx.DiGraph()
Expand Down Expand Up @@ -671,6 +689,142 @@ def test_distinct_with_attributes(self):
assert "Alice" in res["n.name"] and "Bob" in res["n.name"]


class TestOrderBy:
def test_order_by_single_field_ascending(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age ASC
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Carol", "Alice", "Bob"]

def test_order_by_single_field_descending(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age DESC
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Bob", "Alice", "Carol"]

def test_order_by_single_field_no_direction_provided(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Carol", "Alice", "Bob"]

def test_order_by_multiple_fields(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=25)
host.add_node("d", name="Dave", age=25)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age ASC, n.name DESC
"""
res = GrandCypher(host).run(qry)
# names sorted in descending order where ages are the same
assert res["n.name"] == ["Dave", "Carol", "Alice", "Bob"]

def test_order_by_with_limit(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age ASC LIMIT 2
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Carol", "Alice"]

def test_order_by_with_skip(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name
ORDER BY n.age ASC SKIP 1
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Alice", "Bob"]

def test_order_by_with_distinct(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=25)
host.add_node("d", name="Alice", age=25)
host.add_node("e", name="Greg", age=32)

qry = """
MATCH (n)
RETURN DISTINCT n.name, n.age
ORDER BY n.age DESC
"""
res = GrandCypher(host).run(qry)
# Distinct names, ordered by age where available
assert res["n.name"] == ['Greg', 'Bob', 'Alice', 'Carol']
assert res["n.age"] == [32, 30, 25, 25]

def test_error_on_order_by_with_distinct_and_non_returned_field(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=25)
host.add_node("d", name="Alice", age=25)
host.add_node("e", name="Greg", age=32)

qry = """
MATCH (n)
RETURN DISTINCT n.name
ORDER BY n.age DESC
"""
# Expect an error since 'n.age' is not included in the RETURN clause but used in ORDER BY
with pytest.raises(Exception):
res = GrandCypher(host).run(qry)

def test_order_by_with_non_returned_field(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Carol", age=20)

qry = """
MATCH (n)
RETURN n.name ORDER BY n.age ASC
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ["Carol", "Alice", "Bob"]


class TestVariableLengthRelationship:
def test_single_variable_length_relationship(self):
host = nx.DiGraph()
Expand Down

0 comments on commit 664ab96

Please sign in to comment.