Skip to content

Commit

Permalink
working couette flow case and pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasErbesdobler committed Jun 27, 2024
1 parent aa0e93a commit 6626abd
Show file tree
Hide file tree
Showing 6 changed files with 430 additions and 3 deletions.
160 changes: 160 additions & 0 deletions cases/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Couette flow case setup"""

import jax.numpy as jnp
import numpy as np
from omegaconf import DictConfig

from jax_sph.case_setup import SimulationSetup
from jax_sph.utils import Tag, pos_init_cartesian_2d, pos_init_cartesian_3d


class CF(SimulationSetup):
"""Couette Flow.
Setup based on "Modeling Low Reynolds Number Incompressible [...], Morris 1997,
and similar to PF case.
"""

def __init__(self, cfg: DictConfig):
super().__init__(cfg)

# custom variables related only to this Simulation
if self.case.dim == 2:
self.u_wall = jnp.array([self.special.u_x_wall, 0.0])
elif self.case.dim == 3:
self.u_wall = jnp.array([self.special.u_x_wall, 0.0, 0.0])

# define offset vector
self.offset_vec = self._offset_vec()

# relaxation configurations
if self.case.mode == "rlx":
self._set_default_rlx()

if self.case.r0_type == "relaxed":
self._load_only_fluid = False
self._init_pos2D = self._get_relaxed_r0
self._init_pos3D = self._get_relaxed_r0

def _box_size2D(self, n_walls):
dx2n = self.case.dx * n_walls * 2
sp = self.special
return np.array([sp.L, sp.H + dx2n])

def _box_size3D(self, n_walls):
dx2n = self.case.dx * n_walls * 2
sp = self.special
return np.array([sp.L, sp.H + dx2n, 0.4])

def _init_walls_2d(self, dx, n_walls):
sp = self.special

# thickness of wall particles
dxn = dx * n_walls

# horizontal and vertical blocks
horiz = pos_init_cartesian_2d(np.array([sp.L, dxn]), dx)

# wall: bottom, top
wall_b = horiz.copy()
wall_t = horiz.copy() + np.array([0.0, sp.H + dxn])

rw = np.concatenate([wall_b, wall_t])
return rw

def _init_walls_3d(self, dx, n_walls):
sp = self.special

# thickness of wall particles
dxn = dx * n_walls

# horizontal and vertical blocks
horiz = pos_init_cartesian_3d(np.array([sp.L, dxn, 0.4]), dx)

# wall: bottom, top
wall_b = horiz.copy()
wall_t = horiz.copy() + np.array([0.0, sp.H + dxn, 0.0])

rw = np.concatenate([wall_b, wall_t])
return rw

def _init_pos2D(self, box_size, dx, n_walls):
sp = self.special

# initialize fluid phase
r_f = np.array([0.0, 1.0]) * n_walls * dx + pos_init_cartesian_2d(
np.array([sp.L, sp.H]), dx
)

# initialize walls
r_w = self._init_walls_2d(dx, n_walls)

# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])

# set velocity wall tag
box_size = self._box_size2D(n_walls)
mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx)
tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag)
return r, tag

def _init_pos3D(self, box_size, dx, n_walls):
sp = self.special

# initialize fluid phase
r_f = np.array([0.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d(
np.array([sp.L, sp.H, 0.4]), dx
)

# initialize walls
r_w = self._init_walls_3d(dx, n_walls)

# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])

# set velocity wall tag
box_size = self._box_size3D(n_walls)
mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx)
tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag)
return r, tag

def _offset_vec(self):
dim = self.cfg.case.dim
if dim == 2:
res = np.array([0.0, 1.0]) * self.cfg.solver.n_walls * self.cfg.case.dx
elif dim == 3:
res = np.array([0.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx
return res

def _init_velocity2D(self, r):
return jnp.zeros_like(r)

def _init_velocity3D(self, r):
return jnp.zeros_like(r)

def _external_acceleration_fn(self, r):
return jnp.zeros_like(r)

def _boundary_conditions_fn(self, state):
mask1 = state["tag"][:, None] == Tag.SOLID_WALL
mask2 = state["tag"][:, None] == Tag.MOVING_WALL

state["u"] = jnp.where(mask1, 0.0, state["u"])
state["v"] = jnp.where(mask1, 0.0, state["v"])
state["u"] = jnp.where(mask2, self.u_wall, state["u"])
state["v"] = jnp.where(mask2, self.u_wall, state["v"])

state["dudt"] = jnp.where(mask1, 0.0, state["dudt"])
state["dvdt"] = jnp.where(mask1, 0.0, state["dvdt"])
state["dudt"] = jnp.where(mask2, 0.0, state["dudt"])
state["dvdt"] = jnp.where(mask2, 0.0, state["dvdt"])

return state
24 changes: 24 additions & 0 deletions cases/cf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
extends: JAX_SPH_DEFAULTS

seed: 123

case:
source: "cf.py"
dim: 2
dx: 0.0166666
viscosity: 100.0
u_ref: 1.25
special:
L: 0.4 # water column length
H: 1.0 # water column height
u_x_wall: 1.25

solver:
dt: 0.0000005
t_end: 0.01
is_bc_trick: True

io:
write_type: ["h5"]
write_every: 200
data_path: "data/debug"
2 changes: 1 addition & 1 deletion jax_sph/case_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def initialize(self):
if k not in cfg.case.state0_keys:
continue
assert k in _state, ValueError(f"Key {k} not found in state0 file.")
mask, _mask = state["tag"]==Tag.FLUID, _state["tag"]==Tag.FLUID
mask, _mask = state["tag"] == Tag.FLUID, _state["tag"] == Tag.FLUID
assert state[k][mask].shape == _state[k][_mask].shape, ValueError(
f"Shape mismatch for key {k} in state0 file."
)
Expand Down
117 changes: 117 additions & 0 deletions tests/test_cf2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Test a full run of the solver on the Coette flow case from the validations."""

import os

import jax.numpy as jnp
import numpy as np
import pytest
from jax import config
from omegaconf import OmegaConf

from main import load_embedded_configs


def u_series_cf_exp(y, t, n_max=10):
"""Analytical solution to unsteady Couette flow (low Re)
Based on Series expansion as shown in:
"Modeling Low Reynolds Number Incompressible Flows Using SPH"
ba Morris et al. 1997
"""

eta = 100.0 # dynamic viscosity
rho = 1.0 # denstiy
nu = eta / rho # kinematic viscosity
u_max = 1.25 # max velocity in middle of channel
d = 1.0 # channel width

Re = u_max * d / nu
print(f"Couette flow at Re={Re}")

offset = u_max * y / d

def term(n):
base = np.pi * n / d

prefactor = 2 * u_max / (n * np.pi) * (-1)**n
sin_term = np.sin(base * y)
exp_term = np.exp(-(base**2) * nu * t)
return prefactor * sin_term * exp_term

res = offset
for i in range(1, n_max):
res += term(i)

return res


@pytest.fixture
def setup_simulation():
y_axis = np.linspace(0, 1, 21)
t_dimless = [0.0005, 0.001, 0.005]
# get analytical solution
ref_solutions = []
for t_val in t_dimless:
ref_solutions.append(u_series_cf_exp(y_axis, t_val))
return y_axis, t_dimless, ref_solutions


def run_simulation(tmp_path, tvf, solver):
"""Emulate `main.py`."""
data_path = tmp_path / f"cf_test_{tvf}"

cli_args = OmegaConf.create(
{
"config": "cases/cf.yaml",
"case": {"dx": 0.0333333},
"solver": {"name": solver, "tvf": tvf, "dt": 0.000002, "t_end": 0.005},
"io": {"write_every": 250, "data_path": str(data_path)},
}
)
cfg = load_embedded_configs(cli_args)

# Specify cuda device. These setting must be done before importing jax-md.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.xla_mem_fraction)

if cfg.dtype == "float64":
config.update("jax_enable_x64", True)

from jax_sph.simulate import simulate

simulate(cfg)

return data_path


def get_solution(data_path, t_dimless, y_axis):
from jax_sph.utils import sph_interpolator

dir = os.listdir(data_path)[0]
cfg = OmegaConf.load(data_path / dir / "config.yaml")
step_max = np.array(np.rint(cfg.solver.t_end / cfg.solver.dt), dtype=int)
digits = len(str(step_max))

y_axis += 3 * cfg.case.dx
rs = 0.2 * jnp.ones([y_axis.shape[0], 2])
rs = rs.at[:, 1].set(y_axis)
solutions = []
for i in range(len(t_dimless)):
file_name = (
"traj_" + str(int(t_dimless[i] / cfg.solver.dt)).zfill(digits) + ".h5"
)
src_path = data_path / dir / file_name
interp_vel_fn = sph_interpolator(cfg, src_path)
solutions.append(interp_vel_fn(src_path, rs, prop="u", dim_ind=0))
return solutions


@pytest.mark.parametrize("tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE")])
def test_cf2d(tvf, solver, tmp_path, setup_simulation):
"""Test whether the couette flow simulation matches the analytical solution"""
y_axis, t_dimless, ref_solutions = setup_simulation
data_path = run_simulation(tmp_path, tvf, solver)
solutions = get_solution(data_path, t_dimless, y_axis)
for sol, ref_sol in zip(solutions, ref_solutions):
assert np.allclose(sol, ref_sol, atol=1e-2), "Velocity profile does not match."
14 changes: 14 additions & 0 deletions validation/cf2d.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
# Validation of the 2D Couette Flow
# Reference result from:
# "Modeling Low Reynolds Number Incompressible Flows Using SPH", Morris 1997

# Generate data
python main.py config=cases/cf.yaml solver.tvf=1.0 io.data_path=data_valid/cf2d_tvf/
python main.py config=cases/cf.yaml solver.tvf=0.0 io.data_path=data_valid/cf2d_notvf/
python main.py config=cases/cf.yaml solver.tvf=0.0 solver.name=RIE solver.density_evolution=True io.data_path=data_valid/cf2d_Rie/

# Run validation script
python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_tvf/
python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_notvf/
python validation/validate.py --case=2D_CF --src_dir=data_valid/cf2d_Rie/
Loading

0 comments on commit 6626abd

Please sign in to comment.