diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index ed08891fb..50a38c316 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -694,7 +694,22 @@ def load(cls, filepath): The loaded synthesizer. """ with open(filepath, 'rb') as f: - synthesizer = cloudpickle.load(f) + try: + synthesizer = cloudpickle.load(f) + except RuntimeError as e: + err_msg = ( + 'Attempting to deserialize object on a CUDA device but ' + 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' + " please use torch.load with map_location=torch.device('cpu') " + 'to map your storages to the CPU.' + ) + if str(e) == err_msg: + raise SamplingError( + 'This synthesizer was created on a machine with GPU but the current ' + 'machine is CPU-only. This feature is currently unsupported. We recommend' + ' sampling on the same GPU-enabled machine.' + ) + raise e check_synthesizer_version(synthesizer) check_sdv_versions_and_warn(synthesizer) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index de474d1c7..2fc8c18d0 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -492,7 +492,22 @@ def load(cls, filepath): The loaded synthesizer. """ with open(filepath, 'rb') as f: - synthesizer = cloudpickle.load(f) + try: + synthesizer = cloudpickle.load(f) + except RuntimeError as e: + err_msg = ( + 'Attempting to deserialize object on a CUDA device but ' + 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' + " please use torch.load with map_location=torch.device('cpu') " + 'to map your storages to the CPU.' + ) + if str(e) == err_msg: + raise SamplingError( + 'This synthesizer was created on a machine with GPU but the current ' + 'machine is CPU-only. This feature is currently unsupported. We recommend' + ' sampling on the same GPU-enabled machine.' + ) + raise e check_synthesizer_version(synthesizer) check_sdv_versions_and_warn(synthesizer) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index c4fe83d6e..6880de0ef 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1603,3 +1603,35 @@ def test_load(self, mock_file, cloudpickle_mock, 'SYNTHESIZER CLASS NAME': 'Mock', 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) + + @patch('builtins.open') + @patch('sdv.multi_table.base.cloudpickle') + def test_load_runtime_error(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError(( + 'Attempting to deserialize object on a CUDA device but ' + 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' + " please use torch.load with map_location=torch.device('cpu') " + 'to map your storages to the CPU.' + )) + + # Run and Assert + err_msg = re.escape( + 'This synthesizer was created on a machine with GPU but the current machine is' + ' CPU-only. This feature is currently unsupported. We recommend sampling on ' + 'the same GPU-enabled machine.' + ) + with pytest.raises(SamplingError, match=err_msg): + BaseMultiTableSynthesizer.load('synth.pkl') + + @patch('builtins.open') + @patch('sdv.multi_table.base.cloudpickle') + def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError('Error') + + # Run and Assert + with pytest.raises(RuntimeError, match='Error'): + BaseMultiTableSynthesizer.load('synth.pkl') diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 3074d8506..1ae03b541 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1914,6 +1914,38 @@ def test_load_custom_constraint_classes(self): ['Custom', 'Constr', 'UpperPlus'] ) + @patch('builtins.open') + @patch('sdv.single_table.base.cloudpickle') + def test_load_runtime_error(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError(( + 'Attempting to deserialize object on a CUDA device but ' + 'torch.cuda.is_available() is False. If you are running on a CPU-only machine,' + " please use torch.load with map_location=torch.device('cpu') " + 'to map your storages to the CPU.' + )) + + # Run and Assert + err_msg = re.escape( + 'This synthesizer was created on a machine with GPU but the current machine is' + ' CPU-only. This feature is currently unsupported. We recommend sampling on ' + 'the same GPU-enabled machine.' + ) + with pytest.raises(SamplingError, match=err_msg): + BaseSingleTableSynthesizer.load('synth.pkl') + + @patch('builtins.open') + @patch('sdv.single_table.base.cloudpickle') + def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError('Error') + + # Run and Assert + with pytest.raises(RuntimeError, match='Error'): + BaseSingleTableSynthesizer.load('synth.pkl') + def test_add_custom_constraint_class(self): """Test that this method calls the ``DataProcessor``'s method.""" # Setup