Skip to content

Commit

Permalink
Fix LM optimizer on CUDA devices (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Oct 24, 2023
1 parent 8e2622d commit 56ecdd0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 78 deletions.
2 changes: 1 addition & 1 deletion devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- openff-toolkit-base >=0.9.2
- openff-interchange-base ==0.3.15

- pytorch-cpu
- pytorch
- pydantic
- nnpops

Expand Down
4 changes: 2 additions & 2 deletions smee/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Custom parameter optimizers."""

from smee.optimizers._lm import LevenbergMarquardtConfig, levenberg_marquardt
from smee.optimizers._lm import LevenbergMarquardt, LevenbergMarquardtConfig

__all__ = ["LevenbergMarquardtConfig", "levenberg_marquardt"]
__all__ = ["LevenbergMarquardt", "LevenbergMarquardtConfig"]
127 changes: 77 additions & 50 deletions smee/optimizers/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
import pydantic
import torch

import smee.utils

_LOGGER = logging.getLogger(__name__)


LossFunction = typing.Callable[
ClosureFn = typing.Callable[
[torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]
CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor]


class LevenbergMarquardtConfig(pydantic.BaseModel):
Expand All @@ -40,10 +43,6 @@ class LevenbergMarquardtConfig(pydantic.BaseModel):
1.0, description="Adaptive trust radius adjustment damping.", gt=0.0
)

max_iterations: int = pydantic.Field(
..., description="The maximum number of iterations to perform.", ge=0
)

error_tolerance: float = pydantic.Field(
1.0,
description="Steps where the loss increases more than this amount are rejected.",
Expand Down Expand Up @@ -93,7 +92,9 @@ def _solver(
The step with ``shape=(n,)`` and the expected improvement with ``shape=()``.
"""

hessian_regular = hessian + (damping_factor - 1) ** 2 * torch.eye(len(hessian))
hessian_regular = hessian + (damping_factor - 1) ** 2 * torch.eye(
len(hessian), device=hessian.device, dtype=hessian.dtype
)
hessian_inverse = _invert_svd(hessian_regular)

dx = -(hessian_inverse @ gradient)
Expand Down Expand Up @@ -135,7 +136,7 @@ def _damping_factor_loss_fn(
def _step(
gradient: torch.Tensor,
hessian: torch.Tensor,
trust_radius: float,
trust_radius: torch.Tensor,
initial_damping_factor: float = 1.0,
min_eigenvalue: float = 1.0e-4,
) -> tuple[torch.Tensor, torch.Tensor, bool]:
Expand Down Expand Up @@ -175,7 +176,9 @@ def _step(
f"hessian has a small or negative eigenvalue ({eigenvalue_smallest:.1e}), "
f"mixing in some steepest descent ({adjacency:.1e}) to correct this."
)
hessian += adjacency * torch.eye(hessian.shape[0])
hessian += adjacency * torch.eye(
hessian.shape[0], device=hessian.device, dtype=hessian.dtype
)

damping_factor = torch.tensor(1.0)

Expand All @@ -190,7 +193,11 @@ def _step(
# meaningless steps.
damping_factor = optimize.brent(
_damping_factor_loss_fn,
(gradient, hessian, trust_radius),
(
gradient.detach().cpu(),
hessian.detach().cpu(),
trust_radius.detach().cpu(),
),
brack=(initial_damping_factor, initial_damping_factor * 4),
tol=1e-6,
)
Expand All @@ -205,7 +212,7 @@ def _step(

def _reduce_trust_radius(
dx_norm: torch.Tensor, config: LevenbergMarquardtConfig
) -> float:
) -> torch.Tensor:
"""Reduce the trust radius.
Args:
Expand All @@ -220,16 +227,16 @@ def _reduce_trust_radius(
)
_LOGGER.info(f"reducing trust radius to {trust_radius:.4e}")

return trust_radius
return smee.utils.tensor_like(trust_radius, dx_norm)


def _update_trust_radius(
dx_norm: torch.Tensor,
step_quality: float,
trust_radius: float,
trust_radius: torch.Tensor,
damping_adjusted: bool,
config: LevenbergMarquardtConfig,
) -> float:
) -> torch.Tensor:
"""Adjust the trust radius based on the quality of the previous step.
Args:
Expand All @@ -246,7 +253,8 @@ def _update_trust_radius(

if step_quality <= config.quality_threshold_low:
trust_radius = max(
dx_norm * (1.0 / (1.0 + config.adaptive_factor)), config.min_trust_radius
dx_norm * (1.0 / (1.0 + config.adaptive_factor)),
smee.utils.tensor_like(config.min_trust_radius, dx_norm),
)
_LOGGER.info(
f"low quality step detected - reducing trust radius to {trust_radius:.4e}"
Expand All @@ -265,59 +273,78 @@ def _update_trust_radius(
return trust_radius


@torch.no_grad()
def levenberg_marquardt(
x: torch.Tensor, loss_fn: LossFunction, config: LevenbergMarquardtConfig
) -> torch.Tensor:
"""Optimize a function using the Levenberg-Marquardt algorithm.
Args:
x: The initial guess of the parameters.
loss_fn: The loss function. This should return the loss, gradient (with
``shape=(n,)``), and hessian (with ``shape=(n, n)``).
config: The optimizer config.
class LevenbergMarquardt:
"""A Levenberg-Marquardt optimizer.
Returns:
The optimized parameters.
Notes:
This is a reimplementation of the Levenberg-Marquardt optimizer from the
ForceBalance package, and so may differ from a standard implementation.
"""
x = torch.tensor(x, requires_grad=x.requires_grad)

history = [loss_fn(x)]
iteration = 0
def __init__(self, config: LevenbergMarquardtConfig | None = None):
self.config = config if config is not None else LevenbergMarquardtConfig()

self._closure_prev = None
self._trust_radius = torch.tensor(self.config.trust_radius)

@torch.no_grad()
def step(
self,
x: torch.Tensor,
closure_fn: ClosureFn,
correct_fn: CorrectFn | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs a single optimization step.
Args:
x: The initial guess of the parameters.
closure_fn: The closure that computes the loss (``shape=()``), its
gradient (``shape=(n,)``), and hessian (``shape=(n, n)``)..
correct_fn: A function that can be used to correct the parameters after
each step is taken and before the new loss is computed. This may
include, for example, ensuring that vdW parameters are all positive.
Returns:
The optimized parameters.
"""

trust_radius = config.trust_radius
correct_fn = correct_fn if correct_fn is not None else lambda x: x
closure_fn = torch.enable_grad()(closure_fn)

while iteration < config.max_iterations:
loss_prev, gradient_prev, hessian_prev = history[-1]
if self._closure_prev is None:
# compute the initial loss, gradient and hessian
self._closure_prev = closure_fn(x)

if self._trust_radius.device != x.device:
self._trust_radius = self._trust_radius.to(x.device)

loss_prev, gradient_prev, hessian_prev = self._closure_prev

dx, expected_improvement, damping_adjusted = _step(
gradient_prev, hessian_prev, trust_radius
gradient_prev, hessian_prev, self._trust_radius
)
dx_norm = torch.linalg.norm(dx)

x_prev = torch.tensor(x, requires_grad=x.requires_grad)
x += dx
x_next = correct_fn(x + dx).requires_grad_(x.requires_grad)

loss, gradient, hessian = loss_fn(x)
loss_delta = loss - loss_prev
loss_next, gradient_next, hessian_next = closure_fn(x_next)
loss_delta = loss_next - loss_prev

step_quality = loss_delta / expected_improvement

if loss > (loss_prev + config.error_tolerance):
trust_radius = _reduce_trust_radius(dx_norm, config)

if loss_next > (loss_prev + self.config.error_tolerance):
# reject the 'bad' step and try again from where we were
x = x_prev
loss, gradient, hessian = (loss_prev, gradient_prev, hessian_prev)

self._trust_radius = _reduce_trust_radius(dx_norm, self.config)
else:
trust_radius = _update_trust_radius(
dx_norm, step_quality, trust_radius, damping_adjusted, config
)

history.append((loss, gradient, hessian))
iteration += 1
# accept the step
loss, gradient, hessian = (loss_next, gradient_next, hessian_next)
x.data.copy_(x_next.data)

_LOGGER.info(f"step={iteration} loss={loss:.4e}")
self._trust_radius = _update_trust_radius(
dx_norm, step_quality, self._trust_radius, damping_adjusted, self.config
)

return x
self._closure_prev = (loss, gradient, hessian)
return self._closure_prev
50 changes: 25 additions & 25 deletions smee/tests/optimizers/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import torch

from smee.optimizers._lm import (
LevenbergMarquardtConfig,
LevenbergMarquardt,
_damping_factor_loss_fn,
_solver,
_step,
levenberg_marquardt,
)


Expand Down Expand Up @@ -47,7 +46,7 @@ def test_step():
]
)

expected_trust_radius = 0.123
expected_trust_radius = torch.tensor(0.123)

dx, solution, adjusted = _step(
gradient, hessian, trust_radius=expected_trust_radius
Expand Down Expand Up @@ -155,18 +154,17 @@ def test_levenberg_marquardt_adaptive(mocker, caplog):

x_traj = []

def mock_loss_fn(x):
x_traj.append(x.clone())
def mock_loss_fn(_x):
x_traj.append(_x.clone())
return mock_loss_traj.pop(0)

x_0 = torch.tensor([0.0, 0.0])
x = torch.tensor([0.0, 0.0])

with caplog.at_level(logging.INFO):
x_final = levenberg_marquardt(
x_0, mock_loss_fn, LevenbergMarquardtConfig(max_iterations=3)
)
optimizer = LevenbergMarquardt()

assert torch.allclose(x_0, torch.tensor([0.0, 0.0]))
with caplog.at_level(logging.INFO):
for _ in range(3):
optimizer.step(x, mock_loss_fn)

expected_x_traj = [
torch.tensor([0.0, 0.0]),
Expand All @@ -175,8 +173,8 @@ def mock_loss_fn(x):
torch.tensor([0.1, 0.2]),
torch.tensor([0.15, 0.21]),
]
assert x_final.shape == expected_x_traj[-1].shape
assert torch.allclose(x_final, expected_x_traj[-1])
assert x.shape == expected_x_traj[-1].shape
assert torch.allclose(x, expected_x_traj[-1])

trust_radius_messages = [m for m in caplog.messages if "trust radius" in m]

Expand All @@ -202,24 +200,26 @@ def test_levenberg_marquardt():
x_ref = torch.linspace(-2.0, 2.0, 100)
y_ref = expected[0] * x_ref**2 + expected[1] * x_ref + expected[2]

theta_0 = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
theta = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)

def loss_fn(theta: torch.Tensor) -> torch.Tensor:
y = theta[0] * x_ref**2 + theta[1] * x_ref + theta[2]
def loss_fn(_theta: torch.Tensor) -> torch.Tensor:
y = _theta[0] * x_ref**2 + _theta[1] * x_ref + _theta[2]
return torch.sum((y - y_ref) ** 2)

@torch.enable_grad()
def target_fn(
theta: torch.Tensor,
_theta: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.enable_grad():
loss = loss_fn(theta)
(grad,) = torch.autograd.grad(loss, theta, torch.tensor(1.0))
hess = torch.autograd.functional.hessian(loss_fn, theta)
loss = loss_fn(_theta)
(grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0))
hess = torch.autograd.functional.hessian(loss_fn, _theta)

return loss.detach(), grad.detach(), hess.detach()

config = LevenbergMarquardtConfig(max_iterations=15)
x_final = levenberg_marquardt(theta_0, target_fn, config)
optimizer = LevenbergMarquardt()

for _ in range(15):
optimizer.step(theta, target_fn)

assert x_final.shape == expected.shape
assert torch.allclose(x_final, expected)
assert theta.shape == expected.shape
assert torch.allclose(theta, expected)

0 comments on commit 56ecdd0

Please sign in to comment.