Skip to content

Commit

Permalink
Include validation check for single table auto_assign_transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed May 21, 2024
1 parent d70f3bb commit 6b997ae
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def auto_assign_transformers(self, data):
data (pandas.DataFrame):
The raw data (before any transformations) that will be used to fit the model.
"""
self.metadata._validate_metadata_matches_data(data)
self._data_processor.prepare_for_fitting(data)

def get_transformers(self):
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,28 @@ def test_auto_assign_transformers_missing_table(self):
with pytest.raises(ValueError, match=err_msg):
instance.auto_assign_transformers(data)

def test_auto_assign_transformers_missing_column(self):
"""Test that each table of the data calls its single table auto assign method."""
# Setup
metadata = get_multi_table_metadata()
synthesizer = HMASynthesizer(metadata)
table1 = pd.DataFrame({'col1': [1, 2]})
table2 = pd.DataFrame({'col2': [1, 2]})
data = {
'nesreca': table1,
'oseba': table2
}

# Run
error_msg = re.escape(
'The provided data does not match the metadata:\n'
"The columns ['col1'] are not present in the metadata.\n\n"
"The metadata columns ['id_nesreca', 'nesreca_val', 'upravna_enota'] "
'are not present in the data.'
)
with pytest.raises(InvalidDataError, match=error_msg):
synthesizer.auto_assign_transformers(data)

def test_get_transformers(self):
"""Test that each table of the data calls its single table get_transformers method."""
# Setup
Expand Down
24 changes: 23 additions & 1 deletion tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from sdv import version
from sdv.constraints.errors import AggregateConstraintsError
from sdv.errors import ConstraintsNotMetError, SamplingError, SynthesizerInputError, VersionError
from sdv.errors import (
ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError, VersionError)
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling.tabular import Condition
from sdv.single_table import (
Expand Down Expand Up @@ -218,6 +219,27 @@ def test_auto_assign_transformers(self):
# Assert
instance._data_processor.prepare_for_fitting.assert_called_once_with(data)

def test_auto_assign_transformers_with_invalid_data(self):
"""Test that the ``DataProcessor.prepare_for_fitting`` is being called."""
# Setup
metadata = SingleTableMetadata.load_from_dict({
'columns': {
'a': {'sdtype': 'categorical'},
}
})

synthesizer = GaussianCopulaSynthesizer(metadata)

# input data that does not match the metadata
data = pd.DataFrame({'b': list(np.random.choice(['M', 'F'], size=10))})
error_msg = re.escape(
'The provided data does not match the metadata:\n'
"The columns ['b'] are not present in the metadata.\n\n"
"The metadata columns ['a'] are not present in the data."
)
with pytest.raises(InvalidDataError, match=error_msg):
synthesizer.auto_assign_transformers(data)

def test_get_transformers(self):
"""Test that this returns the field transformers from the ``HyperTransformer``."""
# Setup
Expand Down

0 comments on commit 6b997ae

Please sign in to comment.