diff --git a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp index 70cce324e2..207529a303 100644 --- a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp +++ b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp @@ -189,9 +189,14 @@ class CuStateVecCircuitSimulator return; } - // User state provided... + // Check if the pointer is a device pointer + cudaPointerAttributes attributes; + HANDLE_CUDA_ERROR(cudaPointerGetAttributes(&attributes, state)); - // FIXME handle case where pointer is a device pointer + if (attributes.type == cudaMemoryTypeDevice) { + throw std::invalid_argument( + "[CuStateVecCircuitSimulator] Incompatible host pointer"); + } // First allocation, so just set the user provided data here ScopedTraceWithContext( @@ -200,6 +205,7 @@ class CuStateVecCircuitSimulator HANDLE_CUDA_ERROR(cudaMemcpy(deviceStateVector, state, stateDimension * sizeof(CudaDataType), cudaMemcpyHostToDevice)); + return; } @@ -221,8 +227,15 @@ class CuStateVecCircuitSimulator n_blocks, threads_per_block, otherState, (1UL << count)); HANDLE_CUDA_ERROR(cudaGetLastError()); } else { + // Check if the pointer is a device pointer + cudaPointerAttributes attributes; + HANDLE_CUDA_ERROR(cudaPointerGetAttributes(&attributes, state)); + + if (attributes.type == cudaMemoryTypeDevice) { + throw std::invalid_argument( + "[CuStateVecCircuitSimulator] Incompatible host pointer"); + } - // FIXME Handle case where data is already on GPU HANDLE_CUDA_ERROR(cudaMemcpy(otherState, state, (1UL << count) * sizeof(CudaDataType), cudaMemcpyHostToDevice));