Skip to content

Commit

Permalink
Add test and fix for issue #55
Browse files Browse the repository at this point in the history
  • Loading branch information
j6k4m8 committed Jun 4, 2024
1 parent e71be46 commit e933ced
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
10 changes: 5 additions & 5 deletions grand/backends/_sqlbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,14 @@ def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator:
"""

columns = [
self._node_table.c[self._edge_source_key],
self._node_table.c[self._edge_target_key],
self._edge_table.c[self._edge_source_key],
self._edge_table.c[self._edge_target_key],
]

if include_metadata:
columns.append(self._node_table.c["_metadata"])
columns.append(self._edge_table.c["_metadata"])

sql = self._node_table.select().with_only_columns(columns)
sql = self._edge_table.select().with_only_columns(*columns)
return self._connection.execute(sql).fetchall()

def get_node_by_id(self, node_name: Hashable):
Expand Down Expand Up @@ -676,4 +676,4 @@ def commit(self):
self._connection.commit()

def close(self):
self._connection.close()
self._connection.close()
20 changes: 13 additions & 7 deletions grand/backends/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,30 +131,30 @@
),
)


# @pytest.mark.parametrize("backend", backend_test_params)
class TestBackendPersistence:
def test_sqlite_persistence(self):
if not _CAN_IMPORT_SQL:
return

dbpath = "grand_peristence_test_temp.db"
url = "sqlite:///"+dbpath
url = "sqlite:///" + dbpath

#arrange
# arrange
backend = SQLBackend(db_url=url, directed=True)
node0 = backend.add_node("A",{"foo":"bar"})
node0 = backend.add_node("A", {"foo": "bar"})
backend.commit()
backend.close()
#act
# act
backend = SQLBackend(db_url=url, directed=True)
nodes = list(backend.all_nodes_as_iterable())
#assert
# assert
assert node0 in nodes
#cleanup
# cleanup
os.remove(dbpath)



@pytest.mark.parametrize("backend", backend_test_params)
class TestBackend:
def test_can_create(self, backend):
Expand Down Expand Up @@ -344,6 +344,12 @@ def test_can_get_edge_metadata(self, backend):
G.nx.add_edge("foo", "bar", baz=True)
assert list(G.nx.edges(data=True)) == [("foo", "bar", {"baz": True})]

def test_can_get_edges(self, backend):
backend, kwargs = backend
G = Graph(backend=backend(**kwargs))
G.nx.add_edge("foo", "bar", baz=True)
assert list(G.backend.all_edges_as_iterable()) == [("foo", "bar")]

def test_edge_dne_raises(self, backend):
backend, kwargs = backend
G = Graph(backend=backend(**kwargs))
Expand Down

0 comments on commit e933ced

Please sign in to comment.