From 799d612ca1f8d73794fe774a5a1fe9698547df6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=B1=E6=98=B1=E5=86=B0?= Date: Fri, 22 Mar 2024 17:00:22 +0800 Subject: [PATCH] add FastWarp and periodic MinAC estimator --- netobs/adaptors/__init__.py | 25 +- netobs/adaptors/deepsolid_vmc.py | 25 +- netobs/helpers/grad.py | 12 +- netobs/observables/force.py | 300 +++++++++++++++++++++--- tests/__snapshots__/numerical_test.ambr | 2 +- tests/checkpoint_test.py | 6 +- tests/numerical_test.py | 6 +- 7 files changed, 331 insertions(+), 45 deletions(-) diff --git a/netobs/adaptors/__init__.py b/netobs/adaptors/__init__.py index f89217e..5b4d2d3 100644 --- a/netobs/adaptors/__init__.py +++ b/netobs/adaptors/__init__.py @@ -245,7 +245,7 @@ def call_local_potential_energy( Local potential energy at this point. """ - def make_network_grad(self, arg: str): + def make_network_grad(self, arg: str, jaxfun: Callable = jax.grad): """Create gradient function of network. Useful when the gradient handling is different in your network. @@ -254,13 +254,29 @@ def make_network_grad(self, arg: str): Args: arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms". + jaxfun: the type of gradient to take, e.g. `jax.grad`. Returns: The gradient function. """ - return grad_with_system(self.call_network, arg) # type: ignore + return grad_with_system(self.call_network, arg, jaxfun=jaxfun) # type: ignore - def make_local_energy_grad(self, arg: str): + def make_signed_network_grad(self, arg: str, jaxfun: Callable = jax.grad): + """Create gradient function of network, taking sign into consideration. + + Useful when the gradient handling is different in your network, + e.g. the network is complex-valued. + + Args: + arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms". + jaxfun: the type of gradient to take, e.g. `jax.grad`. + + Returns: + The gradient function. + """ + return grad_with_system(self.call_network, arg, jaxfun=jaxfun) # type: ignore + + def make_local_energy_grad(self, arg: str, jaxfun: Callable = jax.grad): """Create gradient function of local energy. Useful when the gradient handling is different in your network, @@ -268,11 +284,12 @@ def make_local_energy_grad(self, arg: str): Args: arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms". + jaxfun: the type of gradient to take, e.g. `jax.grad`. Returns: The gradient function. """ - return grad_with_system(self.call_local_energy, arg) # type: ignore + return grad_with_system(self.call_local_energy, arg, jaxfun=jaxfun) # type: ignore # Utility protocols diff --git a/netobs/adaptors/deepsolid_vmc.py b/netobs/adaptors/deepsolid_vmc.py index fd2c6e0..347157c 100644 --- a/netobs/adaptors/deepsolid_vmc.py +++ b/netobs/adaptors/deepsolid_vmc.py @@ -198,12 +198,31 @@ def walk( return jax.pmap(walk) - def make_local_energy_grad(self, arg: str): + def make_signed_network_grad(self, arg: str, jaxfun: Callable = jax.grad): + def complex_f(params, electrons, system): + sign, slogdet = self.call_signed_network(params, electrons, system) + return jnp.log(sign) + slogdet + + grad_f_real = grad_with_system( + lambda *args: complex_f(*args).real, arg, args_before=1, jaxfun=jaxfun + ) + grad_f_imag = grad_with_system( + lambda *args: complex_f(*args).imag, arg, args_before=1, jaxfun=jaxfun + ) + return lambda *args: grad_f_real(*args) + grad_f_imag(*args) * 1j + + def make_local_energy_grad(self, arg: str, jaxfun: Callable = jax.grad): grad_local_energy_real = grad_with_system( - lambda *args: self.call_local_energy(*args).real, arg, args_before=2 + lambda *args: self.call_local_energy(*args).real, + arg, + args_before=2, + jaxfun=jaxfun, ) grad_local_energy_imag = grad_with_system( - lambda *args: self.call_local_energy(*args).imag, arg, args_before=2 + lambda *args: self.call_local_energy(*args).imag, + arg, + args_before=2, + jaxfun=jaxfun, ) return ( lambda *args: grad_local_energy_real(*args) diff --git a/netobs/helpers/grad.py b/netobs/helpers/grad.py index 538a812..45e7e9a 100644 --- a/netobs/helpers/grad.py +++ b/netobs/helpers/grad.py @@ -48,7 +48,12 @@ def wrapped_kinetic_energy(params: Any, x: jnp.ndarray, atoms: Any) -> jnp.ndarr return wrapped_kinetic_energy -def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int | None = None): +def grad_with_system( + f: Callable[..., jnp.ndarray], + arg: str, + args_before: int | None = None, + jaxfun: Callable = jax.grad, +): """Make grad of functions like f(*args, electrons, system). The last two args must be `electrons as `system`. @@ -60,6 +65,7 @@ def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int | To grad with things inside `system`, use the key, e.g. "atoms". args_before: number of arguments before "electrons". Leaving it empty to automatically detect. + jaxfun: the type of gradient to take, e.g. `jax.grad`. Raises: ValueError: failing to detect the function signature @@ -73,13 +79,13 @@ def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int | raise ValueError("Unable to determine function signature") if arg == "electrons": - return jax.grad(f, argnums=args_before) + return jaxfun(f, argnums=args_before) def wrap_f(*args): *args, x, system = args return f(*args, {**system, arg: x}) - grad_local_energy = jax.grad(wrap_f, argnums=args_before + 1) + grad_local_energy = jaxfun(wrap_f, argnums=args_before + 1) def wrap_grad(*args): *args, system = args diff --git a/netobs/observables/force.py b/netobs/observables/force.py index e35b89f..c97ff9c 100644 --- a/netobs/observables/force.py +++ b/netobs/observables/force.py @@ -27,7 +27,7 @@ from netobs.helpers.grad import grad_with_system from netobs.logging import logger from netobs.observables import Estimator, Observable -from netobs.systems import System +from netobs.systems import System, solid from netobs.systems.molecule import MolecularSystem, calculate_r_ae from netobs.systems.solid import MinimalImageDistance, SolidSystem @@ -159,8 +159,22 @@ def digest(self, all_values, state) -> dict[str, jnp.ndarray]: return values +def exp1(x): + """Swamee and Ohija approximation for E1 function. + + See https://doi.org/10.1111%2Fj.1745-6584.2003.tb02608.x + + Default inplementaion in `jax.scipy.special.exp1` is accurate but slow. + See https://github.com/google/jax/issues/13543. + """ + return ( + jnp.log((0.56146 / x + 0.65) * (1 + x)) ** (-7.7) + + x**4 * jnp.exp(7.7 * x) * (2 + x) ** (3.7) + ) ** -0.13 + + @register_pytree_node_class -class AC(Estimator[MolecularSystem]): +class MinAC(Estimator[System]): r"""AC type Zero-variance zero-bias estimator based on \tilde{\psi}_{min}. \tilde{\psi}_{min} is the "minimal" form removing the singular part. @@ -174,13 +188,32 @@ class AC(Estimator[MolecularSystem]): def __init__(self, adaptor, system, estimator_options, observable_options): super().__init__(adaptor, system, estimator_options, observable_options) self.enable_zb = self.options.get("zb", False) + self.r_core = estimator_options.get("r_core", 0) + if self.r_core > 0: + self.batch_mirror = make_antithetic( + system, adaptor.call_network, self.r_core + ) + self.grad_potential = jax.pmap( + jax.vmap( + grad_with_system(adaptor.call_local_potential_energy, "atoms"), + in_axes=(None, None, 0, None), + ), + in_axes=(0, 0, 0, None), + ) self.batch_local_energy = jax.pmap( jax.vmap(adaptor.call_local_energy, in_axes=(None, None, 0, None)), in_axes=(0, 0, 0, None), ) - self.batch_Q = jax.pmap(jax.vmap(self.Q, in_axes=(0, None)), in_axes=(0, None)) - self.grad_Q = jax.jacfwd(self.Q, argnums=0) + self.batch_f_deriv_atom = jax.pmap( + jax.vmap(adaptor.make_network_grad("atoms"), in_axes=(None, 0, None)), + in_axes=(0, 0, None), + ) self.grad_f = adaptor.make_network_grad("electrons") + if "latvec" in system: + # NOTE: Only support default Ewald settings below. + self.dist = MinimalImageDistance(system["latvec"]) + recvec = solid.recvec(system, norm_to=1) + self.alpha = 5.0 / jnp.amin(1 / jnp.linalg.norm(recvec, axis=1)) def empty_val_state(self, steps: int): term_shape = (steps, *self.observable.shape) @@ -200,9 +233,21 @@ def empty_val_state(self, steps: int): def evaluate(self, i, params, key, data, system, state, aux_data): del i, aux_data - values = {"hf_term": self.batch_zv(params, data, system)} + f_bare = -self.grad_potential(params, key, data, system) + zv_term = self.batch_zv(params, data, system) + hfm_term = f_bare + zv_term + + if self.r_core > 0: + data_mirrored, mirrored_weight = self.batch_mirror(params, data, system) + f_bare_mirror = -self.grad_potential(params, key, data_mirrored, system) + zv_mirror = self.batch_zv(params, data_mirrored, system) + hfm_mirror = (f_bare_mirror + zv_mirror) * mirrored_weight[..., None, None] + hfm_term = (hfm_term + hfm_mirror) / 2 + + values = {"hf_term": hfm_term} + if self.enable_zb: - pulay_coeff = 2 * self.batch_Q(data, system) + pulay_coeff = 2 * self.batch_f_deriv_atom(params, data, system) el = self.batch_local_energy(params, key, data, system) values["el"] = el values["el_term"] = -el[..., None, None] * pulay_coeff @@ -223,44 +268,85 @@ def digest(self, all_values, state) -> dict[str, jnp.ndarray]: values["force"] = all_values["hf_term"] return values + def batch_zv(self, params: jnp.ndarray, x: jnp.ndarray, system: System): + if "latvec" in system: + return self.batch_zv_solid(params, x, cast(SolidSystem, system)) + else: + return self.batch_zv_molecular(params, x, cast(MolecularSystem, system)) + @partial(jax.pmap, in_axes=(None, 0, 0, None)) @partial(jax.vmap, in_axes=(None, None, 0, None)) - def batch_zv( - self, - params: jnp.ndarray, - x: jnp.ndarray, - system: MolecularSystem, + def batch_zv_molecular( + self, params: jnp.ndarray, x: jnp.ndarray, system: MolecularSystem ) -> jnp.ndarray: - atoms, charges = system["atoms"], system["charges"] - aa = atoms[None, ...] - atoms[:, None] - # || (1, natom, ndim) - (natom, 1, ndim) || = (natom, natom) - r_aa = jnp.linalg.norm(aa, axis=-1) - # f_aa_matrix[0, 1] points from atom 0 to atom 1, so its force on atom 1 - # Shapes are: charges (natom); aa (natom, natom, 3); r_aa (natom, natom, 1) - f_aa_matrix = jnp.nan_to_num( - (charges[None, ..., None] * charges[..., None, None]) - * aa - / r_aa[..., None] ** 3 + ae, r_ae = calculate_r_ae(x, system) + f_ae = jnp.sum(system["charges"][..., None] * ae / r_ae**3, axis=0) + dot_term = jnp.dot( + self.grad_Q_molecular(x, system), self.grad_f(params, x, system) ) - f_aa = jnp.sum(f_aa_matrix, axis=0) - dot_term = jnp.dot(self.grad_Q(x, system), self.grad_f(params, x, system)) - return f_aa + dot_term + return -f_ae + dot_term - def Q(self, x: jnp.ndarray, system: MolecularSystem) -> jnp.ndarray: - """The Q matrix. Shape (natom, ndim). + @partial(jax.jacfwd, argnums=1) + def grad_Q_molecular(self, x: jnp.ndarray, system: MolecularSystem) -> jnp.ndarray: + """The gradient of the Q matrix in the molecular case. - Based on Eq. (70) in the paper. + Undecorated: the Q matrix in the molecular case. Args: x: Shape (nelec*ndim). Electron positions. system: system containing atomic info. Returns: - The Q matrix. + nabla Q. Shape (natom, ndim, nelec*ndim). + Undecorated: the Q matrix. Shape (natom, ndim). """ ae, r_ae = calculate_r_ae(x, system) return jnp.sum(system["charges"][..., None] * ae / r_ae, axis=0) + @partial(jax.pmap, in_axes=(None, 0, 0, None)) + @partial(jax.vmap, in_axes=(None, None, 0, None)) + def batch_zv_solid( + self, params: jnp.ndarray, x: jnp.ndarray, system: SolidSystem + ) -> jnp.ndarray: + dot_term = jnp.dot(self.grad_Q_solid(x, system), self.grad_f(params, x, system)) + return -self.lap_Q_solid(x, system) / 2 - dot_term + + def grad_Q_solid(self, x: jnp.ndarray, system: SolidSystem) -> jnp.ndarray: + r"""Calculate \nabla Q. + + shape (natom, ndim, nelec*ndim) + """ + # shape (ncell, nelec, natom, 3), (ncell, nelec, natom, 1) + ae, r_ae = self.dist.neighboring_r_ae(x, system["atoms"]) + # (ncell, nelec, natom, 1) -> (ncell, natom, 1, nelec, 1) + r_ae = jnp.transpose(r_ae, (0, 2, 3, 1))[..., None] + _, nelec, natom, ndim = ae.shape + # (ncell, nelec, natom, ndim, ndim) -> (ncell, natom, ndim, nelec, ndim) + quadratic_mat = jnp.transpose(ae[..., None] * ae[..., None, :], (0, 2, 3, 1, 4)) + + ar = self.alpha * r_ae + quadratic_mat *= jax.lax.erfc(ar) / r_ae**3 + diag_mat = -jnp.eye(ndim)[:, None, :] * ( + self.alpha / jnp.sqrt(jnp.pi) * -exp1(ar**2) + jax.lax.erfc(ar) / r_ae + ) + return system["charges"][:, None, None] * jnp.reshape( + jnp.sum(quadratic_mat + diag_mat, axis=0), (natom, ndim, nelec * ndim) + ) + + def lap_Q_solid(self, x: jnp.ndarray, system: SolidSystem) -> jnp.ndarray: + r"""Calculate \nabla^2 Q. + + shape (natom, ndim) + """ + # shape (ncell, nelec, natom, 3), (ncell, nelec, natom, 1) + ae, r_ae = self.dist.neighboring_r_ae(x, system["atoms"]) + ar = self.alpha * r_ae + coeff = self.alpha / jnp.sqrt(jnp.pi) + return system["charges"][:, None] * jnp.sum( + (jax.lax.erfc(ar) / r_ae - coeff * jnp.exp(-(ar**2))) * 2 * ae / r_ae**2, + axis=(0, 1), + ) + def elec_reshaped(f): @wraps(f) @@ -445,3 +531,161 @@ def decay_function(r_ae: jnp.ndarray) -> jnp.ndarray: The value of the decaying function. """ return 1 / r_ae**4 + + +@register_pytree_node_class +class FastWarp(Estimator[System]): + observable_type = Force + + def __init__(self, adaptor, system, estimator_options, observable_options): + super().__init__(adaptor, system, estimator_options, observable_options) + self.r_core = estimator_options.get("r_core", 0) + if self.r_core != 0: + self.batch_mirror = make_antithetic( + system, adaptor.call_network, self.r_core + ) + self.nelec = sum(system["spins"]) + self.grad_potential = jax.pmap( + jax.vmap( + grad_with_system(adaptor.call_local_potential_energy, "atoms"), + in_axes=(None, None, 0, None), + ), + in_axes=(0, 0, 0, None), + ) + self.ep_deriv_elec = grad_with_system( + adaptor.call_local_potential_energy, "electrons", jaxfun=jax.value_and_grad + ) + self.batch_kinetic_energy = jax.pmap( + jax.vmap(adaptor.call_local_kinetic_energy, in_axes=(None, None, 0, None)), + in_axes=(0, 0, 0, None), + ) + if "latvec" in system: + dist = MinimalImageDistance(system["latvec"]) + self.neighboring_r_ae = lambda x, s: dist.neighboring_r_ae(x, s["atoms"])[1] + logger.info("Using periodic version of FastWarp") + else: + self.neighboring_r_ae = lambda x, s: calculate_r_ae(x, s)[1][None, ...] + logger.info("Using molecular version of FastWarp") + self.grad_f = adaptor.make_signed_network_grad("electrons") + self.f_deriv_atom = adaptor.make_signed_network_grad("atoms") + self.hess_f = adaptor.make_signed_network_grad("electrons", jaxfun=jax.hessian) + self.batch_hfm_warp = jax.pmap( + jax.vmap(self.hfm_warp_term, (None, None, 0, None)), in_axes=(0, 0, 0, None) + ) + self.omega_hess = jax.vmap( + jax.hessian(self.omega_single_electron, argnums=1), in_axes=(None, 0) + ) + self.omega_jacfwd = jax.vmap( + jax.jacfwd(self.omega_single_electron, argnums=1), in_axes=(None, 0) + ) + + def empty_val_state(self, steps: int): + term_shape = (steps, *self.observable.shape) + dtype = self.options.get("dtype") + names = ("hfm_term", "pulay_bare", "pulay_warp", "el_term_bare", "el_term_warp") + empty_values = { + "el": jnp.zeros((steps,), dtype), + **{name: jnp.zeros(term_shape, dtype) for name in names}, + } + return empty_values, {} + + def evaluate(self, i, params, key, data, system, state, aux_data): + del i, aux_data + f_bare = -self.grad_potential(params, key, data, system) + hfm_warp, el, pulay_bare, pulay_warp = self.batch_hfm_warp( + params, key, data, system + ) + if self.r_core != 0: # Enable antithetic + data_mirrored, mirrored_weight = self.batch_mirror(params, data, system) + m_f_bare = -self.grad_potential(params, key, data_mirrored, system) + m_hfm_warp, *_ = self.batch_hfm_warp(params, key, data_mirrored, system) + mirrored_weight = mirrored_weight[..., None, None] + f_bare = (f_bare + m_f_bare * mirrored_weight) / 2 + hfm_warp = (hfm_warp + m_hfm_warp * mirrored_weight) / 2 + values = { + # TODO: inspect why NaN happens + "hfm_term": f_bare + hfm_warp, + "el": el, + "el_term_bare": -jnp.real(jnp.conjugate(el[..., None, None]) * pulay_bare), + "el_term_warp": -jnp.real(jnp.conjugate(el[..., None, None]) * pulay_warp), + "pulay_bare": pulay_bare.real, + "pulay_warp": pulay_warp.real, + } + return values, state + + def digest(self, all_values, state) -> dict[str, jnp.ndarray]: + del state + el_term = all_values["el_term_bare"] + all_values["el_term_warp"] + pulay_coeff = all_values["pulay_bare"] + all_values["pulay_warp"] + energy_mean = jnp.mean(all_values["el"]) + return { + "force_biased": all_values["hfm_term"], + "energy": all_values["el"], + "force": all_values["hfm_term"] + el_term + pulay_coeff * energy_mean, + } + + def hfm_warp_term( + self, params: jnp.ndarray, key: jnp.ndarray, x: jnp.ndarray, system: System + ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + omega_mat = self.omega(system, x) + + potential_energy, ep_grad = self.ep_deriv_elec(params, key, x, system) + pot_term = omega_mat * jnp.reshape(ep_grad, (-1, 1, 3)) + + n = self.nelec + omega_jacfwd = self.omega_jacfwd(system, jnp.reshape(x, (n, 3))) + grad_f = self.grad_f(params, x, system) + hess_f = self.hess_f(params, x, system) + kinetic_energy = -0.5 * (jnp.trace(hess_f) + jnp.sum(grad_f**2)) + + grad_f = jnp.reshape(grad_f, (n, 3)) + hess_f = jnp.reshape(hess_f, (n, 3, n, 3)) + + # (n, 3, 3) + grad_grad_f_per_elec = jnp.einsum("ij,ik->ijk", grad_f, grad_f) + # Take out 3x3 block diagonal from nx3xnx3 matrix + hess_f_per_elec = jnp.moveaxis(jnp.diagonal(hess_f, axis1=0, axis2=2), 2, 0) + hess_psi_per_elec = grad_grad_f_per_elec + hess_f_per_elec + o1_term = jnp.sum( + omega_jacfwd[..., None] * hess_psi_per_elec[:, None, ...], axis=-2 + ) + + grad_f = jnp.reshape(grad_f, (n, 1, 3)) + omega_hess = self.omega_hess(system, jnp.reshape(x, (n, 3))) + lap_omega = jnp.trace(omega_hess, axis1=-2, axis2=-1)[..., None] + # Autograd for omega has numerical issue around nucleus, + # but the value there is simply 0. + lap_omega = jnp.nan_to_num(lap_omega, posinf=0.0, neginf=0.0) + o2_term = lap_omega * grad_f / 2 + + local_energy = potential_energy + kinetic_energy + pulay_bare = 2 * self.f_deriv_atom(params, x, system) + pulay_warp = 2 * jnp.sum(omega_mat * grad_f + omega_jacfwd / 2, axis=0) + + hfm_warp_term = -jnp.sum(pot_term + o1_term + o2_term, axis=0) + return hfm_warp_term, local_energy, pulay_bare, pulay_warp + + def omega(self, system: System, x: jnp.ndarray) -> jnp.ndarray: + r_ae = self.neighboring_r_ae(x, system) + # Remind r_ae is in shape (ncell, nelectron, natom, 1) + f_mat = self.decay_function(r_ae) + return jnp.sum(f_mat, axis=0) / f_mat.sum(axis=(0, 2))[:, None, :] + + def omega_single_electron(self, system: System, x: jnp.ndarray) -> jnp.ndarray: + r"""Calculate \omega matrix by single electron postion. + + Args: + system: system containing atomic info. + x: single electron position. Shape: (ndim,) + + Returns: + Derivative of \omega matrix. Shape: (natom,) + """ + r_ae = self.neighboring_r_ae(x, system) + # r_ae has shape (ncell, 1, ntom, 1) + f_mat = self.decay_function(r_ae[:, 0, :, 0]) # shape (ncell, natom) + return jnp.sum(f_mat, axis=0) / f_mat.sum(axis=(0, 1)) + + @staticmethod + def decay_function(r_ae: jnp.ndarray) -> jnp.ndarray: + return 1 / r_ae**4 diff --git a/tests/__snapshots__/numerical_test.ambr b/tests/__snapshots__/numerical_test.ambr index e4062dd..0010f6c 100644 --- a/tests/__snapshots__/numerical_test.ambr +++ b/tests/__snapshots__/numerical_test.ambr @@ -30,7 +30,7 @@ # --- # name: test_zvzb tuple( - Array([[3.4e-06, 8.7e-06, 1.3e-06]], dtype=float32), + Array([[3.4e-06, 8.9e-06, 1.2e-06]], dtype=float32), Array([[0.0, 0.0, 0.0]], dtype=float32), ) # --- diff --git a/tests/checkpoint_test.py b/tests/checkpoint_test.py index 96e7cfc..b2bbeb8 100644 --- a/tests/checkpoint_test.py +++ b/tests/checkpoint_test.py @@ -22,7 +22,7 @@ from netobs.checkpoint import SavingCheckpointManager from netobs.evaluate import evaluate_observable from netobs.observables.energy import EnergyEstimator -from netobs.observables.force import AC +from netobs.observables.force import MinAC from netobs.options import NetObsOptions @@ -72,7 +72,7 @@ def test_restore_zb_evaluate_zv( ): options.estimator["zb"] = True _, values, _ = evaluate_observable( - adaptor, AC, options=options, checkpoint_mgr=ckpt_mgr + adaptor, MinAC, options=options, checkpoint_mgr=ckpt_mgr ) assert "el" in values @@ -82,7 +82,7 @@ def test_restore_zb_evaluate_zv( with tempfile.TemporaryDirectory() as save_dir: ckpt_mgr.save_path = Path(save_dir) _, values, _ = evaluate_observable( - adaptor, AC, options=options, checkpoint_mgr=ckpt_mgr + adaptor, MinAC, options=options, checkpoint_mgr=ckpt_mgr ) assert list(values.keys()) == ["hf_term"] assert len(os.listdir(ckpt_mgr.restore_path)) == 1 diff --git a/tests/numerical_test.py b/tests/numerical_test.py index be430a2..dd05839 100644 --- a/tests/numerical_test.py +++ b/tests/numerical_test.py @@ -20,7 +20,7 @@ from netobs.helpers.digest import robust_mean from netobs.observables.density import DensityEstimator from netobs.observables.energy import EnergyEstimator -from netobs.observables.force import AC, SWCT, Bare +from netobs.observables.force import SWCT, Bare, MinAC from netobs.observables.wf_change import WFChangeEstimator from netobs.options import NetObsOptions @@ -103,14 +103,14 @@ def test_antithetic_zb(adaptor: SimpleHydrogen, options: NetObsOptions, snapshot def test_zv_noerror(adaptor: SimpleHydrogen, options: NetObsOptions): options.mcmc_burn_in = 0 options.steps = 1 - digest, *_ = evaluate_observable(adaptor, AC, options=options) + digest, *_ = evaluate_observable(adaptor, MinAC, options=options) assert digest is not None assert "force" in digest def test_zvzb(adaptor: SimpleHydrogen, options: NetObsOptions, snapshot: str): options.estimator["zb"] = True - digest, *_ = evaluate_observable(adaptor, AC, options=options) + digest, *_ = evaluate_observable(adaptor, MinAC, options=options) assert digest is not None assert "force" in digest assert "force_zv" in digest