Skip to content

Commit

Permalink
Change code place
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Aug 11, 2023
1 parent 1e8fd04 commit 6ecc336
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 42 deletions.
26 changes: 8 additions & 18 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import visualize_graph
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -561,29 +562,15 @@ def visualize(self, show_table_details='full', show_relationship_labels=True,
edges = []
if show_table_details == 'full':
for table_name, table_meta in self.tables.items():
column_dict = table_meta.columns.items()
columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict]
nodes[table_name] = {
'columns': r'\l'.join(columns),
'columns': create_columns_node(table_meta.columns),
'primary_key': f'Primary key: {table_meta.primary_key}'
}

elif show_table_details == 'summarized':
default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean']
for table_name, table_meta in self.tables.items():
count_dict = defaultdict(int)
for column_name, meta in table_meta.columns.items():
sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype']
count_dict[sdtype] += 1

count_dict = dict(sorted(count_dict.items()))
columns = ['Columns']
columns.extend([
fr'    • {sdtype} : {count}'
for sdtype, count in count_dict.items()
])
nodes[table_name] = {
'columns': r'\l'.join(columns),
'columns': create_summarized_columns_node(table_meta.columns),
'primary_key': f'Primary key: {table_meta.primary_key}'
}

Expand All @@ -610,7 +597,10 @@ def visualize(self, show_table_details='full', show_relationship_labels=True,
if show_table_details:
foreign_keys = r'\l'.join(info.get('foreign_keys', []))
keys = r'\l'.join([info['primary_key'], foreign_keys])
label = fr"{{{table}|{info['columns']}\l|{keys}\l}}"
if foreign_keys:
label = fr"{{{table}|{info['columns']}\l|{keys}\l}}"
else:
label = fr"{{{table}|{info['columns']}\l|{keys}}}"

else:
label = f'{table}'
Expand Down
24 changes: 4 additions & 20 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import logging
import re
import warnings
from collections import defaultdict
from copy import deepcopy
from datetime import datetime

from sdv.metadata.anonymization import SDTYPE_ANONYMIZERS, is_faker_function
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.visualization import visualize_graph
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import cast_to_iterable, load_data_from_csv

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,26 +498,10 @@ def visualize(self, show_table_details='full', output_filepath=None):
raise ValueError("'show_table_details' should be 'full' or 'summarized'.")

if show_table_details == 'full':
columns = [
fr"{name} : {meta.get('sdtype')}\l"
for name, meta in self.columns.items()
]
node = ''.join(columns)
node = fr'{create_columns_node(self.columns)}\l'

elif show_table_details == 'summarized':
default_sdtypes = ['id', 'numerical', 'categorical', 'datetime', 'boolean']
count_dict = defaultdict(int)
for column_name, meta in self.columns.items():
sdtype = 'other' if meta['sdtype'] not in default_sdtypes else meta['sdtype']
count_dict[sdtype] += 1

count_dict = dict(sorted(count_dict.items()))
columns = [r'Columns\l']
columns.extend([
fr'    • {sdtype} : {count}\l'
for sdtype, count in count_dict.items()
])
node = ''.join(columns)
node = fr'{create_summarized_columns_node(self.columns)}\l'

if self.primary_key:
node = fr'{node}|Primary key: {self.primary_key}\l'
Expand Down
49 changes: 49 additions & 0 deletions sdv/metadata/visualization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,58 @@
"""Functions for Metadata visualization."""

import warnings
from collections import defaultdict

import graphviz

DEFAULT_SDTYPES = ['id', 'numerical', 'categorical', 'datetime', 'boolean']


def create_columns_node(columns):
"""Convert columns into text for ``graphviz`` node.
Args:
columns (dict):
A dict mapping the column names with a dictionary containing the ``sdtype`` of the
column name.
Returns:
str:
String representing the node that will be printed for the given columns.
"""
columns = [
fr"{name} : {meta.get('sdtype')}"
for name, meta in columns.items()
]
return r'\l'.join(columns)


def create_summarized_columns_node(columns):
"""Convert columns into summarized text for ``graphviz`` node.
Args:
columns (dict):
A dict mapping the column names with a dictionary containing the ``sdtype`` of the
column name.
Returns:
str:
String representing the node that will be printed for the given columns.
"""
count_dict = defaultdict(int)
for column_name, meta in columns.items():
sdtype = 'other' if meta['sdtype'] not in DEFAULT_SDTYPES else meta['sdtype']
count_dict[sdtype] += 1

count_dict = dict(sorted(count_dict.items()))
columns = ['Columns']
columns.extend([
fr'    • {sdtype} : {count}'
for sdtype, count in count_dict.items()
])

return r'\l'.join(columns)


def _get_graphviz_extension(filepath):
if filepath:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def test_visualize_show_relationship_and_details(self, visualize_graph_mock):
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def test_visualize_show_relationship_and_details_summarized(self, visualize_grap
)
expected_user_label = (
'{users|Columns\\l    • categorical : 1\\l    • id : '
'1\\l|Primary key: id\\l\\l}'
'1\\l|Primary key: id\\l}'
)
expected_nodes = {
'users': expected_user_label,
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def test_visualize_show_relationship_and_details_warning(self, visualize_graph_m
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def test_visualize_show_table_details_only(self, visualize_graph_mock):
'datetime\\l|Primary key: transaction_id\\lForeign key (sessions): session_id\\l}'
)
expected_nodes = {
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l\\l}',
'users': '{users|id : id\\lcountry : categorical\\l|Primary key: id\\l}',
'payments': expected_payments_label,
'sessions': expected_sessions_label,
'transactions': expected_transactions_label
Expand Down

0 comments on commit 6ecc336

Please sign in to comment.