Skip to content

Commit

Permalink
Fix constraints evaluating int or float instances
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Sep 6, 2024
1 parent 1934059 commit 633447e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install invoke .[test]
- name: Run integration tests
env:
PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }}

run: |
invoke integration
invoke benchmark-dtypes
9 changes: 9 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import pandas as pd
from pandas.api.types import is_float, is_integer
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from rdt.transformers.utils import _GENERATORS

Expand Down Expand Up @@ -439,3 +440,11 @@ def get_possible_chars(regex, num_subpatterns=None):
possible_chars += _get_chars_for_option(option, params)

return possible_chars


def _is_numerical(value):
"""Determine if the input is a numerical type or not."""
try:
return is_integer(value) or is_float(value)
except Exception:
return False
12 changes: 5 additions & 7 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import numpy as np
import pandas as pd

from sdv._utils import _convert_to_timedelta, _create_unique_name, _is_datetime_type
from sdv._utils import _convert_to_timedelta, _create_unique_name, _is_datetime_type, _is_numerical
from sdv.constraints.base import Constraint
from sdv.constraints.errors import (
AggregateConstraintsError,
Expand Down Expand Up @@ -604,7 +604,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs):
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
value = kwargs.get('value')
if sdtype == 'numerical':
if not isinstance(value, (int, float)):
if not _is_numerical(value):
raise ConstraintMetadataError("'value' must be an int or float.")

elif sdtype == 'datetime':
Expand Down Expand Up @@ -632,7 +632,7 @@ def _validate_init_inputs(column_name, value, relation):
if relation not in ['>', '>=', '<', '<=']:
raise ValueError('`relation` must be one of the following: `>`, `>=`, `<`, `<=`')

if not (isinstance(value, (int, float)) or value_is_datetime):
if not (_is_numerical(value) or value_is_datetime):
raise ValueError('`value` must be a number or a string that represents a datetime.')

if value_is_datetime and not isinstance(value, str):
Expand Down Expand Up @@ -1071,9 +1071,7 @@ def _validate_init_inputs(low_value, high_value):
if values_are_datetimes and not values_are_strings:
raise ValueError('Datetime must be represented as a string.')

values_are_numerical = bool(
isinstance(low_value, (int, float)) and isinstance(high_value, (int, float))
)
values_are_numerical = bool(_is_numerical(low_value) and _is_numerical(high_value))
if not (values_are_numerical or values_are_datetimes):
raise ValueError(
'``low_value`` and ``high_value`` must be a number or a string that '
Expand All @@ -1092,7 +1090,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high_value = kwargs.get('high_value')
low_value = kwargs.get('low_value')
if sdtype == 'numerical':
if not isinstance(high_value, (int, float)) or not isinstance(low_value, (int, float)):
if not _is_numerical(high_value) or not _is_numerical(low_value):
raise ConstraintMetadataError(
"Both 'high_value' and 'low_value' must be ints or floats"
)
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_get_datetime_format,
_get_root_tables,
_is_datetime_type,
_is_numerical,
_validate_foreign_keys_not_null,
check_sdv_versions_and_warn,
check_synthesizer_version,
Expand Down Expand Up @@ -713,3 +714,33 @@ def test_get_possible_chars():
nums = [str(i) for i in range(10)]
lowercase_letters = list(string.ascii_lowercase)
assert possible_chars == prefix + nums + ['_'] + lowercase_letters


def test__is_numerical():
"""Test that ensures that if passed any numerical data type we will get a ``True``."""
# Setup
np_int = np.int16(10)
np_nan = np.nan

# Run
np_int_result = _is_numerical(np_int)
np_nan_result = _is_numerical(np_nan)

# Assert
assert np_int_result
assert np_nan_result


def test__is_numerical_string():
"""Test that ensures that if passed any other value but numerical it will return `False`."""
# Setup
str_value = 'None'
datetime_value = pd.to_datetime('2012-01-01')

# Run
str_result = _is_numerical(str_value)
datetime_result = _is_numerical(datetime_value)

# Assert
assert str_result is False
assert datetime_result is False

0 comments on commit 633447e

Please sign in to comment.