Skip to content

Commit

Permalink
Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Aug 21, 2024
1 parent cf75d32 commit 2c2cee6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
5 changes: 2 additions & 3 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,10 @@ def detect_from_csvs(self, folder_name, read_csv_parameters=None):
raise ValueError(f"No CSV files detected in the folder '{folder_name}'.")

data = {}
read_csv_parameters = read_csv_parameters or {}
for csv_file in csv_files:
table_name = csv_file.stem
self.detect_table_from_csv(table_name, str(csv_file), read_csv_parameters)
data[csv_file.stem] = pd.read_csv(str(csv_file), **read_csv_parameters)
data[table_name] = _load_data_from_csv(csv_file, read_csv_parameters)
self.detect_table_from_dataframe(table_name, data[table_name])

self._detect_relationships(data)

Expand Down
21 changes: 14 additions & 7 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,18 +2380,20 @@ def test_detect_table_from_csv_table_already_exists(self):
with pytest.raises(InvalidMetadataError, match=error_message):
metadata.detect_table_from_csv('table', 'path.csv')

def test_detect_from_csvs(self, tmp_path):
@patch('sdv.metadata.multi_table._load_data_from_csv')
def test_detect_from_csvs(self, load_data_mock, tmp_path):
"""Test the ``detect_from_csvs`` method.
The method should call ``detect_table_from_csv`` for each csv in the folder.
"""
# Setup
instance = MultiTableMetadata()
instance.detect_table_from_csv = Mock()
instance.detect_table_from_dataframe = Mock()
instance._detect_relationships = Mock()

data1 = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
data2 = pd.DataFrame({'col1': [5, 6], 'col2': [7, 8]})
load_data_mock.side_effect = [data2, data1]

filepath1 = tmp_path / 'table1.csv'
filepath2 = tmp_path / 'table2.csv'
Expand All @@ -2406,13 +2408,18 @@ def test_detect_from_csvs(self, tmp_path):
instance.detect_from_csvs(tmp_path)

# Assert
expected_calls = [
call('table1', str(filepath1), {}),
call('table2', str(filepath2), {}),
expected_calls_load_data = [
call(filepath1, None),
call(filepath2, None),
]
load_data_mock.assert_has_calls(expected_calls_load_data, any_order=True)

instance.detect_table_from_csv.assert_has_calls(expected_calls, any_order=True)
assert instance.detect_table_from_csv.call_count == 2
expected_detect_calls = [
call('table1', data1),
call('table2', data2),
]
instance.detect_table_from_dataframe.assert_has_calls(expected_detect_calls, any_order=True)
assert instance.detect_table_from_dataframe.call_count == 2

instance._detect_relationships.assert_called_once()
table1 = instance._detect_relationships.call_args[0][0]['table1']
Expand Down

0 comments on commit 2c2cee6

Please sign in to comment.