Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for simple PARSynthesizer constraints #2044

Merged
merged 7 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import logging
import uuid
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -135,13 +134,53 @@ def get_parameters(self):
return instantiated_parameters

def add_constraints(self, constraints):
"""Warn the user that constraints can't be added to the ``PARSynthesizer``."""
warnings.warn(
'The PARSynthesizer does not yet support constraints. This model will ignore any '
'constraints in the metadata.'
)
self._data_processor._constraints = []
self._data_processor._constraints_list = []
"""Add constraints to the synthesizer.

For PARSynthesizers only allow a list of constraints that follow these rules:

1) All constraints must be either for all contextual columns or non-contextual column.
No mixing constraints that cover both contextual and non-contextual columns
2) No overlapping constraints (there are no constraints that act on the same column)
3) No custom constraints

Args:
constraints (list):
List of constraints described as dictionaries in the following format:
* ``constraint_class``: Name of the constraint to apply.
* ``constraint_parameters``: A dictionary with the constraint parameters.
"""
context_set = set(self.context_columns)
constraint_cols = []
for constraint in constraints:
constraint_parameters = constraint['constraint_parameters']
columns = []
for param in constraint_parameters:
if 'column_name' in param:
col_names = constraint_parameters[param]
if isinstance(col_names, list):
columns.extend(col_names)
else:
columns.append(col_names)
for col in columns:
if col in constraint_cols:
raise SynthesizerInputError(
'The PARSynthesizer cannot accommodate multiple constraints '
'that overlap on the same columns.')
constraint_cols.append(col)

all_context = all(col in context_set for col in constraint_cols)
no_context = all(col not in context_set for col in constraint_cols)

if all_context or no_context:
super().add_constraints(constraints)
else:
raise SynthesizerInputError(
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.')

def load_custom_constraint_classes(self, filepath, class_names):
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
"""Error that tells the user custom constraints can't be used in the ``PARSynthesizer``."""
raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.')

def _validate_context_columns(self, data):
errors = []
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
import re

import numpy as np
import pandas as pd
import pytest
from deepecho import load_demo

from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError
from sdv.metadata import SingleTableMetadata
from sdv.sequential import PARSynthesizer

Expand Down Expand Up @@ -284,6 +287,60 @@ def test_par_missing_sequence_index():
assert (sampled.dtypes == data.dtypes).all()


def test_constraints_on_par():
"""Test if only simple constraints work on PARSynthesizer."""
# Setup
real_data, metadata = download_demo(
modality='sequential',
dataset_name='nasdaq100_2019'
)

synthesizer = PARSynthesizer(
metadata,
epochs=5,
context_columns=['Sector', 'Industry']
)

market_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'MarketCap',
'strict_boundaries': True
}
}
volume_constraint = {
'constraint_class': 'Positive',
'constraint_parameters': {
'column_name': 'Volume',
'strict_boundaries': True
}
}

context_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'Sector',
'strict_boundaries': True
}
}

# Run
synthesizer.add_constraints([volume_constraint, market_constraint])
synthesizer.fit(real_data)
samples = synthesizer.sample(50, 10)

# Assert
assert not (samples['MarketCap'] < 0).any().any()
assert not (samples['Volume'] < 0).any().any()
mixed_constraint_error_msg = re.escape(
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.'
)

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([volume_constraint, context_constraint])


def test_par_unique_sequence_index_with_enforce_min_max():
"""Test to see if there are duplicate sequence index values
when sequence_length is higher than real data
Expand Down
84 changes: 71 additions & 13 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,82 @@ def test___init___no_sequence_key(self):
verbose=False
)

@patch('sdv.sequential.par.warnings')
def test_add_constraints(self, warnings_mock):
"""Test that if constraints are being added, a warning is raised."""
def test_add_constraints(self):
"""Test that that only simple constraints can be added to PARSynthesizer."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata)
synthesizer = PARSynthesizer(metadata=metadata,
context_columns=['name', 'measurement'])
name_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'name'
}
}
measurement_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'measurement'
}
}
gender_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'gender'
}
}
time_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_name': 'time'
}
}
multi_constraint = {
'constraint_class': 'Mock',
'constraint_parameters': {
'column_names': ['name', 'time']
}
}
overlapping_error_msg = re.escape(
'The PARSynthesizer cannot accommodate multiple constraints '
'that overlap on the same columns.'
)
mixed_constraint_error_msg = re.escape(
'The PARSynthesizer cannot accommodate constraints '
'with a mix of context and non-context columns.'
)

# Run
synthesizer.add_constraints([object()])
# Run and Assert
with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([name_constraint, gender_constraint])

# Assert
warning_message = (
'The PARSynthesizer does not yet support constraints. This model will ignore any '
'constraints in the metadata.'
with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([time_constraint, measurement_constraint])

with pytest.raises(SynthesizerInputError, match=mixed_constraint_error_msg):
synthesizer.add_constraints([multi_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([multi_constraint, name_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([name_constraint, name_constraint])

with pytest.raises(SynthesizerInputError, match=overlapping_error_msg):
synthesizer.add_constraints([gender_constraint, gender_constraint])

def test_load_custom_constraint_classes(self):
"""Test that if custom constraint is being added, an error is raised."""
# Setup
metadata = self.get_metadata()
synthesizer = PARSynthesizer(metadata=metadata)

# Run and Assert
error_message = re.escape(
'The PARSynthesizer cannot accommodate custom constraints.'
)
warnings_mock.warn.assert_called_once_with(warning_message)
assert synthesizer._data_processor._constraints == []
assert synthesizer._data_processor._constraints_list == []
with pytest.raises(SynthesizerInputError, match=error_message):
synthesizer.load_custom_constraint_classes(filepath='test', class_names=[])

def test_get_parameters(self):
"""Test that it returns every ``init`` parameter without the ``metadata``."""
Expand Down
Loading