-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: neural tangents support (#117)
Adds support for the neural tangents library, i.e., uncertainty estimates using the neural tangent kernel or NNGP. At the same time, we use the newly introduced `jax` dependency to build ensembles of finite width NN.
- Loading branch information
1 parent
230f7cd
commit 5911cc5
Showing
14 changed files
with
826 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2020 PyePAL authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utility functions to build neutral tangents models for PALNT | ||
Depending on the dataset there might be some issues with these models, | ||
some tricks are listed in https://github.com/google/neural-tangents/issues/76 | ||
1. Use Erf as activation | ||
2. Initialize the weights with larger standard deviation | ||
3. Standardize the data | ||
The first two points are done by default in the `build_dense_network` function | ||
Note that following the law of total variance the prior, intialized via | ||
W_std and b_std give an upper bound on the std of the posterior | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Sequence, Union | ||
|
||
from jax import jit | ||
from jax.experimental import optimizers | ||
from neural_tangents import stax | ||
|
||
|
||
@dataclass | ||
class NTModel: | ||
"""Defining a dataclass for neural tangents models""" | ||
|
||
# Initialization functions construct parameters for neural networks | ||
# given a random key and an input shape. | ||
init_fn: Callable | ||
# Apply functions do computations with finite-width neural networks. | ||
apply_fn: Callable | ||
kernel_fn: Callable | ||
predict_fn: Union[Callable, None] = None | ||
scaler: Union[Callable, None] = None # Used to store Standard Scaler objects | ||
params: Union[list, None] = None # Used to store parameters for the ensemble models | ||
|
||
|
||
@dataclass | ||
class JaxOptimizer: | ||
"""Defining a dataclass for a JAX optimizer""" | ||
|
||
opt_init: Callable | ||
opt_update: Callable | ||
get_params: Callable | ||
|
||
|
||
__all__ = ["NTModel", "build_dense_network", "JaxOptimizer"] | ||
|
||
|
||
def build_dense_network( | ||
hidden_layers: Sequence[int], | ||
activations: Union[Sequence, str] = "erf", | ||
w_std: float = 2.5, | ||
b_std=1, | ||
) -> NTModel: | ||
"""Utility function to build a simple feedforward network with the | ||
neural tangents library. | ||
Args: | ||
hidden_layers (Sequence[int]): Iterable with the number of neurons. | ||
For example, [512, 512] | ||
activations (Union[Sequence, str], optional): | ||
Iterable with neural_tangents.stax axtivations or "relu" or "erf". | ||
Defaults to "erf". | ||
w_std (float): Standard deviation of the weight distribution. | ||
b_std (float): Standard deviation of the bias distribution. | ||
Returns: | ||
NTModel: jiited init, apply and | ||
kernel functions, predict_function (None) | ||
""" | ||
assert len(hidden_layers) >= 1, "You must provide at least one hidden layer" | ||
if activations is None: | ||
activations = [stax.Relu() for _ in hidden_layers] | ||
elif isinstance(activations, str): | ||
if activations.lower() == "relu": | ||
activations = [stax.Relu() for _ in hidden_layers] | ||
elif activations.lower() == "erf": | ||
activations = [stax.Erf() for _ in hidden_layers] | ||
else: | ||
for activation in activations: | ||
assert callable( | ||
activation | ||
), "You need to provide `neural_tangents.stax` activations" | ||
|
||
assert len(activations) == len( | ||
hidden_layers | ||
), "The number of hidden layers should match the number of nonlinearities" | ||
stack = [] | ||
|
||
for hidden_layer, activation in zip(hidden_layers, activations): | ||
stack.append(stax.Dense(hidden_layer, W_std=w_std, b_std=b_std)) | ||
stack.append(activation) | ||
|
||
stack.append(stax.Dense(1, W_std=w_std, b_std=b_std)) | ||
|
||
init_fn, apply_fn, kernel_fn = stax.serial(*stack) | ||
|
||
return NTModel(init_fn, jit(apply_fn), jit(kernel_fn, static_argnums=(2,)), None) | ||
|
||
|
||
def get_optimizer( | ||
learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None | ||
) -> JaxOptimizer: | ||
"""Return a `JaxOptimizer` dataclass for a JAX optimizer | ||
Args: | ||
learning_rate (float, optional): Step size. Defaults to 1e-4. | ||
optimizer (str, optional): Optimizer type (Allowed types: "adam", | ||
"adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg". | ||
optimizer_kwargs (dict, optional): Additional keyword arguments | ||
that are passed to the optimizer. Defaults to None. | ||
Returns: | ||
JaxOptimizer | ||
""" | ||
if optimizer_kwargs is None: | ||
optimizer_kwargs = {} | ||
optimizer = optimizer.lower() | ||
if optimizer == "adam": | ||
opt_init, opt_update, get_params = optimizers.adam( | ||
learning_rate, **optimizer_kwargs | ||
) | ||
elif optimizer == "adagrad": | ||
opt_init, opt_update, get_params = optimizers.adagrad( | ||
learning_rate, **optimizer_kwargs | ||
) | ||
elif optimizer == "adamax": | ||
opt_init, opt_update, get_params = optimizers.adamax( | ||
learning_rate, **optimizer_kwargs | ||
) | ||
elif optimizer == "rmsprop": | ||
opt_init, opt_update, get_params = optimizers.rmsprop( | ||
learning_rate, **optimizer_kwargs | ||
) | ||
else: | ||
opt_init, opt_update, get_params = optimizers.sgd( | ||
learning_rate, **optimizer_kwargs | ||
) | ||
|
||
opt_update = jit(opt_update) | ||
|
||
return JaxOptimizer(opt_init, opt_update, get_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2020 PyePAL authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
"""Run PAL with the same models for finite ensemble models | ||
and infinite width models (`PALNT`) | ||
""" | ||
|
||
from typing import Sequence | ||
|
||
import numpy as np | ||
from jax import random | ||
from jax.api import grad, jit, vmap | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from ..models.nt import JaxOptimizer, NTModel | ||
from .pal_base import PALBase | ||
from .validate_inputs import ( | ||
validate_nt_models, | ||
validate_optimizers, | ||
validate_positive_integer_list, | ||
) | ||
|
||
|
||
# Again, the idea of having the core as pure functions outside of the class is that | ||
# we could parallelize it easier in this way | ||
def _ensemble_train_one_finite_width( # pylint:disable=too-many-arguments, too-many-locals | ||
i: int, | ||
models: Sequence[NTModel], | ||
design_space: np.ndarray, | ||
objectives: np.ndarray, | ||
sampled: np.ndarray, | ||
optimizers: Sequence[JaxOptimizer], | ||
key: random.PRNGKey, | ||
training_steps: Sequence[int], | ||
ensemble_size: Sequence[int], | ||
): | ||
model = models[i] | ||
optimizer = optimizers[i] | ||
loss = jit(lambda params, x, y: 0.5 * np.mean((model.apply_fn(params, x) - y) ** 2)) | ||
grad_loss = jit(lambda state, x, y: grad(loss)(optimizer.get_params(state), x, y)) | ||
|
||
x_train = design_space[sampled[:, i]] | ||
|
||
scaler = StandardScaler() | ||
y_train = scaler.fit_transform(objectives[sampled[:, i], i].reshape(-1, 1)) | ||
|
||
def train_network(key): | ||
_, params = model.init_fn(key, (-1, x_train.shape[1])) | ||
opt_state = optimizer.opt_init(params) | ||
|
||
for j in range(training_steps[i]): | ||
opt_state = optimizer.opt_update( | ||
j, grad_loss(opt_state, x_train, y_train), opt_state | ||
) | ||
|
||
return optimizer.get_params(opt_state) | ||
|
||
ensemble_key = random.split(key, ensemble_size[i]) | ||
params = vmap(train_network)(ensemble_key) | ||
|
||
return params, scaler | ||
|
||
|
||
def _ensemble_predict_one_finite_width(i: int, models: Sequence[NTModel], design_space): | ||
model = models[i] | ||
|
||
ensemble_func = vmap(model.apply_fn, (0, None))(model.params, design_space) | ||
|
||
mean_func = np.reshape(np.mean(ensemble_func, axis=0), (-1,)) | ||
std_func = np.reshape(np.std(ensemble_func, axis=0), (-1,)) | ||
|
||
return mean_func, std_func | ||
|
||
|
||
__all__ = ["PALJaxEnsemble", "NTModel", "JaxOptimizer"] | ||
|
||
|
||
class PALJaxEnsemble(PALBase): # pylint:disable=too-many-instance-attributes | ||
"""Use PAL with and ensemble of finite-width neural networks. | ||
Note that we current assume that there is one model per output, | ||
i.e., we did not yet implement multihead support. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
"""Construct the PALJaxEnsemble instance | ||
Args: | ||
X_design (np.array): Design space (feature matrix) | ||
models (Sequence[NTModel]): You need to provide a sequence of | ||
NTModel (`pyepal.models.nt.NTModel`). | ||
The elements of this dataclass are the `apply_fn`, `init_fn`, | ||
`kernel_fn` and `predict_fn` (for latter you can typically | ||
provide `None`). | ||
Can be constructed with | ||
:py:func:`pyepal.pal.models.nt.build_dense_network`. | ||
optimizer (Union[JaxOptimizer, Sequence[JaxOptimizer]]): | ||
Sequence of dataclasses with functions for a JAX optimizer, | ||
can be constructed with :py:func:`pyepal.pal.models.nt.get_optimizer`. | ||
ndim (int): Number of objectives | ||
epsilon (Union[list, float], optional): Epsilon hyperparameter. | ||
Defaults to 0.01. | ||
delta (float, optional): Delta hyperparameter. Defaults to 0.05. | ||
beta_scale (float, optional): Scaling parameter for beta. | ||
If not equal to 1, the theoretical guarantees do not necessarily hold. | ||
Also note that the parametrization depends on the kernel type. | ||
Defaults to 1/9. | ||
goals (List[str], optional): If a list, provide "min" for every objective | ||
that shall be minimized and "max" for every objective | ||
that shall be maximized. Defaults to None, which means | ||
that the code maximizes all objectives. | ||
coef_var_threshold (float, optional): Use only points with | ||
a coefficient of variation below this threshold | ||
in the classification step. Defaults to 3. | ||
key (int): Seed to generate the key for the JAX | ||
pseudo-random number generator. Defaults to 10. | ||
training_steps (Union[int, Sequence[int]]): Number of epochs, | ||
the networks are trained. Defaults to 500. | ||
ensemble_size (Union[int, Sequence[int]]): Size of the ensemble, i.e., | ||
over how many randomly initialized neural networks we average | ||
to obtain estimates of mean and standard deviation. | ||
Automatically vectorized using `vmap`. | ||
Defaults to 100. | ||
""" | ||
self.optimizers = validate_optimizers( | ||
kwargs.pop("optimizers"), kwargs.get("ndim") | ||
) | ||
|
||
self.training_steps = validate_positive_integer_list( | ||
kwargs.pop("training_steps", 500), kwargs.get("ndim") | ||
) | ||
self.ensemble_size = validate_positive_integer_list( | ||
kwargs.pop("ensemble_size", 100), kwargs.get("ndim") | ||
) | ||
self.key = random.PRNGKey(kwargs.pop("key", 10)) | ||
self.design_space_scaler = StandardScaler() | ||
super().__init__(*args, **kwargs) | ||
self.models = validate_nt_models(self.models, self.ndim) | ||
|
||
def _set_data(self): | ||
self.design_space = self.design_space_scaler.fit_transform(self.design_space) | ||
|
||
def _train(self): | ||
for i in range(len(self.models)): | ||
params, scaler = _ensemble_train_one_finite_width( | ||
i, | ||
self.models, | ||
self.design_space, | ||
self.y, | ||
self.sampled, | ||
self.optimizers, | ||
self.key, | ||
self.training_steps, | ||
self.ensemble_size, | ||
) | ||
self.models[i].params = params | ||
self.models[i].scaler = scaler | ||
self.y[:, i] = scaler.transform(self.y[:, i].reshape(-1, 1)).flatten() | ||
|
||
def _predict(self): | ||
means, stds = [], [] | ||
for i in range(len(self.models)): | ||
mean, std = _ensemble_predict_one_finite_width( | ||
i, self.models, self.design_space | ||
) | ||
means.append(mean.reshape(-1, 1)) | ||
stds.append(std.reshape(-1, 1)) | ||
|
||
self.means = np.hstack(means) | ||
self.std = np.hstack(stds) |
Oops, something went wrong.