From e933ced9dcf21a62f2aa7a6126688d6c449ae9e3 Mon Sep 17 00:00:00 2001 From: Jordan Matelsky Date: Tue, 4 Jun 2024 16:02:42 -0400 Subject: [PATCH] Add test and fix for issue #55 --- grand/backends/_sqlbackend.py | 10 +++++----- grand/backends/test_backends.py | 20 +++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/grand/backends/_sqlbackend.py b/grand/backends/_sqlbackend.py index 8a6591d..2f71493 100644 --- a/grand/backends/_sqlbackend.py +++ b/grand/backends/_sqlbackend.py @@ -306,14 +306,14 @@ def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator: """ columns = [ - self._node_table.c[self._edge_source_key], - self._node_table.c[self._edge_target_key], + self._edge_table.c[self._edge_source_key], + self._edge_table.c[self._edge_target_key], ] if include_metadata: - columns.append(self._node_table.c["_metadata"]) + columns.append(self._edge_table.c["_metadata"]) - sql = self._node_table.select().with_only_columns(columns) + sql = self._edge_table.select().with_only_columns(*columns) return self._connection.execute(sql).fetchall() def get_node_by_id(self, node_name: Hashable): @@ -676,4 +676,4 @@ def commit(self): self._connection.commit() def close(self): - self._connection.close() \ No newline at end of file + self._connection.close() diff --git a/grand/backends/test_backends.py b/grand/backends/test_backends.py index 5b345e6..cb0d098 100644 --- a/grand/backends/test_backends.py +++ b/grand/backends/test_backends.py @@ -131,6 +131,7 @@ ), ) + # @pytest.mark.parametrize("backend", backend_test_params) class TestBackendPersistence: def test_sqlite_persistence(self): @@ -138,23 +139,22 @@ def test_sqlite_persistence(self): return dbpath = "grand_peristence_test_temp.db" - url = "sqlite:///"+dbpath + url = "sqlite:///" + dbpath - #arrange + # arrange backend = SQLBackend(db_url=url, directed=True) - node0 = backend.add_node("A",{"foo":"bar"}) + node0 = backend.add_node("A", {"foo": "bar"}) backend.commit() backend.close() - #act + # act backend = SQLBackend(db_url=url, directed=True) nodes = list(backend.all_nodes_as_iterable()) - #assert + # assert assert node0 in nodes - #cleanup + # cleanup os.remove(dbpath) - @pytest.mark.parametrize("backend", backend_test_params) class TestBackend: def test_can_create(self, backend): @@ -344,6 +344,12 @@ def test_can_get_edge_metadata(self, backend): G.nx.add_edge("foo", "bar", baz=True) assert list(G.nx.edges(data=True)) == [("foo", "bar", {"baz": True})] + def test_can_get_edges(self, backend): + backend, kwargs = backend + G = Graph(backend=backend(**kwargs)) + G.nx.add_edge("foo", "bar", baz=True) + assert list(G.backend.all_edges_as_iterable()) == [("foo", "bar")] + def test_edge_dne_raises(self, backend): backend, kwargs = backend G = Graph(backend=backend(**kwargs))