diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 7ffd3eef3..83c80ae25 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -70,9 +70,10 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) + synthesizer_parameters = {'locales': self.locales} + synthesizer_parameters.update(self._table_parameters.get(table_name, {})) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, locales=self.locales, **synthesizer_parameters + metadata=table_metadata, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 464681938..48a32cd30 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -40,7 +40,9 @@ def test__initialize_models(self): locales = ['en_CA', 'fr_CA'] instance = Mock() instance._table_synthesizers = {} - instance._table_parameters = {'nesreca': {'default_distribution': 'gamma'}} + instance._table_parameters = { + 'nesreca': {'default_distribution': 'gamma', 'locales': ['en_US']}, + } instance.locales = locales instance.metadata = get_multi_table_metadata() @@ -57,7 +59,7 @@ def test__initialize_models(self): call( metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales, + locales=['en_US'], ), call(metadata=instance.metadata.tables['oseba'], locales=locales), call(metadata=instance.metadata.tables['upravna_enota'], locales=locales),