Skip to content

Commit

Permalink
clean up _resolve_arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 18, 2024
1 parent 94b2860 commit 26fa6bc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 122 deletions.
38 changes: 0 additions & 38 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,44 +221,6 @@ def validate_table(self, data, table_name=None):

return self.validate_data({table_name: data})

def _resolve_arguments(self, arg_names, *args, **kwargs):
"""Resolves the arguments from the provided args and kwargs.
Args:
arg_names (list):
List of argument names to resolve.
"""
parameters = {}
is_single_table = len(self.tables) == 1
args_table_name = True
if is_single_table:
parameters['table_name'] = next(iter(self.tables))
if len(arg_names) != len(args):
args_table_name = False

else:
table_name = kwargs.get('table_name')
if table_name is None:
table_name = args[0]
args_table_name = False

parameters['table_name'] = table_name

parameters.update({
arg_name: arg for arg_name, arg in zip(arg_names, args[not args_table_name :])
})
for parameter_name, parameter in parameters.items():
kwargs_value = kwargs.get(parameter_name)
if kwargs_value is not None and kwargs_value != parameter:
raise ValueError(
f"Conflicting values for '{parameter_name}': '{parameter}' and '{kwargs_value}'"
)

kwargs = {key: value for key, value in kwargs.items() if key not in parameters}
parameters.update(kwargs)

return parameters

def get_column_names(self, table_name=None, **kwargs):
"""Return a list of column names that match the given metadata keyword arguments."""
table_name = self._handle_table_name(table_name)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ def test_update_transformers_with_id_generator():
data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']})

stm = Metadata.detect_from_dataframes({'table': data})
stm.update_column('table', 'user_id', sdtype='id')
stm.set_primary_key('table', 'user_id')
stm.update_column('user_id', 'table', sdtype='id')
stm.set_primary_key('user_id', 'table')

gc = GaussianCopulaSynthesizer(stm)
custom_id = IDGenerator(starting_value=min_value_id)
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_unknown_sdtype():
})

metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('table', 'unknown', sdtype='unknown')
metadata.update_column('unknown', 'table', sdtype='unknown')

synthesizer = GaussianCopulaSynthesizer(metadata)

Expand Down
81 changes: 0 additions & 81 deletions tests/unit/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,87 +644,6 @@ def test_detect_from_dataframe_raises_error_if_not_dataframe(self):
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframe(Mock())

def test__resolve_arguments_single_table(self):
"""Test the ``resolve_arguments`` method for single table metadata."""
# Setup
metadata = Mock()
metadata.tables = ['table1']
expected_error = re.escape(
"Conflicting values for 'table_name': 'table1' and 'wrong_table'"
)

# Run
result_without_table_name = Metadata._resolve_arguments(
metadata, ['column_name_1', 'column_name_2'], 'column1', 'column2'
)
result_with_table_name = Metadata._resolve_arguments(
metadata, ['column_name_1', 'column_name_2'], 'table1', 'column1', 'column2'
)
result_kwargs_without_table_name = Metadata._resolve_arguments(
metadata,
['column_name_1', 'column_name_2'],
column_name_1='column1',
column_name_2='column2',
)
result_kwargs_with_table_name = Metadata._resolve_arguments(
metadata,
['column_name_1', 'column_name_2'],
table_name='table1',
column_name_1='column1',
column_name_2='column2',
)
with pytest.raises(ValueError, match=expected_error):
Metadata._resolve_arguments(
metadata,
['column_name_1', 'column_name_2'],
table_name='wrong_table',
column_name_1='column1',
column_name_2='column2',
)

# Assert
expected_result = {
'table_name': 'table1',
'column_name_1': 'column1',
'column_name_2': 'column2',
}
results = [
result_without_table_name,
result_with_table_name,
result_kwargs_without_table_name,
result_kwargs_with_table_name,
]
for result in results:
assert result == expected_result

def test__resolve_argument_multi_table(self):
"""Test the ``resolve_arguments`` method for multi table metadata."""
# Setup
metadata = Mock()
metadata.tables = ['table1', 'table2']

# Run
result_with_table_name = Metadata._resolve_arguments(
metadata, ['column_name_1', 'column_name_2'], 'table1', 'column1', 'column2'
)
result_kwargs_with_table_name = Metadata._resolve_arguments(
metadata,
['column_name_1', 'column_name_2'],
table_name='table1',
column_name_1='column1',
column_name_2='column2',
)

# Assert
expected_result = {
'table_name': 'table1',
'column_name_1': 'column1',
'column_name_2': 'column2',
}
results = [result_with_table_name, result_kwargs_with_table_name]
for result in results:
assert result == expected_result

params = [
('update_column', ['column_name']),
('update_columns', ['column_names']),
Expand Down

0 comments on commit 26fa6bc

Please sign in to comment.