Skip to content

Commit

Permalink
Improve metadata detection (#1529)
Browse files Browse the repository at this point in the history
* def unit test determine methods

* add unit test detect_column

* cleaning

* add integration tests

* import

* remove 'u' kind

* update condition and test

* address comments

* address comments

* rebase

* warning datetime detection

* silence pandas warning

* subsample datetime detection

* address comments

* rebase
  • Loading branch information
R-Palazzo authored and amontanez24 committed Sep 27, 2023
1 parent 0bc7508 commit 53bb4a3
Show file tree
Hide file tree
Showing 7 changed files with 736 additions and 370 deletions.
100 changes: 91 additions & 9 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import (
cast_to_iterable, format_invalid_values_string, is_boolean_type, is_datetime_type,
is_numerical_type, load_data_from_csv, validate_datetime_format)
cast_to_iterable, format_invalid_values_string, get_datetime_format, is_boolean_type,
is_datetime_type, is_numerical_type, load_data_from_csv, validate_datetime_format)

LOGGER = logging.getLogger(__name__)

Expand All @@ -35,10 +35,7 @@ class SingleTableMetadata:
}

_DTYPES_TO_SDTYPES = {
'i': 'numerical',
'f': 'numerical',
'O': 'categorical',
'b': 'boolean',
'b': 'categorical',
'M': 'datetime',
}

Expand Down Expand Up @@ -252,11 +249,96 @@ def to_dict(self):

return deepcopy(metadata)

def _determine_sdtype_for_numbers(self, data):
"""Determine the sdtype for a numerical column.
Args:
data (pandas.Series):
The data to be analyzed.
"""
sdtype = 'numerical'
if len(data) > 5:
is_not_null = ~data.isna()
whole_values = (data == data.round()).loc[is_not_null].all()
positive_values = (data >= 0).loc[is_not_null].all()

unique_values = data.nunique()
unique_lt_categorical_threshold = unique_values <= round(len(data) / 10)

if whole_values and positive_values and unique_lt_categorical_threshold:
sdtype = 'categorical'
elif unique_values == len(data) and whole_values:
sdtype = 'id'

return sdtype

def _determine_sdtype_for_objects(self, data):
"""Determine the sdtype for an object column.
Args:
data (pandas.Series):
The data to be analyzed.
"""
sdtype = None
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning)

data_test = data.sample(10000) if len(data) > 10000 else data
datetime_format = get_datetime_format(data_test)
pd.to_datetime(data_test, format=datetime_format, errors='raise')

sdtype = 'datetime'

except Exception:
if len(data) <= 5:
sdtype = 'categorical'
else:
unique_values = data.nunique()
if unique_values == len(data):
sdtype = 'id'
elif unique_values <= round(len(data) / 5):
sdtype = 'categorical'
else:
sdtype = 'unknown'

return sdtype

def _detect_columns(self, data):
"""Detect columns sdtype in the data.
Args:
data (pandas.DataFrame):
The data to be analyzed.
"""
for field in data:
clean_data = data[field].dropna()
kind = clean_data.infer_objects().dtype.kind
self.columns[field] = {'sdtype': self._DTYPES_TO_SDTYPES.get(kind, 'categorical')}
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = None
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)

if sdtype is None:
raise InvalidMetadataError(
f"Unsupported data type for column '{field}' (kind: {dtype})."
"The valid data types are: 'object', 'int', 'float', 'datetime', 'bool'."
)
column_dict = {'sdtype': sdtype}

if sdtype == 'unknown':
column_dict['pii'] = True
elif sdtype == 'datetime' and dtype == 'O':
datetime_format = get_datetime_format(column_data.iloc[:100])
column_dict['datetime_format'] = datetime_format

self.columns[field] = deepcopy(column_dict)

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.
Expand Down
1 change: 1 addition & 0 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_datetime_format(value):

value = value[~value.isna()]
value = value.astype(str).to_numpy()

return _guess_datetime_format_for_array(value)


Expand Down
Loading

0 comments on commit 53bb4a3

Please sign in to comment.