diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 924f433ae..a5f7264e0 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -600,11 +600,10 @@ def detect_from_csvs(self, folder_name, read_csv_parameters=None): raise ValueError(f"No CSV files detected in the folder '{folder_name}'.") data = {} - read_csv_parameters = read_csv_parameters or {} for csv_file in csv_files: table_name = csv_file.stem - self.detect_table_from_csv(table_name, str(csv_file), read_csv_parameters) - data[csv_file.stem] = pd.read_csv(str(csv_file), **read_csv_parameters) + data[table_name] = _load_data_from_csv(csv_file, read_csv_parameters) + self.detect_table_from_dataframe(table_name, data[table_name]) self._detect_relationships(data) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index df4d9bd3b..2749e277f 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -2380,18 +2380,20 @@ def test_detect_table_from_csv_table_already_exists(self): with pytest.raises(InvalidMetadataError, match=error_message): metadata.detect_table_from_csv('table', 'path.csv') - def test_detect_from_csvs(self, tmp_path): + @patch('sdv.metadata.multi_table._load_data_from_csv') + def test_detect_from_csvs(self, load_data_mock, tmp_path): """Test the ``detect_from_csvs`` method. The method should call ``detect_table_from_csv`` for each csv in the folder. """ # Setup instance = MultiTableMetadata() - instance.detect_table_from_csv = Mock() + instance.detect_table_from_dataframe = Mock() instance._detect_relationships = Mock() data1 = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) data2 = pd.DataFrame({'col1': [5, 6], 'col2': [7, 8]}) + load_data_mock.side_effect = [data2, data1] filepath1 = tmp_path / 'table1.csv' filepath2 = tmp_path / 'table2.csv' @@ -2406,13 +2408,18 @@ def test_detect_from_csvs(self, tmp_path): instance.detect_from_csvs(tmp_path) # Assert - expected_calls = [ - call('table1', str(filepath1), {}), - call('table2', str(filepath2), {}), + expected_calls_load_data = [ + call(filepath1, None), + call(filepath2, None), ] + load_data_mock.assert_has_calls(expected_calls_load_data, any_order=True) - instance.detect_table_from_csv.assert_has_calls(expected_calls, any_order=True) - assert instance.detect_table_from_csv.call_count == 2 + expected_detect_calls = [ + call('table1', data1), + call('table2', data2), + ] + instance.detect_table_from_dataframe.assert_has_calls(expected_detect_calls, any_order=True) + assert instance.detect_table_from_dataframe.call_count == 2 instance._detect_relationships.assert_called_once() table1 = instance._detect_relationships.call_args[0][0]['table1']