From 504072af05b018344e67e425a7416dcf2cba6dc7 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 9 Aug 2023 22:10:36 +0200 Subject: [PATCH 1/3] Add singletable visualization --- sdv/metadata/single_table.py | 56 +++++++++++++++ tests/unit/metadata/test_single_table.py | 87 ++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 8ccc6b7d9..162967574 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -4,6 +4,7 @@ import logging import re import warnings +from collections import defaultdict from copy import deepcopy from datetime import datetime @@ -11,6 +12,7 @@ from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist +from sdv.metadata.visualization import visualize_graph from sdv.utils import cast_to_iterable, load_data_from_csv LOGGER = logging.getLogger(__name__) @@ -477,6 +479,60 @@ def validate(self): + '\n'.join([str(e) for e in errors]) ) + def visualize(self, show_table_details='full', output_filepath=None): + """Create a visualization of the multi-table dataset. + + Args: + show_table_details (str): + If 'full', show the column names, primary, alternate and sequence keys are all + shown along with the table names. If 'summarized' primary, alternate and sequence + keys are shown and a count of the different sdtypes is shown. Defaults to 'full'. + output_filepath (str): + Full path of where to savve the visualization. If None, the visualization is not + saved. Defaults to None. + + Returns: + ``graphviz.Digraph`` object. + """ + if show_table_details not in ('full', 'summarized'): + raise ValueError("'show_table_details' should be 'full' or 'summarized'.") + + if show_table_details == 'full': + columns = [ + fr"{name} : {meta.get('sdtype')}\l" + for name, meta in self.columns.items() + ] + node = ''.join(columns) + + elif show_table_details == 'summarized': + default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] + count_dict = defaultdict(int) + for column_name, meta in self.columns.items(): + sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype'] + count_dict[sdtype] += 1 + + count_dict = dict(sorted(count_dict.items())) + columns = [r'Columns\l'] + columns.extend([ + fr'    • {sdtype} : {count}\l' + for sdtype, count in count_dict.items() + ]) + node = ''.join(columns) + + if self.primary_key: + node = fr'{node}|Primary key: {self.primary_key}\l' + + if self.sequence_key: + node = fr'{node}|Sequence key: {self.sequence_key}\l' + + if self.alternate_keys: + alternate_keys = [fr'    • {key}\l' for key in self.alternate_keys] + alternate_keys = ''.join(alternate_keys) + node = fr'{node}|Alternate keys:\l {alternate_keys}' + + node = {'': f'{{{node}}}'} + return visualize_graph(node, [], output_filepath) + def save_to_json(self, filepath): """Save the current ``SingleTableMetadata`` in to a ``json`` file. diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index 49ddfcd5d..5c3defa6f 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1729,6 +1729,93 @@ def test___repr__(self, mock_json): mock_json.dumps.assert_called_once_with(instance.to_dict(), indent=4) assert res == mock_json.dumps.return_value + def test_visualize_with_invalid_input(self): + """Test that a ``ValueError`` is being raised when ``show_table_details`` is incorrect.""" + # Setup + instance = SingleTableMetadata() + + # Run and Assert + error_msg = "'show_table_details' should be 'full' or 'summarized'." + with pytest.raises(ValueError, match=error_msg): + instance.visualize(None) + + @patch('sdv.metadata.single_table.visualize_graph') + def test_visualize_metadata_full(self, mock_visualize_graph): + """Test the ``visualize`` method when ``show_table_details`` is 'full'.""" + # Setup + instance = SingleTableMetadata() + instance.columns = { + 'name': {'sdtype': 'categorical'}, + 'age': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'phrase': {'sdtype': 'id'}, + } + + # Run + result = instance.visualize('full') + + # Assert + assert result == mock_visualize_graph.return_value + expected_node = { + '': '{name : categorical\\lage : numerical\\lstart_date : datetime\\lphrase : id\\l}' + } + assert mock_visualize_graph.called_once_with(call(expected_node, [], None)) + + @patch('sdv.metadata.single_table.visualize_graph') + def test_visualize_metadata_summarized(self, mock_visualize_graph): + """Test the ``visualize`` method when ``show_table_details`` is 'summarized'.""" + # Setup + instance = SingleTableMetadata() + instance.columns = { + 'name': {'sdtype': 'categorical'}, + 'age': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'phrase': {'sdtype': 'id'}, + } + + # Run + result = instance.visualize('summarized') + + # Assert + assert result == mock_visualize_graph.return_value + node = ( + '{Columns\\l    • categorical : 1\\l    • datetime : 1\\l  ' + '  • id : 1\\l    • numerical : 1\\l}' + ) + expected_node = {'': node} + assert mock_visualize_graph.called_once_with(call(expected_node, [], None)) + + @patch('sdv.metadata.single_table.visualize_graph') + def test_visualize_metadata_with_primary_alternate_and_sequence_key(self, + mock_visualize_graph): + """Test the ``visualize`` method when there are primary, alternate and sequence keys.""" + # Setup + instance = SingleTableMetadata() + instance.columns = { + 'name': {'sdtype': 'categorical'}, + 'timestamp': {'sdtype': 'datetime'}, + 'age': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'phrase': {'sdtype': 'id'}, + 'passport': {'sdtype': 'id'} + } + instance.primary_key = 'passport' + instance.alternate_keys = ['phrase', 'name'] + instance.sequence_key = 'timestamp' + + # Run + result = instance.visualize('full') + + # Assert + assert result == mock_visualize_graph.return_value + node = ( + '{name : categorical\\ltimestamp : datetime\\lage : numerical\\lstart_date : ' + 'datetime\\lphrase : id\\lpassport : id\\l|Primary key: passport\\l|Sequence key: ' + 'timestamp\\l|Alternate keys:\\l     • phrase\\l    • name\\l}' + ) + expected_node = {'': node} + assert mock_visualize_graph.called_once_with(call(expected_node, [], None)) + @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') From 1e8fd04ade5a50a40c0eb6d2f4fe6b500cc98ead Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 11 Aug 2023 15:52:47 +0200 Subject: [PATCH 2/3] Address comments --- sdv/metadata/single_table.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 162967574..552ccba90 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -480,13 +480,13 @@ def validate(self): ) def visualize(self, show_table_details='full', output_filepath=None): - """Create a visualization of the multi-table dataset. + """Create a visualization of the single-table dataset. Args: show_table_details (str): - If 'full', show the column names, primary, alternate and sequence keys are all - shown along with the table names. If 'summarized' primary, alternate and sequence - keys are shown and a count of the different sdtypes is shown. Defaults to 'full'. + If 'full', the column names, primary, alternate and sequence keys are all + shown. If 'summarized', primary, alternate and sequence keys are shown and a + count of the different sdtypes. Defaults to 'full'. output_filepath (str): Full path of where to savve the visualization. If None, the visualization is not saved. Defaults to None. From 6ecc33673ca16a3287037ce2d31556f536386183 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 11 Aug 2023 22:08:38 +0200 Subject: [PATCH 3/3] Change code place --- sdv/metadata/multi_table.py | 26 ++++--------- sdv/metadata/single_table.py | 24 ++---------- sdv/metadata/visualization.py | 49 +++++++++++++++++++++++++ tests/unit/metadata/test_multi_table.py | 8 ++-- 4 files changed, 65 insertions(+), 42 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 256815c45..c7838500c 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -10,7 +10,8 @@ from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.metadata.utils import read_json, validate_file_does_not_exist -from sdv.metadata.visualization import visualize_graph +from sdv.metadata.visualization import ( + create_columns_node, create_summarized_columns_node, visualize_graph) from sdv.utils import cast_to_iterable LOGGER = logging.getLogger(__name__) @@ -561,29 +562,15 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, edges = [] if show_table_details == 'full': for table_name, table_meta in self.tables.items(): - column_dict = table_meta.columns.items() - columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict] nodes[table_name] = { - 'columns': r'\l'.join(columns), + 'columns': create_columns_node(table_meta.columns), 'primary_key': f'Primary key: {table_meta.primary_key}' } elif show_table_details == 'summarized': - default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] for table_name, table_meta in self.tables.items(): - count_dict = defaultdict(int) - for column_name, meta in table_meta.columns.items(): - sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype'] - count_dict[sdtype] += 1 - - count_dict = dict(sorted(count_dict.items())) - columns = ['Columns'] - columns.extend([ - fr'    • {sdtype} : {count}' - for sdtype, count in count_dict.items() - ]) nodes[table_name] = { - 'columns': r'\l'.join(columns), + 'columns': create_summarized_columns_node(table_meta.columns), 'primary_key': f'Primary key: {table_meta.primary_key}' } @@ -610,7 +597,10 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, if show_table_details: foreign_keys = r'\l'.join(info.get('foreign_keys', [])) keys = r'\l'.join([info['primary_key'], foreign_keys]) - label = fr"{{{table}|{info['columns']}\l|{keys}\l}}" + if foreign_keys: + label = fr"{{{table}|{info['columns']}\l|{keys}\l}}" + else: + label = fr"{{{table}|{info['columns']}\l|{keys}}}" else: label = f'{table}' diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 552ccba90..dbd049328 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -4,7 +4,6 @@ import logging import re import warnings -from collections import defaultdict from copy import deepcopy from datetime import datetime @@ -12,7 +11,8 @@ from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist -from sdv.metadata.visualization import visualize_graph +from sdv.metadata.visualization import ( + create_columns_node, create_summarized_columns_node, visualize_graph) from sdv.utils import cast_to_iterable, load_data_from_csv LOGGER = logging.getLogger(__name__) @@ -498,26 +498,10 @@ def visualize(self, show_table_details='full', output_filepath=None): raise ValueError("'show_table_details' should be 'full' or 'summarized'.") if show_table_details == 'full': - columns = [ - fr"{name} : {meta.get('sdtype')}\l" - for name, meta in self.columns.items() - ] - node = ''.join(columns) + node = fr'{create_columns_node(self.columns)}\l' elif show_table_details == 'summarized': - default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] - count_dict = defaultdict(int) - for column_name, meta in self.columns.items(): - sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype'] - count_dict[sdtype] += 1 - - count_dict = dict(sorted(count_dict.items())) - columns = [r'Columns\l'] - columns.extend([ - fr'    • {sdtype} : {count}\l' - for sdtype, count in count_dict.items() - ]) - node = ''.join(columns) + node = fr'{create_summarized_columns_node(self.columns)}\l' if self.primary_key: node = fr'{node}|Primary key: {self.primary_key}\l' diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py index c8e0a7e61..d5fa3eac8 100644 --- a/sdv/metadata/visualization.py +++ b/sdv/metadata/visualization.py @@ -1,9 +1,58 @@ """Functions for Metadata visualization.""" import warnings +from collections import defaultdict import graphviz +DEFAULT_SDTYPES = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] + + +def create_columns_node(columns): + """Convert columns into text for ``graphviz`` node. + + Args: + columns (dict): + A dict mapping the column names with a dictionary containing the ``sdtype`` of the + column name. + + Returns: + str: + String representing the node that will be printed for the given columns. + """ + columns = [ + fr"{name} : {meta.get('sdtype')}" + for name, meta in columns.items() + ] + return r'\l'.join(columns) + + +def create_summarized_columns_node(columns): + """Convert columns into summarized text for ``graphviz`` node. + + Args: + columns (dict): + A dict mapping the column names with a dictionary containing the ``sdtype`` of the + column name. + + Returns: + str: + String representing the node that will be printed for the given columns. + """ + count_dict = defaultdict(int) + for column_name, meta in columns.items(): + sdtype = 'other' if meta['sdtype'] not in DEFAULT_SDTYPES else meta['sdtype'] + count_dict[sdtype] += 1 + + count_dict = dict(sorted(count_dict.items())) + columns = ['Columns'] + columns.extend([ + fr'    • {sdtype} : {count}' + for sdtype, count in count_dict.items() + ]) + + return r'\l'.join(columns) + def _get_graphviz_extension(filepath): if filepath: diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index c1543fecc..c0ece36cb 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1235,7 +1235,7 @@ def test_visualize_show_relationship_and_details(self, visualize_graph_mock): 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label @@ -1283,7 +1283,7 @@ def test_visualize_show_relationship_and_details_summarized(self, visualize_grap ) expected_user_label = ( '{users|Columns\\l    • categorical : 1\\l    • id : ' - '1\\l|Primary key: id\\l\\l}' + '1\\l|Primary key: id\\l}' ) expected_nodes = { 'users': expected_user_label, @@ -1339,7 +1339,7 @@ def test_visualize_show_relationship_and_details_warning(self, visualize_graph_m 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label @@ -1479,7 +1479,7 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock): 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label