Skip to content

Commit

Permalink
Add CUDA JIT to calc_gamma_van_der_waals() (tardis-sn#127)
Browse files Browse the repository at this point in the history
* Add basic test for `calc_doppler_width()`

* Refactor `calc_doppler_width()` and add test for vectorized implementation

* Typecast to float

* Add unwrapped cuda implementation of doppler_width

Also typecast all global constants to float

* Add wrapped cuda implementation of calc_doppler_width

* Return cupy array by default

* Add tests for non cuda implementation of `calc_gamma_van_der_waals()`.

* Typecast inputs

* Optimize formula for `calc_gamma_van_der_waals()`

* Vectorize `calc_gamma_van_der_waals()`

* Add CUDA implementations of `calc_gamma_van_der_waals()` and associated tests

* change ion_number to not square itself in _calc_gamma_van_der_waals()

* add missing import to test_broadening.py

* get rid of squaresd value overwrites in _calc_gamma_van_der_waals

---------

Co-authored-by: Josh Shields <[email protected]>
  • Loading branch information
smokestacklightnin and jvshields committed Sep 20, 2023
1 parent 6c24b93 commit a20f718
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 4 deletions.
96 changes: 92 additions & 4 deletions stardis/radiation_field/opacities/opacities_solvers/broadening.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def calc_gamma_quadratic_stark_cuda(


@numba.njit
def calc_gamma_van_der_waals(
def _calc_gamma_van_der_waals(
ion_number,
n_eff_upper,
n_eff_lower,
Expand Down Expand Up @@ -435,13 +435,21 @@ def calc_gamma_van_der_waals(
gamma_van_der_waals : float
Broadening parameter for van der Waals broadening.
"""
ion_number, n_eff_upper, n_eff_lower, temperature, h_density, h_mass = (
int(ion_number),
float(n_eff_upper),
float(n_eff_lower),
float(temperature),
float(h_density),
float(h_mass),
)
c6 = (
6.46e-34
* (
n_eff_upper**2 * (5 * n_eff_upper**2 + 1)
- n_eff_lower**2 * (5 * n_eff_lower**2 + 1)
(5 * n_eff_upper**4 + n_eff_upper**2)
- (5 * n_eff_lower**4 + n_eff_lower**2)
)
/ (2 * ion_number**2)
/ (2 * ion_number * ion_number)
)

gamma_van_der_waals = (
Expand All @@ -454,6 +462,86 @@ def calc_gamma_van_der_waals(
return gamma_van_der_waals


@numba.vectorize(nopython=True)
def calc_gamma_van_der_waals(
ion_number,
n_eff_upper,
n_eff_lower,
temperature,
h_density,
h_mass,
):
return _calc_gamma_van_der_waals(
ion_number,
n_eff_upper,
n_eff_lower,
temperature,
h_density,
h_mass,
)


@cuda.jit
def _calc_gamma_van_der_waals_cuda(
res,
ion_number,
n_eff_upper,
n_eff_lower,
temperature,
h_density,
h_mass,
):
tid = cuda.grid(1)
size = len(res)

if tid < size:
res[tid] = _calc_gamma_van_der_waals(
ion_number[tid],
n_eff_upper[tid],
n_eff_lower[tid],
temperature[tid],
h_density[tid],
h_mass[tid],
)


def calc_gamma_van_der_waals_cuda(
ion_number,
n_eff_upper,
n_eff_lower,
temperature,
h_density,
h_mass,
nthreads=256,
ret_np_ndarray=False,
dtype=float,
):
arg_list = (
ion_number,
n_eff_upper,
n_eff_lower,
temperature,
h_density,
h_mass,
)

shortest_arg_idx = np.argmin(map(len, arg_list))
size = len(arg_list[shortest_arg_idx])

nblocks = 1 + (size // nthreads)

arg_list = tuple(map(lambda v: cp.array(v, dtype=dtype), arg_list))

res = cp.empty_like(arg_list[shortest_arg_idx], dtype=dtype)

_calc_gamma_van_der_waals_cuda[nblocks, nthreads](
res,
*arg_list,
)

return cp.asnumpy(res) if ret_np_ndarray else res


@numba.njit
def calc_gamma(
atomic_number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
calc_gamma_quadratic_stark,
_calc_gamma_quadratic_stark_cuda,
calc_gamma_quadratic_stark_cuda,
calc_gamma_van_der_waals,
_calc_gamma_van_der_waals_cuda,
calc_gamma_van_der_waals_cuda,
)

GPUs_available = cuda.is_available()
Expand Down Expand Up @@ -485,3 +488,148 @@ def test_calc_gamma_quadratic_stark_cuda_wrapped_sample_cuda_values(
calc_gamma_quadratic_stark_cuda(*map(cp.asarray, arg_list)),
calc_gamma_quadratic_stark_sample_values_expected_result,
)


@pytest.mark.parametrize(
"calc_gamma_van_der_waals_sample_values_input_ion_number,calc_gamma_van_der_waals_sample_values_input_n_eff_upper,calc_gamma_van_der_waals_sample_values_input_n_eff_lower, calc_gamma_van_der_waals_sample_values_input_temperature, calc_gamma_van_der_waals_sample_values_input_h_density,calc_gamma_van_der_waals_sample_values_input_h_mass,calc_gamma_van_der_waals_sample_values_expected_result",
[
(
1, # ion_number
1.0, # n_eff_upper
0.0, # n_eff_lower
np.pi / 8 / BOLTZMANN_CONSTANT / 17 ** (1.0 / 0.3), # temperature
(3.0 * 6.46e-34) ** (-0.4), # h_density
1.0, # h_mass
1.0, # Expected output
),
(
np.array(2 * [1], dtype=int),
np.array(2 * [1.0]),
np.array(2 * [0.0]),
np.array(2 * [np.pi / 8 / BOLTZMANN_CONSTANT / 17 ** (1.0 / 0.3)]),
np.array(2 * [(3.0 * 6.46e-34) ** (-0.4)]),
np.array(2 * [1.0]),
np.array(2 * [1.0]),
),
],
)
def test_calc_gamma_van_der_waals_sample_values(
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
calc_gamma_van_der_waals_sample_values_expected_result,
):
print(
calc_gamma_van_der_waals(
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
)
)
assert np.allclose(
calc_gamma_van_der_waals(
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
),
calc_gamma_van_der_waals_sample_values_expected_result,
)


@pytest.mark.skipif(
not GPUs_available, reason="No GPU is available to test CUDA function"
)
@pytest.mark.parametrize(
"calc_gamma_van_der_waals_sample_values_input_ion_number,calc_gamma_van_der_waals_sample_values_input_n_eff_upper,calc_gamma_van_der_waals_sample_values_input_n_eff_lower, calc_gamma_van_der_waals_sample_values_input_temperature, calc_gamma_van_der_waals_sample_values_input_h_density,calc_gamma_van_der_waals_sample_values_input_h_mass,calc_gamma_van_der_waals_sample_values_expected_result",
[
(
np.array(2 * [1], dtype=int),
np.array(2 * [1.0]),
np.array(2 * [0.0]),
np.array(2 * [np.pi / 8 / BOLTZMANN_CONSTANT / 17 ** (1.0 / 0.3)]),
np.array(2 * [(3.0 * 6.46e-34) ** (-0.4)]),
np.array(2 * [1.0]),
np.array(2 * [1.0]),
),
],
)
def test_calc_gamma_van_der_waals_cuda_unwrapped_sample_values(
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
calc_gamma_van_der_waals_sample_values_expected_result,
):
arg_list = (
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
)

arg_list = tuple(map(cp.array, arg_list))
result_values = cp.empty_like(arg_list[0])

nthreads = 256
length = len(calc_gamma_van_der_waals_sample_values_expected_result)
nblocks = 1 + (length // nthreads)

_calc_gamma_van_der_waals_cuda[nblocks, nthreads](result_values, *arg_list)

assert np.allclose(
cp.asnumpy(result_values),
calc_gamma_van_der_waals_sample_values_expected_result,
)


@pytest.mark.skipif(
not GPUs_available, reason="No GPU is available to test CUDA function"
)
@pytest.mark.parametrize(
"calc_gamma_van_der_waals_sample_values_input_ion_number,calc_gamma_van_der_waals_sample_values_input_n_eff_upper,calc_gamma_van_der_waals_sample_values_input_n_eff_lower, calc_gamma_van_der_waals_sample_values_input_temperature, calc_gamma_van_der_waals_sample_values_input_h_density,calc_gamma_van_der_waals_sample_values_input_h_mass,calc_gamma_van_der_waals_sample_values_expected_result",
[
(
np.array(2 * [1], dtype=int),
np.array(2 * [1.0]),
np.array(2 * [0.0]),
np.array(2 * [np.pi / 8 / BOLTZMANN_CONSTANT / 17 ** (1.0 / 0.3)]),
np.array(2 * [(3.0 * 6.46e-34) ** (-0.4)]),
np.array(2 * [1.0]),
np.array(2 * [1.0]),
),
],
)
def test_calc_gamma_van_der_waals_cuda_wrapped_sample_cuda_values(
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
calc_gamma_van_der_waals_sample_values_expected_result,
):
arg_list = (
calc_gamma_van_der_waals_sample_values_input_ion_number,
calc_gamma_van_der_waals_sample_values_input_n_eff_upper,
calc_gamma_van_der_waals_sample_values_input_n_eff_lower,
calc_gamma_van_der_waals_sample_values_input_temperature,
calc_gamma_van_der_waals_sample_values_input_h_density,
calc_gamma_van_der_waals_sample_values_input_h_mass,
)
assert np.allclose(
calc_gamma_van_der_waals_cuda(*map(cp.asarray, arg_list)),
calc_gamma_van_der_waals_sample_values_expected_result,
)

0 comments on commit a20f718

Please sign in to comment.