Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SingleTableMetadata visualization #1535

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import visualize_graph
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -561,29 +562,15 @@ def visualize(self, show_table_details='full', show_relationship_labels=True,
edges = []
if show_table_details == 'full':
for table_name, table_meta in self.tables.items():
column_dict = table_meta.columns.items()
columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict]
nodes[table_name] = {
'columns': r'\l'.join(columns),
'columns': create_columns_node(table_meta.columns),
'primary_key': f'Primary key: {table_meta.primary_key}'
}

elif show_table_details == 'summarized':
default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean']
for table_name, table_meta in self.tables.items():
count_dict = defaultdict(int)
for column_name, meta in table_meta.columns.items():
sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype']
count_dict[sdtype] += 1

count_dict = dict(sorted(count_dict.items()))
columns = ['Columns']
columns.extend([
fr'    • {sdtype} : {count}'
for sdtype, count in count_dict.items()
])
nodes[table_name] = {
'columns': r'\l'.join(columns),
'columns': create_summarized_columns_node(table_meta.columns),
'primary_key': f'Primary key: {table_meta.primary_key}'
}

Expand All @@ -610,7 +597,10 @@ def visualize(self, show_table_details='full', show_relationship_labels=True,
if show_table_details:
foreign_keys = r'\l'.join(info.get('foreign_keys', []))
keys = r'\l'.join([info['primary_key'], foreign_keys])
label = fr"{{{table}|{info['columns']}\l|{keys}\l}}"
if foreign_keys:
label = fr"{{{table}|{info['columns']}\l|{keys}\l}}"
else:
label = fr"{{{table}|{info['columns']}\l|{keys}}}"

else:
label = f'{table}'
Expand Down
40 changes: 40 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable, load_data_from_csv

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -477,6 +479,44 @@ def validate(self):
+ '\n'.join([str(e) for e in errors])
)

def visualize(self, show_table_details='full', output_filepath=None):
"""Create a visualization of the single-table dataset.

Args:
show_table_details (str):
If 'full', the column names, primary, alternate and sequence keys are all
shown. If 'summarized', primary, alternate and sequence keys are shown and a
count of the different sdtypes. Defaults to 'full'.
output_filepath (str):
Full path of where to savve the visualization. If None, the visualization is not
saved. Defaults to None.

Returns:
``graphviz.Digraph`` object.
"""
if show_table_details not in ('full', 'summarized'):
raise ValueError("'show_table_details' should be 'full' or 'summarized'.")

if show_table_details == 'full':
node = fr'{create_columns_node(self.columns)}\l'

elif show_table_details == 'summarized':
node = fr'{create_summarized_columns_node(self.columns)}\l'

if self.primary_key:
node = fr'{node}|Primary key: {self.primary_key}\l'

if self.sequence_key:
node = fr'{node}|Sequence key: {self.sequence_key}\l'

if self.alternate_keys:
alternate_keys = [fr'    • {key}\l' for key in self.alternate_keys]
alternate_keys = ''.join(alternate_keys)
node = fr'{node}|Alternate keys:\l {alternate_keys}'

node = {'': f'{{{node}}}'}
return visualize_graph(node, [], output_filepath)

def save_to_json(self, filepath):
"""Save the current ``SingleTableMetadata`` in to a ``json`` file.

Expand Down
49 changes: 49 additions & 0 deletions sdv/metadata/visualization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,58 @@
"""Functions for Metadata visualization."""

import warnings
from collections import defaultdict

import graphviz

DEFAULT_SDTYPES = ['id', 'numerical', 'categorical', 'datetime', 'boolean']


def create_columns_node(columns):
"""Convert columns into text for ``graphviz`` node.

Args:
columns (dict):
A dict mapping the column names with a dictionary containing the ``sdtype`` of the
column name.

Returns:
str:
String representing the node that will be printed for the given columns.
"""
columns = [
fr"{name} : {meta.get('sdtype')}"
for name, meta in columns.items()
]
return r'\l'.join(columns)


def create_summarized_columns_node(columns):
"""Convert columns into summarized text for ``graphviz`` node.

Args:
columns (dict):
A dict mapping the column names with a dictionary containing the ``sdtype`` of the
column name.

Returns:
str:
String representing the node that will be printed for the given columns.
"""
count_dict = defaultdict(int)
for column_name, meta in columns.items():
sdtype = 'other' if meta['sdtype'] not in DEFAULT_SDTYPES else meta['sdtype']
count_dict[sdtype] += 1

count_dict = dict(sorted(count_dict.items()))
columns = ['Columns']
columns.extend([
fr'    • {sdtype} : {count}'
for sdtype, count in count_dict.items()
])

return r'\l'.join(columns)


def _get_graphviz_extension(filepath):
if filepath:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def test_visualize_show_relationship_and_details(self, visualize_graph_mock):
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def test_visualize_show_relationship_and_details_summarized(self, visualize_grap
)
expected_user_label = (
'{users|Columns\\l    • categorical : 1\\l    • id : '
'1\\l|Primary key: id\\l\\l}'
'1\\l|Primary key: id\\l}'
)
expected_nodes = {
'users': expected_user_label,
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def test_visualize_show_relationship_and_details_warning(self, visualize_graph_m
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock):
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,93 @@ def test___repr__(self, mock_json):
mock_json.dumps.assert_called_once_with(instance.to_dict(), indent=4)
assert res == mock_json.dumps.return_value

def test_visualize_with_invalid_input(self):
"""Test that a ``ValueError`` is being raised when ``show_table_details`` is incorrect."""
# Setup
instance = SingleTableMetadata()

# Run and Assert
error_msg = "'show_table_details' should be 'full' or 'summarized'."
with pytest.raises(ValueError, match=error_msg):
instance.visualize(None)

@patch('sdv.metadata.single_table.visualize_graph')
def test_visualize_metadata_full(self, mock_visualize_graph):
"""Test the ``visualize`` method when ``show_table_details`` is 'full'."""
# Setup
instance = SingleTableMetadata()
instance.columns = {
'name': {'sdtype': 'categorical'},
'age': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'phrase': {'sdtype': 'id'},
}

# Run
result = instance.visualize('full')

# Assert
assert result == mock_visualize_graph.return_value
expected_node = {
'': '{name : categorical\\lage : numerical\\lstart_date : datetime\\lphrase : id\\l}'
}
assert mock_visualize_graph.called_once_with(call(expected_node, [], None))

@patch('sdv.metadata.single_table.visualize_graph')
def test_visualize_metadata_summarized(self, mock_visualize_graph):
"""Test the ``visualize`` method when ``show_table_details`` is 'summarized'."""
# Setup
instance = SingleTableMetadata()
instance.columns = {
'name': {'sdtype': 'categorical'},
'age': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'phrase': {'sdtype': 'id'},
}

# Run
result = instance.visualize('summarized')

# Assert
assert result == mock_visualize_graph.return_value
node = (
'{Columns\\l    • categorical : 1\\l    • datetime : 1\\l  '
'  • id : 1\\l    • numerical : 1\\l}'
)
expected_node = {'': node}
assert mock_visualize_graph.called_once_with(call(expected_node, [], None))

@patch('sdv.metadata.single_table.visualize_graph')
def test_visualize_metadata_with_primary_alternate_and_sequence_key(self,
mock_visualize_graph):
"""Test the ``visualize`` method when there are primary, alternate and sequence keys."""
# Setup
instance = SingleTableMetadata()
instance.columns = {
'name': {'sdtype': 'categorical'},
'timestamp': {'sdtype': 'datetime'},
'age': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'phrase': {'sdtype': 'id'},
'passport': {'sdtype': 'id'}
}
instance.primary_key = 'passport'
instance.alternate_keys = ['phrase', 'name']
instance.sequence_key = 'timestamp'

# Run
result = instance.visualize('full')

# Assert
assert result == mock_visualize_graph.return_value
node = (
'{name : categorical\\ltimestamp : datetime\\lage : numerical\\lstart_date : '
'datetime\\lphrase : id\\lpassport : id\\l|Primary key: passport\\l|Sequence key: '
'timestamp\\l|Alternate keys:\\l     • phrase\\l    • name\\l}'
)
expected_node = {'': node}
assert mock_visualize_graph.called_once_with(call(expected_node, [], None))

@patch('sdv.metadata.single_table.read_json')
@patch('sdv.metadata.single_table.convert_metadata')
@patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict')
Expand Down
Loading