diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 7c6e2a17e..61be08db8 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -126,10 +126,10 @@ def _validate_relationship_sdtypes( ) def _validate_circular_relationships( - self, parent, children=None, parents=None, child_map=None, errors=None + self, parent, children=None, visited=None, child_map=None, errors=None ): """Validate that there is no circular relationship in the metadata.""" - parents = set() if parents is None else parents + visited = set() if visited is None else visited if children is None: children = child_map[parent] @@ -137,15 +137,15 @@ def _validate_circular_relationships( errors.append(parent) for child in children: - if child in parents: - break + if child in visited: + continue - parents.add(child) + visited.add(child) self._validate_circular_relationships( parent, children=child_map.get(child, set()), child_map=child_map, - parents=parents, + visited=visited, errors=errors, ) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index a43816100..43c9941a6 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1418,6 +1418,102 @@ def test_validate_data_datetime_warning(self): with pytest.warns(UserWarning, match=warning_msg): metadata.validate_data(data) + def test_add_relationship_circular_graph(self): + """Test that an error is raised when a circular relationship is detected. + + The graph has the cycle B->C->D->B. + Besides the cycle, the other relationships are: B->A, C->A, D->A. + """ + # Setup + metadata = MultiTableMetadata() + metadata.add_table('A') + metadata.add_column('A', 'id', sdtype='id') + metadata.add_column('A', 'fk', sdtype='id') + metadata.set_primary_key('A', 'id') + + metadata.add_table('B') + metadata.add_column('B', 'id', sdtype='id') + metadata.add_column('B', 'fk', sdtype='id') + metadata.set_primary_key('B', 'id') + + metadata.add_table('C') + metadata.add_column('C', 'id', sdtype='id') + metadata.add_column('C', 'fk', sdtype='id') + metadata.set_primary_key('C', 'id') + + metadata.add_table('D') + metadata.add_column('D', 'id', sdtype='id') + metadata.add_column('D', 'fk', sdtype='id') + metadata.set_primary_key('D', 'id') + + metadata.add_relationship('B', 'C', 'id', 'fk') + metadata.add_relationship('B', 'A', 'id', 'fk') + + metadata.add_relationship('C', 'D', 'id', 'fk') + metadata.add_relationship('C', 'A', 'id', 'fk') + + metadata.add_relationship('D', 'A', 'id', 'fk') + + # Run and Assert + error_msg = re.escape( + 'The relationships in the dataset describe a ' + "circular dependency between tables ['B', 'C', 'D']." + ) + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.add_relationship('D', 'B', 'id', 'fk') + + def test_add_relationship_circular_graph_complex(self): + """Test that an error is raised when a circular relationship is detected. + + The graph has the cycle C->E->D->C. + Besides the cycle, the other relationships are: C->B, D->B, E->B, E->A, A->B. + """ + # Setup + metadata = MultiTableMetadata() + metadata.add_table('A') + metadata.add_column('A', 'id', sdtype='id') + metadata.add_column('A', 'fk', sdtype='id') + metadata.set_primary_key('A', 'id') + + metadata.add_table('B') + metadata.add_column('B', 'id', sdtype='id') + metadata.add_column('B', 'fk', sdtype='id') + metadata.set_primary_key('B', 'id') + + metadata.add_table('C') + metadata.add_column('C', 'id', sdtype='id') + metadata.add_column('C', 'fk', sdtype='id') + metadata.set_primary_key('C', 'id') + + metadata.add_table('D') + metadata.add_column('D', 'id', sdtype='id') + metadata.add_column('D', 'fk', sdtype='id') + metadata.set_primary_key('D', 'id') + + metadata.add_table('E') + metadata.add_column('E', 'id', sdtype='id') + metadata.add_column('E', 'fk', sdtype='id') + metadata.set_primary_key('E', 'id') + + metadata.add_relationship('C', 'B', 'id', 'fk') + metadata.add_relationship('C', 'E', 'id', 'fk') + + metadata.add_relationship('D', 'B', 'id', 'fk') + metadata.add_relationship('D', 'C', 'id', 'fk') + + metadata.add_relationship('A', 'B', 'id', 'fk') + + metadata.add_relationship('E', 'A', 'id', 'fk') + metadata.add_relationship('E', 'B', 'id', 'fk') + + # Run and Assert + error_msg = re.escape( + 'The relationships in the dataset describe a ' + "circular dependency between tables ['C', 'D', 'E']." + ) + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.add_relationship('E', 'D', 'id', 'fk') + @patch('sdv.metadata.multi_table.SingleTableMetadata') def test_add_table(self, table_metadata_mock): """Test that the method adds the table name to ``instance.tables``."""