Skip to content

Commit

Permalink
Do not enforce min/max on sequence index column
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed May 31, 2024
1 parent 29cf341 commit c4e2ce3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def _preprocess(self, data):
preprocessed = super()._preprocess(data)

if self._sequence_index:
sequence_index_transformer = self.get_transformers()[self._sequence_index]
if sequence_index_transformer.enforce_min_max_values:
sequence_index_transformer.enforce_min_max_values = False
preprocessed = self._transform_sequence_index(preprocessed)

return preprocessed
Expand Down
44 changes: 44 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,47 @@ def test_par_missing_sequence_index():
# Assert
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()


def test_par_unique_sequence_index_with_enforce_min_max():
"""Test to see if there are duplicate sequence index values
when sequence_length is higher than real data
"""
# Setup
test_id = list(range(10))
s_key = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
visits = [
'2021-01-01', '2021-01-03', '2021-01-05', '2021-01-07', '2021-01-09',
'2021-09-11', '2021-09-17', '2021-10-01', '2021-10-08', '2021-11-01'
]
pre_date = [
'2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05',
'2021-04-01', '2021-04-02', '2021-04-03', '2021-04-04', '2021-04-05'
]
test_df = pd.DataFrame({
'id': test_id,
's_key': s_key,
'visits': visits,
'pre_date': pre_date
})
test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply(
pd.to_datetime, format='%Y-%m-%d', errors='coerce')
test_df['pre_date'] = pd.to_datetime(test_df['pre_date'], unit='ns').astype(int)
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(test_df)
metadata.update_column(column_name='pre_date', sdtype='numerical')
metadata.update_column(column_name='s_key', sdtype='id')
metadata.set_sequence_key('s_key')
metadata.set_sequence_index('visits')
synthesizer = PARSynthesizer(metadata, enforce_min_max_values=True,
enforce_rounding=False, epochs=100, verbose=True)

# Run
synthesizer.fit(test_df)
synth_df = synthesizer.sample(num_sequences=50, sequence_length=50)

# Assert
for i in synth_df['s_key'].unique():
seq_df = synth_df[synth_df['s_key'] == i]
has_duplicates = seq_df['visits'].duplicated().any()
assert not has_duplicates
3 changes: 3 additions & 0 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def test_preprocess(self, base_preprocess_mock):
par._transform_sequence_index = Mock()
par.auto_assign_transformers = Mock()
par.update_transformers = Mock()
get_transform_mock = Mock()
get_transform_mock.return_value = {'time': Mock()}
par.get_transformers = get_transform_mock
par._data_processor._prepared_for_fitting = True
data = self.get_data()

Expand Down

0 comments on commit c4e2ce3

Please sign in to comment.