diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index fc3942a26..537d8bba2 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -70,6 +70,10 @@ def _set_metadata_dict(self, metadata, single_table_name=None): else: if single_table_name is None: single_table_name = self.DEFAULT_SINGLE_TABLE_NAME + warnings.warn( + 'No table name was provided to metadata containing only one table. ' + f'Assigning name: {single_table_name}' + ) self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) def _get_single_table_name(self): diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 7c6e2a17e..38224da59 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -409,6 +409,8 @@ def update_columns(self, table_name, column_names, **kwargs): **kwargs: Any key word arguments that describe metadata for the columns. """ + if not isinstance(column_names, list): + raise InvalidMetadataError('Please pass in a list to column_names arg.') self._validate_table_exists(table_name) table = self.tables.get(table_name) table.update_columns(column_names, **kwargs) @@ -832,8 +834,8 @@ def validate_data(self, data): * all foreign keys belong to a primay key Args: - data (pd.DataFrame): - The data to validate. + data (dict): + A dictionary of table names to pd.DataFrames. Raises: InvalidDataError: @@ -843,6 +845,9 @@ def validate_data(self, data): A warning is being raised if ``datetime_format`` is missing from a column represented as ``object`` in the dataframe and its sdtype is ``datetime``. """ + if not isinstance(data, dict): + raise InvalidMetadataError('Please pass in a dictionary mapping tables to dataframes.') + errors = [] errors += self._validate_missing_tables(data) errors += self._validate_all_tables(data) @@ -880,7 +885,7 @@ def get_column_names(self, table_name, **kwargs): Args: table_name (str): - The name of the table to get column names for.s + The name of the table to get column names for. **kwargs: Metadata keywords to filter on, for example sdtype='id' or pii=True. diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index f722826dd..3978bddf1 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -77,7 +77,8 @@ def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): synthesizer_parameters = self._table_parameters.get(table_name, {}) - metadata = Metadata.load_from_dict(table_metadata.to_dict()) + metadata_dict = {'tables': {table_name: table_metadata.to_dict()}} + metadata = Metadata.load_from_dict(metadata_dict) self._table_synthesizers[table_name] = self._synthesizer( metadata=metadata, locales=self.locales, **synthesizer_parameters ) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index a43816100..e3634d503 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -3064,3 +3064,42 @@ def test_anonymize(self, mock_load): 'parent_primary_key': 'col1', 'child_foreign_key': 'col2', } + + def test_update_columns_no_list_error(self): + """Test that ``update_columns`` only takes in list and that an error is thrown.""" + # Setup + metadata = MultiTableMetadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + + error_msg = re.escape('Please pass in a list to column_names arg.') + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.update_columns('table', 'col1', sdtype='categorical') + + def test_validate_data_without_dict(self): + """Test that ``validate_data`` only takes in dict and that an error is thrown otherwise.""" + # Setup + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'categorical'}, + 'latitude': {'sdtype': 'latitude'}, + 'longitude': {'sdtype': 'longitude'}, + } + } + } + }) + data = pd.DataFrame({ + 'col_1': [1, 2, 3], + 'col_2': ['a', 'b', 'c'], + 'latitude': [1, 2, 3], + 'longitude': [1, 2, 3], + }) + error_msg = re.escape('Please pass in a dictionary mapping tables to dataframes.') + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.validate_data(data)