Skip to content

Commit

Permalink
Verify err msg matches
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 20, 2024
1 parent 7749f05 commit f483afb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
17 changes: 12 additions & 5 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,12 +696,19 @@ def load(cls, filepath):
with open(filepath, 'rb') as f:
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
except RuntimeError as e:
err_msg = (
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
)
if str(e) == err_msg:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
17 changes: 12 additions & 5 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,19 @@ def load(cls, filepath):
with open(filepath, 'rb') as f:
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
except RuntimeError as e:
err_msg = (
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
)
if str(e) == err_msg:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,12 @@ def test_load(self, mock_file, cloudpickle_mock,
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError
cloudpickle_mock.load.side_effect = RuntimeError((
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
))

# Run and Assert
err_msg = re.escape(
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,12 @@ def test_load_custom_constraint_classes(self):
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError
cloudpickle_mock.load.side_effect = RuntimeError((
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
))

# Run and Assert
err_msg = re.escape(
Expand Down

0 comments on commit f483afb

Please sign in to comment.