Skip to content

Commit

Permalink
feat: neural tangents support (#117)
Browse files Browse the repository at this point in the history
Adds support for the neural tangents library, i.e., uncertainty estimates using the neural tangent kernel or NNGP.
At the same time, we use the newly introduced `jax` dependency to build ensembles of finite width NN.
  • Loading branch information
kjappelbaum committed Nov 28, 2020
1 parent 230f7cd commit 5911cc5
Show file tree
Hide file tree
Showing 14 changed files with 826 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov prospector GPy matplotlib lightgbm && pip install -e .
pip install -e .[all,testing,pre-commit]
- name: Test with pytest (numba activated)
run: |
pytest
Expand Down
2 changes: 2 additions & 0 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Which class do i use?
- For Gaussian processes built with :code:`GPy` use :py:class:`~pyepal.pal.pal_gpy.PALGPy`
- For coregionalized Gaussian processes (built with :code:`GPy`) use :py:class:`~pyepal.pal.pal_coregionalized.PALCoregionalized`
- For quantile regression using :code:`LightGBM` gradient boosted decision trees use :py:class:`~pyepal.pal.pal_gbdt.PALGBDT`
- For `infinite wide neural networks with the neural tangent kernel or exact Bayesian inference (Novak et al., 2019) <https://arxiv.org/pdf/1912.02803.pdf>`_ use :py:class:`~pyepal.pal.pal_neutral_tangent.PALNT`
- For an `ensemble of finite width neural networks (Lakshminarayanan et al., 2017) <https://proceedings.neurips.cc/paper/2017/file/9ef2ed4b7fd2c810847ffa5fa85bce38-Paper.pdf>`_ (built with JAX) use :py:class:`~pyepal.pal.pal_finite_ensemble.PALNTEnsemble`

If your favorite model is not listed, you can easily implement it yourself (see :ref:`new_pal_class`)!

Expand Down
7 changes: 7 additions & 0 deletions pyepal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

"""PyePAL"""
from ._version import get_versions
from .models.nt import JaxOptimizer, NTModel
from .pal.pal_base import PALBase
from .pal.pal_coregionalized import PALCoregionalized
from .pal.pal_finite_ensemble import PALJaxEnsemble
from .pal.pal_gbdt import PALGBDT
from .pal.pal_gpy import PALGPy
from .pal.pal_neural_tangent import PALNT
from .pal.pal_sklearn import PALSklearn
from .pal.utils import (
exhaust_loop,
Expand All @@ -37,6 +40,10 @@
"PALGBDT",
"PALGPy",
"PALSklearn",
"PALJaxEnsemble",
"PALNT",
"NTModel",
"JaxOptimizer",
"exhaust_loop",
"get_hypervolume",
"get_kmeans_samples",
Expand Down
159 changes: 159 additions & 0 deletions pyepal/models/nt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
# Copyright 2020 PyePAL authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions to build neutral tangents models for PALNT
Depending on the dataset there might be some issues with these models,
some tricks are listed in https://github.com/google/neural-tangents/issues/76
1. Use Erf as activation
2. Initialize the weights with larger standard deviation
3. Standardize the data
The first two points are done by default in the `build_dense_network` function
Note that following the law of total variance the prior, intialized via
W_std and b_std give an upper bound on the std of the posterior
"""

from dataclasses import dataclass
from typing import Callable, Sequence, Union

from jax import jit
from jax.experimental import optimizers
from neural_tangents import stax


@dataclass
class NTModel:
"""Defining a dataclass for neural tangents models"""

# Initialization functions construct parameters for neural networks
# given a random key and an input shape.
init_fn: Callable
# Apply functions do computations with finite-width neural networks.
apply_fn: Callable
kernel_fn: Callable
predict_fn: Union[Callable, None] = None
scaler: Union[Callable, None] = None # Used to store Standard Scaler objects
params: Union[list, None] = None # Used to store parameters for the ensemble models


@dataclass
class JaxOptimizer:
"""Defining a dataclass for a JAX optimizer"""

opt_init: Callable
opt_update: Callable
get_params: Callable


__all__ = ["NTModel", "build_dense_network", "JaxOptimizer"]


def build_dense_network(
hidden_layers: Sequence[int],
activations: Union[Sequence, str] = "erf",
w_std: float = 2.5,
b_std=1,
) -> NTModel:
"""Utility function to build a simple feedforward network with the
neural tangents library.
Args:
hidden_layers (Sequence[int]): Iterable with the number of neurons.
For example, [512, 512]
activations (Union[Sequence, str], optional):
Iterable with neural_tangents.stax axtivations or "relu" or "erf".
Defaults to "erf".
w_std (float): Standard deviation of the weight distribution.
b_std (float): Standard deviation of the bias distribution.
Returns:
NTModel: jiited init, apply and
kernel functions, predict_function (None)
"""
assert len(hidden_layers) >= 1, "You must provide at least one hidden layer"
if activations is None:
activations = [stax.Relu() for _ in hidden_layers]
elif isinstance(activations, str):
if activations.lower() == "relu":
activations = [stax.Relu() for _ in hidden_layers]
elif activations.lower() == "erf":
activations = [stax.Erf() for _ in hidden_layers]
else:
for activation in activations:
assert callable(
activation
), "You need to provide `neural_tangents.stax` activations"

assert len(activations) == len(
hidden_layers
), "The number of hidden layers should match the number of nonlinearities"
stack = []

for hidden_layer, activation in zip(hidden_layers, activations):
stack.append(stax.Dense(hidden_layer, W_std=w_std, b_std=b_std))
stack.append(activation)

stack.append(stax.Dense(1, W_std=w_std, b_std=b_std))

init_fn, apply_fn, kernel_fn = stax.serial(*stack)

return NTModel(init_fn, jit(apply_fn), jit(kernel_fn, static_argnums=(2,)), None)


def get_optimizer(
learning_rate: float = 1e-4, optimizer="sdg", optimizer_kwargs: dict = None
) -> JaxOptimizer:
"""Return a `JaxOptimizer` dataclass for a JAX optimizer
Args:
learning_rate (float, optional): Step size. Defaults to 1e-4.
optimizer (str, optional): Optimizer type (Allowed types: "adam",
"adamax", "adagrad", "rmsprop", "sdg"). Defaults to "sdg".
optimizer_kwargs (dict, optional): Additional keyword arguments
that are passed to the optimizer. Defaults to None.
Returns:
JaxOptimizer
"""
if optimizer_kwargs is None:
optimizer_kwargs = {}
optimizer = optimizer.lower()
if optimizer == "adam":
opt_init, opt_update, get_params = optimizers.adam(
learning_rate, **optimizer_kwargs
)
elif optimizer == "adagrad":
opt_init, opt_update, get_params = optimizers.adagrad(
learning_rate, **optimizer_kwargs
)
elif optimizer == "adamax":
opt_init, opt_update, get_params = optimizers.adamax(
learning_rate, **optimizer_kwargs
)
elif optimizer == "rmsprop":
opt_init, opt_update, get_params = optimizers.rmsprop(
learning_rate, **optimizer_kwargs
)
else:
opt_init, opt_update, get_params = optimizers.sgd(
learning_rate, **optimizer_kwargs
)

opt_update = jit(opt_update)

return JaxOptimizer(opt_init, opt_update, get_params)
2 changes: 1 addition & 1 deletion pyepal/pal/pal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def _compare_mae_variance(self):
"""The mean absolute error in crossvalidation is {:.2f},
the mean standard deviation is {:.2f}.
Your model might not be predictive and/or overconfident.
In the docs, you find hints on how to make GPRs more robust.""".format(
In the docs, you find hints on how to make models more robust.""".format(
mae, mean_std
),
UserWarning,
Expand Down
182 changes: 182 additions & 0 deletions pyepal/pal/pal_finite_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
# Copyright 2020 PyePAL authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Run PAL with the same models for finite ensemble models
and infinite width models (`PALNT`)
"""

from typing import Sequence

import numpy as np
from jax import random
from jax.api import grad, jit, vmap
from sklearn.preprocessing import StandardScaler

from ..models.nt import JaxOptimizer, NTModel
from .pal_base import PALBase
from .validate_inputs import (
validate_nt_models,
validate_optimizers,
validate_positive_integer_list,
)


# Again, the idea of having the core as pure functions outside of the class is that
# we could parallelize it easier in this way
def _ensemble_train_one_finite_width( # pylint:disable=too-many-arguments, too-many-locals
i: int,
models: Sequence[NTModel],
design_space: np.ndarray,
objectives: np.ndarray,
sampled: np.ndarray,
optimizers: Sequence[JaxOptimizer],
key: random.PRNGKey,
training_steps: Sequence[int],
ensemble_size: Sequence[int],
):
model = models[i]
optimizer = optimizers[i]
loss = jit(lambda params, x, y: 0.5 * np.mean((model.apply_fn(params, x) - y) ** 2))
grad_loss = jit(lambda state, x, y: grad(loss)(optimizer.get_params(state), x, y))

x_train = design_space[sampled[:, i]]

scaler = StandardScaler()
y_train = scaler.fit_transform(objectives[sampled[:, i], i].reshape(-1, 1))

def train_network(key):
_, params = model.init_fn(key, (-1, x_train.shape[1]))
opt_state = optimizer.opt_init(params)

for j in range(training_steps[i]):
opt_state = optimizer.opt_update(
j, grad_loss(opt_state, x_train, y_train), opt_state
)

return optimizer.get_params(opt_state)

ensemble_key = random.split(key, ensemble_size[i])
params = vmap(train_network)(ensemble_key)

return params, scaler


def _ensemble_predict_one_finite_width(i: int, models: Sequence[NTModel], design_space):
model = models[i]

ensemble_func = vmap(model.apply_fn, (0, None))(model.params, design_space)

mean_func = np.reshape(np.mean(ensemble_func, axis=0), (-1,))
std_func = np.reshape(np.std(ensemble_func, axis=0), (-1,))

return mean_func, std_func


__all__ = ["PALJaxEnsemble", "NTModel", "JaxOptimizer"]


class PALJaxEnsemble(PALBase): # pylint:disable=too-many-instance-attributes
"""Use PAL with and ensemble of finite-width neural networks.
Note that we current assume that there is one model per output,
i.e., we did not yet implement multihead support.
"""

def __init__(self, *args, **kwargs):
"""Construct the PALJaxEnsemble instance
Args:
X_design (np.array): Design space (feature matrix)
models (Sequence[NTModel]): You need to provide a sequence of
NTModel (`pyepal.models.nt.NTModel`).
The elements of this dataclass are the `apply_fn`, `init_fn`,
`kernel_fn` and `predict_fn` (for latter you can typically
provide `None`).
Can be constructed with
:py:func:`pyepal.pal.models.nt.build_dense_network`.
optimizer (Union[JaxOptimizer, Sequence[JaxOptimizer]]):
Sequence of dataclasses with functions for a JAX optimizer,
can be constructed with :py:func:`pyepal.pal.models.nt.get_optimizer`.
ndim (int): Number of objectives
epsilon (Union[list, float], optional): Epsilon hyperparameter.
Defaults to 0.01.
delta (float, optional): Delta hyperparameter. Defaults to 0.05.
beta_scale (float, optional): Scaling parameter for beta.
If not equal to 1, the theoretical guarantees do not necessarily hold.
Also note that the parametrization depends on the kernel type.
Defaults to 1/9.
goals (List[str], optional): If a list, provide "min" for every objective
that shall be minimized and "max" for every objective
that shall be maximized. Defaults to None, which means
that the code maximizes all objectives.
coef_var_threshold (float, optional): Use only points with
a coefficient of variation below this threshold
in the classification step. Defaults to 3.
key (int): Seed to generate the key for the JAX
pseudo-random number generator. Defaults to 10.
training_steps (Union[int, Sequence[int]]): Number of epochs,
the networks are trained. Defaults to 500.
ensemble_size (Union[int, Sequence[int]]): Size of the ensemble, i.e.,
over how many randomly initialized neural networks we average
to obtain estimates of mean and standard deviation.
Automatically vectorized using `vmap`.
Defaults to 100.
"""
self.optimizers = validate_optimizers(
kwargs.pop("optimizers"), kwargs.get("ndim")
)

self.training_steps = validate_positive_integer_list(
kwargs.pop("training_steps", 500), kwargs.get("ndim")
)
self.ensemble_size = validate_positive_integer_list(
kwargs.pop("ensemble_size", 100), kwargs.get("ndim")
)
self.key = random.PRNGKey(kwargs.pop("key", 10))
self.design_space_scaler = StandardScaler()
super().__init__(*args, **kwargs)
self.models = validate_nt_models(self.models, self.ndim)

def _set_data(self):
self.design_space = self.design_space_scaler.fit_transform(self.design_space)

def _train(self):
for i in range(len(self.models)):
params, scaler = _ensemble_train_one_finite_width(
i,
self.models,
self.design_space,
self.y,
self.sampled,
self.optimizers,
self.key,
self.training_steps,
self.ensemble_size,
)
self.models[i].params = params
self.models[i].scaler = scaler
self.y[:, i] = scaler.transform(self.y[:, i].reshape(-1, 1)).flatten()

def _predict(self):
means, stds = [], []
for i in range(len(self.models)):
mean, std = _ensemble_predict_one_finite_width(
i, self.models, self.design_space
)
means.append(mean.reshape(-1, 1))
stds.append(std.reshape(-1, 1))

self.means = np.hstack(means)
self.std = np.hstack(stds)
Loading

0 comments on commit 5911cc5

Please sign in to comment.