Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message when sampling on a non-CPU device #2016

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,14 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we only switch to the new error message after confirming the error that surfaced has that specific message about the current machine being CPU-only?

'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
9 changes: 8 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,14 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this occur on the sampling instead ? or the error gets raised while you load the synthesizer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the user issue, they ran into this problem when loading the synthesizer, so that's where I added the warning.


check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,3 +1603,19 @@ def test_load(self, mock_file, cloudpickle_mock,
'SYNTHESIZER CLASS NAME': 'Mock',
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

@patch('builtins.open')
@patch('sdv.multi_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseMultiTableSynthesizer.load('synth.pkl')
16 changes: 16 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,22 @@ def test_load_custom_constraint_classes(self):
['Custom', 'Constr', 'UpperPlus']
)

@patch('builtins.open')
@patch('sdv.single_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseSingleTableSynthesizer.load('synth.pkl')

def test_add_custom_constraint_class(self):
"""Test that this method calls the ``DataProcessor``'s method."""
# Setup
Expand Down
Loading