From 6ecc33673ca16a3287037ce2d31556f536386183 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 11 Aug 2023 22:08:38 +0200 Subject: [PATCH] Change code place --- sdv/metadata/multi_table.py | 26 ++++--------- sdv/metadata/single_table.py | 24 ++---------- sdv/metadata/visualization.py | 49 +++++++++++++++++++++++++ tests/unit/metadata/test_multi_table.py | 8 ++-- 4 files changed, 65 insertions(+), 42 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 256815c45..c7838500c 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -10,7 +10,8 @@ from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.metadata.utils import read_json, validate_file_does_not_exist -from sdv.metadata.visualization import visualize_graph +from sdv.metadata.visualization import ( + create_columns_node, create_summarized_columns_node, visualize_graph) from sdv.utils import cast_to_iterable LOGGER = logging.getLogger(__name__) @@ -561,29 +562,15 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, edges = [] if show_table_details == 'full': for table_name, table_meta in self.tables.items(): - column_dict = table_meta.columns.items() - columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict] nodes[table_name] = { - 'columns': r'\l'.join(columns), + 'columns': create_columns_node(table_meta.columns), 'primary_key': f'Primary key: {table_meta.primary_key}' } elif show_table_details == 'summarized': - default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] for table_name, table_meta in self.tables.items(): - count_dict = defaultdict(int) - for column_name, meta in table_meta.columns.items(): - sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype'] - count_dict[sdtype] += 1 - - count_dict = dict(sorted(count_dict.items())) - columns = ['Columns'] - columns.extend([ - fr'    • {sdtype} : {count}' - for sdtype, count in count_dict.items() - ]) nodes[table_name] = { - 'columns': r'\l'.join(columns), + 'columns': create_summarized_columns_node(table_meta.columns), 'primary_key': f'Primary key: {table_meta.primary_key}' } @@ -610,7 +597,10 @@ def visualize(self, show_table_details='full', show_relationship_labels=True, if show_table_details: foreign_keys = r'\l'.join(info.get('foreign_keys', [])) keys = r'\l'.join([info['primary_key'], foreign_keys]) - label = fr"{{{table}|{info['columns']}\l|{keys}\l}}" + if foreign_keys: + label = fr"{{{table}|{info['columns']}\l|{keys}\l}}" + else: + label = fr"{{{table}|{info['columns']}\l|{keys}}}" else: label = f'{table}' diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 552ccba90..dbd049328 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -4,7 +4,6 @@ import logging import re import warnings -from collections import defaultdict from copy import deepcopy from datetime import datetime @@ -12,7 +11,8 @@ from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist -from sdv.metadata.visualization import visualize_graph +from sdv.metadata.visualization import ( + create_columns_node, create_summarized_columns_node, visualize_graph) from sdv.utils import cast_to_iterable, load_data_from_csv LOGGER = logging.getLogger(__name__) @@ -498,26 +498,10 @@ def visualize(self, show_table_details='full', output_filepath=None): raise ValueError("'show_table_details' should be 'full' or 'summarized'.") if show_table_details == 'full': - columns = [ - fr"{name} : {meta.get('sdtype')}\l" - for name, meta in self.columns.items() - ] - node = ''.join(columns) + node = fr'{create_columns_node(self.columns)}\l' elif show_table_details == 'summarized': - default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] - count_dict = defaultdict(int) - for column_name, meta in self.columns.items(): - sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype'] - count_dict[sdtype] += 1 - - count_dict = dict(sorted(count_dict.items())) - columns = [r'Columns\l'] - columns.extend([ - fr'    • {sdtype} : {count}\l' - for sdtype, count in count_dict.items() - ]) - node = ''.join(columns) + node = fr'{create_summarized_columns_node(self.columns)}\l' if self.primary_key: node = fr'{node}|Primary key: {self.primary_key}\l' diff --git a/sdv/metadata/visualization.py b/sdv/metadata/visualization.py index c8e0a7e61..d5fa3eac8 100644 --- a/sdv/metadata/visualization.py +++ b/sdv/metadata/visualization.py @@ -1,9 +1,58 @@ """Functions for Metadata visualization.""" import warnings +from collections import defaultdict import graphviz +DEFAULT_SDTYPES = ['id', 'numerical', 'categorical', 'datetime', 'boolean'] + + +def create_columns_node(columns): + """Convert columns into text for ``graphviz`` node. + + Args: + columns (dict): + A dict mapping the column names with a dictionary containing the ``sdtype`` of the + column name. + + Returns: + str: + String representing the node that will be printed for the given columns. + """ + columns = [ + fr"{name} : {meta.get('sdtype')}" + for name, meta in columns.items() + ] + return r'\l'.join(columns) + + +def create_summarized_columns_node(columns): + """Convert columns into summarized text for ``graphviz`` node. + + Args: + columns (dict): + A dict mapping the column names with a dictionary containing the ``sdtype`` of the + column name. + + Returns: + str: + String representing the node that will be printed for the given columns. + """ + count_dict = defaultdict(int) + for column_name, meta in columns.items(): + sdtype = 'other' if meta['sdtype'] not in DEFAULT_SDTYPES else meta['sdtype'] + count_dict[sdtype] += 1 + + count_dict = dict(sorted(count_dict.items())) + columns = ['Columns'] + columns.extend([ + fr'    • {sdtype} : {count}' + for sdtype, count in count_dict.items() + ]) + + return r'\l'.join(columns) + def _get_graphviz_extension(filepath): if filepath: diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index c1543fecc..c0ece36cb 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1235,7 +1235,7 @@ def test_visualize_show_relationship_and_details(self, visualize_graph_mock): 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label @@ -1283,7 +1283,7 @@ def test_visualize_show_relationship_and_details_summarized(self, visualize_grap ) expected_user_label = ( '{users|Columns\\l    • categorical : 1\\l    • id : ' - '1\\l|Primary key: id\\l\\l}' + '1\\l|Primary key: id\\l}' ) expected_nodes = { 'users': expected_user_label, @@ -1339,7 +1339,7 @@ def test_visualize_show_relationship_and_details_warning(self, visualize_graph_m 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label @@ -1479,7 +1479,7 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock): 'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}' ) expected_nodes = { - 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}', + 'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}', 'payments': expected_payments_label, 'sessions': expected_sessions_label, 'transactions': expected_transactions_label