Skip to content

Commit

Permalink
Add detect_from_csvs and detect_from_dataframes methods to MultiTable…
Browse files Browse the repository at this point in the history
…Metadata (#1533)

* define detection methods + tests

* address comments

* use Pathlib

* modify test to use tmp_path

* test detect_table_from_csv

* use load_data_from_csv
  • Loading branch information
R-Palazzo committed Aug 14, 2023
1 parent 0433374 commit 72f7e1f
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 5 deletions.
42 changes: 40 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path

import pandas as pd

from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable
from sdv.utils import cast_to_iterable, load_data_from_csv

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -344,6 +347,19 @@ def detect_table_from_dataframe(self, table_name, data):
self.tables[table_name] = table
self._log_detected_table(table)

def detect_from_dataframes(self, data):
"""Detect the metadata for all tables in a dictionary of dataframes.
Args:
data (dict):
Dictionary of table names to dataframes.
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')

for table_name, dataframe in data.items():
self.detect_table_from_dataframe(table_name, dataframe)

def detect_table_from_csv(self, table_name, filepath):
"""Detect the metadata for a table from a csv file.
Expand All @@ -355,11 +371,33 @@ def detect_table_from_csv(self, table_name, filepath):
"""
self._validate_table_not_detected(table_name)
table = SingleTableMetadata()
data = table._load_data_from_csv(filepath)
data = load_data_from_csv(filepath)
table._detect_columns(data)
self.tables[table_name] = table
self._log_detected_table(table)

def detect_from_csvs(self, folder_name):
"""Detect the metadata for all tables in a folder of csv files.
Args:
folder_name (str):
Name of the folder to detect the metadata from.
"""
folder_path = Path(folder_name)

if folder_path.is_dir():
csv_files = list(folder_path.rglob('*.csv'))
else:
raise ValueError(f"The folder '{folder_name}' does not exist.")

if not csv_files:
raise ValueError(f"No CSV files detected in the folder '{folder_name}'.")

for csv_file in csv_files:
table_name = csv_file.stem
self.detect_table_from_csv(table_name, str(csv_file))

def set_primary_key(self, table_name, column_name):
"""Set the primary key of a table.
Expand Down
136 changes: 136 additions & 0 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json

from sdv.datasets.demo import download_demo
from sdv.metadata import MultiTableMetadata


Expand Down Expand Up @@ -132,3 +133,138 @@ def test_upgrade_metadata(tmp_path):
assert new_metadata['tables'] == expected_metadata['tables']
for relationship in new_metadata['relationships']:
assert relationship in expected_metadata['relationships']


def test_detect_from_dataframes():
"""Test the ``detect_from_dataframes`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

# Run
metadata.detect_from_dataframes(real_data)

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
},
'guests': {
'columns': {
'guest_email': {'sdtype': 'categorical'},
'hotel_id': {'sdtype': 'categorical'},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical'},
'checkin_date': {'sdtype': 'categorical'},
'checkout_date': {'sdtype': 'categorical'},
'room_rate': {'sdtype': 'numerical'},
'billing_address': {'sdtype': 'categorical'},
'credit_card_number': {'sdtype': 'numerical'}
}
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata


def test_detect_from_csvs(tmp_path):
"""Test the ``detect_from_csvs`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

for table_name, dataframe in real_data.items():
csv_path = tmp_path / f'{table_name}.csv'
dataframe.to_csv(csv_path, index=False)

# Run
metadata.detect_from_csvs(folder_name=tmp_path)

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
},
'guests': {
'columns': {
'guest_email': {'sdtype': 'categorical'},
'hotel_id': {'sdtype': 'categorical'},
'has_rewards': {'sdtype': 'boolean'},
'room_type': {'sdtype': 'categorical'},
'amenities_fee': {'sdtype': 'numerical'},
'checkin_date': {'sdtype': 'categorical'},
'checkout_date': {'sdtype': 'categorical'},
'room_rate': {'sdtype': 'numerical'},
'billing_address': {'sdtype': 'categorical'},
'credit_card_number': {'sdtype': 'numerical'}
}
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata


def test_detect_table_from_csv(tmp_path):
"""Test the ``detect_table_from_csv`` method."""
# Setup
real_data, _ = download_demo(
modality='multi_table',
dataset_name='fake_hotels'
)

metadata = MultiTableMetadata()

for table_name, dataframe in real_data.items():
csv_path = tmp_path / f'{table_name}.csv'
dataframe.to_csv(csv_path, index=False)

# Run
metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv')

# Assert
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'categorical'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'}
}
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

assert metadata.to_dict() == expected_metadata
102 changes: 99 additions & 3 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,8 @@ def test_update_column_table_does_not_exist(self):

@patch('sdv.metadata.multi_table.LOGGER')
@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_detect_table_from_csv(self, single_table_mock, log_mock):
@patch('sdv.metadata.multi_table.load_data_from_csv')
def test_detect_table_from_csv(self, load_csv_mock, single_table_mock, log_mock):
"""Test the ``detect_table_from_csv`` method.
If the table does not already exist, a ``SingleTableMetadata`` instance
Expand All @@ -1604,7 +1605,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock):
# Setup
metadata = MultiTableMetadata()
fake_data = Mock()
single_table_mock.return_value._load_data_from_csv.return_value = fake_data
load_csv_mock.return_value = fake_data
single_table_mock.return_value.to_dict.return_value = {
'columns': {'a': {'sdtype': 'numerical'}}
}
Expand All @@ -1613,7 +1614,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock):
metadata.detect_table_from_csv('table', 'path.csv')

# Assert
single_table_mock.return_value._load_data_from_csv.assert_called_once_with('path.csv')
load_csv_mock.assert_called_once_with('path.csv')
single_table_mock.return_value._detect_columns.assert_called_once_with(fake_data)
assert metadata.tables == {'table': single_table_mock.return_value}

Expand Down Expand Up @@ -1656,6 +1657,59 @@ def test_detect_table_from_csv_table_already_exists(self):
with pytest.raises(InvalidMetadataError, match=error_message):
metadata.detect_table_from_csv('table', 'path.csv')

def test_detect_from_csvs(self, tmp_path):
"""Test the ``detect_from_csvs`` method.
The method should call ``detect_table_from_csv`` for each csv in the folder.
"""
# Setup
instance = MultiTableMetadata()
instance.detect_table_from_csv = Mock()

data1 = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
data2 = pd.DataFrame({'col1': [5, 6], 'col2': [7, 8]})

filepath1 = tmp_path / 'table1.csv'
filepath2 = tmp_path / 'table2.csv'
data1.to_csv(filepath1, index=False)
data2.to_csv(filepath2, index=False)

json_filepath = tmp_path / 'not_csv.json'
with open(json_filepath, 'w') as json_file:
json_file.write('{"key": "value"}')

# Run
instance.detect_from_csvs(tmp_path)

# Assert
expected_calls = [
call('table1', str(filepath1)),
call('table2', str(filepath2))
]

instance.detect_table_from_csv.assert_has_calls(expected_calls, any_order=True)
assert instance.detect_table_from_csv.call_count == 2

def test_detect_from_csvs_no_csv(self, tmp_path):
"""Test the ``detect_from_csvs`` method with no csv file in the folder."""
# Setup
instance = MultiTableMetadata()

json_filepath = tmp_path / 'not_csv.json'
with open(json_filepath, 'w') as json_file:
json_file.write('{"key": "value"}')

# Run and Assert
expected_message = re.escape("No CSV files detected in the folder '{}'.".format(tmp_path))
with pytest.raises(ValueError, match=expected_message):
instance.detect_from_csvs(tmp_path)

expected_message_folder = re.escape(
"The folder '{}' does not exist.".format('not_a_folder')
)
with pytest.raises(ValueError, match=expected_message_folder):
instance.detect_from_csvs('not_a_folder')

@patch('sdv.metadata.multi_table.LOGGER')
@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_detect_table_from_dataframe(self, single_table_mock, log_mock):
Expand Down Expand Up @@ -1723,6 +1777,48 @@ def test_detect_table_from_dataframe_table_already_exists(self):
with pytest.raises(InvalidMetadataError, match=error_message):
metadata.detect_table_from_dataframe('table', pd.DataFrame())

def test_detect_from_dataframes(self):
"""Test ``detect_from_dataframes``.
Expected to call ``detect_table_from_dataframe`` for each table name and dataframe
in the input.
"""
# Setup
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe = Mock()

guests_table = pd.DataFrame()
hotels_table = pd.DataFrame()

# Run
metadata.detect_from_dataframes(
data={
'guests': guests_table,
'hotels': hotels_table
}
)

# Assert
metadata.detect_table_from_dataframe.assert_any_call('guests', guests_table)
metadata.detect_table_from_dataframe.assert_any_call('hotels', hotels_table)

def test_detect_from_dataframes_no_dataframes(self):
"""Test ``detect_from_dataframes`` with no dataframes in the input.
Expected to raise an error.
"""
# Setup
metadata = MultiTableMetadata()

# Run and Assert
expected_message = 'The provided dictionary must contain only pandas DataFrame objects.'

with pytest.raises(ValueError, match=expected_message):
metadata.detect_from_dataframes(data={})

with pytest.raises(ValueError, match=expected_message):
metadata.detect_from_dataframes(data={'a': 1})

def test__validate_table_exists(self):
"""Test ``_validate_table_exists``.
Expand Down

0 comments on commit 72f7e1f

Please sign in to comment.