Skip to content

Commit

Permalink
add FastWarp and periodic MinAC estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanChain committed Apr 16, 2024
1 parent 3b87260 commit 799d612
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 45 deletions.
25 changes: 21 additions & 4 deletions netobs/adaptors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def call_local_potential_energy(
Local potential energy at this point.
"""

def make_network_grad(self, arg: str):
def make_network_grad(self, arg: str, jaxfun: Callable = jax.grad):
"""Create gradient function of network.
Useful when the gradient handling is different in your network.
Expand All @@ -254,25 +254,42 @@ def make_network_grad(self, arg: str):
Args:
arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms".
jaxfun: the type of gradient to take, e.g. `jax.grad`.
Returns:
The gradient function.
"""
return grad_with_system(self.call_network, arg) # type: ignore
return grad_with_system(self.call_network, arg, jaxfun=jaxfun) # type: ignore

def make_local_energy_grad(self, arg: str):
def make_signed_network_grad(self, arg: str, jaxfun: Callable = jax.grad):
"""Create gradient function of network, taking sign into consideration.
Useful when the gradient handling is different in your network,
e.g. the network is complex-valued.
Args:
arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms".
jaxfun: the type of gradient to take, e.g. `jax.grad`.
Returns:
The gradient function.
"""
return grad_with_system(self.call_network, arg, jaxfun=jaxfun) # type: ignore

def make_local_energy_grad(self, arg: str, jaxfun: Callable = jax.grad):
"""Create gradient function of local energy.
Useful when the gradient handling is different in your network,
e.g. the network is complex-valued.
Args:
arg: the name of arguments to calculate gradient, e.g. "electrons"/"atoms".
jaxfun: the type of gradient to take, e.g. `jax.grad`.
Returns:
The gradient function.
"""
return grad_with_system(self.call_local_energy, arg) # type: ignore
return grad_with_system(self.call_local_energy, arg, jaxfun=jaxfun) # type: ignore


# Utility protocols
Expand Down
25 changes: 22 additions & 3 deletions netobs/adaptors/deepsolid_vmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,31 @@ def walk(

return jax.pmap(walk)

def make_local_energy_grad(self, arg: str):
def make_signed_network_grad(self, arg: str, jaxfun: Callable = jax.grad):
def complex_f(params, electrons, system):
sign, slogdet = self.call_signed_network(params, electrons, system)
return jnp.log(sign) + slogdet

grad_f_real = grad_with_system(
lambda *args: complex_f(*args).real, arg, args_before=1, jaxfun=jaxfun
)
grad_f_imag = grad_with_system(
lambda *args: complex_f(*args).imag, arg, args_before=1, jaxfun=jaxfun
)
return lambda *args: grad_f_real(*args) + grad_f_imag(*args) * 1j

def make_local_energy_grad(self, arg: str, jaxfun: Callable = jax.grad):
grad_local_energy_real = grad_with_system(
lambda *args: self.call_local_energy(*args).real, arg, args_before=2
lambda *args: self.call_local_energy(*args).real,
arg,
args_before=2,
jaxfun=jaxfun,
)
grad_local_energy_imag = grad_with_system(
lambda *args: self.call_local_energy(*args).imag, arg, args_before=2
lambda *args: self.call_local_energy(*args).imag,
arg,
args_before=2,
jaxfun=jaxfun,
)
return (
lambda *args: grad_local_energy_real(*args)
Expand Down
12 changes: 9 additions & 3 deletions netobs/helpers/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def wrapped_kinetic_energy(params: Any, x: jnp.ndarray, atoms: Any) -> jnp.ndarr
return wrapped_kinetic_energy


def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int | None = None):
def grad_with_system(
f: Callable[..., jnp.ndarray],
arg: str,
args_before: int | None = None,
jaxfun: Callable = jax.grad,
):
"""Make grad of functions like f(*args, electrons, system).
The last two args must be `electrons as `system`.
Expand All @@ -60,6 +65,7 @@ def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int |
To grad with things inside `system`, use the key, e.g. "atoms".
args_before: number of arguments before "electrons".
Leaving it empty to automatically detect.
jaxfun: the type of gradient to take, e.g. `jax.grad`.
Raises:
ValueError: failing to detect the function signature
Expand All @@ -73,13 +79,13 @@ def grad_with_system(f: Callable[..., jnp.ndarray], arg: str, args_before: int |
raise ValueError("Unable to determine function signature")

if arg == "electrons":
return jax.grad(f, argnums=args_before)
return jaxfun(f, argnums=args_before)

def wrap_f(*args):
*args, x, system = args
return f(*args, {**system, arg: x})

grad_local_energy = jax.grad(wrap_f, argnums=args_before + 1)
grad_local_energy = jaxfun(wrap_f, argnums=args_before + 1)

def wrap_grad(*args):
*args, system = args
Expand Down
Loading

0 comments on commit 799d612

Please sign in to comment.