Skip to content

Commit

Permalink
Adding more details to deep kernel learning example (#70)
Browse files Browse the repository at this point in the history
* add details to deep kernel example

* remove blackcellmagic cell

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add news fragment

* add more description

* show matern-3/2 on transformed features

* add more description

* [pre-commit.ci] pre-commit autoupdate (#71)

updates:
- [github.com/hadialqattan/pycln: v1.2.4 → v1.2.5](hadialqattan/pycln@v1.2.4...v1.2.5)
- [github.com/pre-commit/mirrors-mypy: v0.931 → v0.940](pre-commit/mirrors-mypy@v0.931...v0.940)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* try simplify intuition

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add smoothing assumption

* some edits

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dan F-M <[email protected]>
  • Loading branch information
3 people committed Mar 17, 2022
1 parent 1861232 commit d6b7c3c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 33 deletions.
202 changes: 169 additions & 33 deletions docs/tutorials/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@
"metadata": {},
"source": [
"Then we will fit this model using a model similar to the one described in {ref}`modeling-flax`, except our kernel will include a custom {class}`tinygp.kernels.Transform` that will pass the input coordinates through a (small) neural network before passing them into a {class}`tinygp.kernels.Matern32` kernel.\n",
"Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`."
"Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`.\n",
"\n",
"We compare the performance of the Deep Matern-3/2 kernel (a {class}`tinygp.kernels.Matern32` kernel, with custom neural network transform) to the performance of the same kernel without the transform. The untransformed model doesn't have the capacity to capture our simulated step function, but our transformed model does. In our transformed model, the hyperparameters of our kernel now include the weights of our neural network transform, and we learn those simultaneously with the length scale and amplitude of the `Matern32` kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f",
"id": "94938ebe",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -102,11 +104,49 @@
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"from flax.linen.initializers import zeros\n",
"from tinygp import kernels, transforms, GaussianProcess\n",
"\n",
"from tinygp import kernels, transforms, GaussianProcess"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21e63e64",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"class Matern32Loss(nn.Module):\n",
" @nn.compact\n",
" def __call__(self, x, y, t):\n",
" # Set up a typical Matern-3/2 kernel\n",
" log_sigma = self.param(\"log_sigma\", zeros, ())\n",
" log_rho = self.param(\"log_rho\", zeros, ())\n",
" log_jitter = self.param(\"log_jitter\", zeros, ())\n",
" base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(\n",
" jnp.exp(log_rho)\n",
" )\n",
"\n",
"# Define a small neural network used to non-linearly transform the input data in our model\n",
" # Evaluate and return the GP negative log likelihood as usual\n",
" gp = GaussianProcess(\n",
" base_kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n",
" )\n",
" log_prob, gp_cond = gp.condition(y, t[:, None])\n",
" return -log_prob, (gp_cond.loc, gp_cond.variance)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f",
"metadata": {},
"outputs": [],
"source": [
"class Transformer(nn.Module):\n",
" \"\"\"A small neural network used to non-linearly transform the input data\"\"\"\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" x = nn.Dense(features=15)(x)\n",
Expand All @@ -117,7 +157,7 @@
" return x\n",
"\n",
"\n",
"class GPLoss(nn.Module):\n",
"class DeepLoss(nn.Module):\n",
" @nn.compact\n",
" def __call__(self, x, y, t):\n",
" # Set up a typical Matern-3/2 kernel\n",
Expand All @@ -128,56 +168,152 @@
" jnp.exp(log_rho)\n",
" )\n",
"\n",
" # Define a custom transform to pass the input coordinates through our `Transformer`\n",
" # network from above\n",
" # Define a custom transform to pass the input coordinates through our\n",
" # `Transformer` network from above\n",
" transform = Transformer()\n",
" kernel = transforms.Transform(transform, base_kernel)\n",
"\n",
" # Evaluate and return the GP negative log likelihood as usual\n",
" # Evaluate and return the GP negative log likelihood as usual with the\n",
" # transformed features\n",
" gp = GaussianProcess(\n",
" kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n",
" )\n",
" log_prob, gp_cond = gp.condition(y, t[:, None])\n",
" return -log_prob, (gp_cond.loc, gp_cond.variance)\n",
"\n",
" # We return the loss, the conditional mean and variance, and the\n",
" # transformed input parameters\n",
" return (\n",
" -log_prob,\n",
" (gp_cond.loc, gp_cond.variance),\n",
" (transform(x[:, None]), transform(t[:, None])),\n",
" )\n",
"\n",
"\n",
"# Define and train the model\n",
"def loss(params):\n",
" return model.apply(params, x, y, t)[0]\n",
"def loss_func(model):\n",
" def loss(params):\n",
" return model.apply(params, x, y, t)[0]\n",
"\n",
" return loss\n",
"\n",
"model = GPLoss()\n",
"params = model.init(jax.random.PRNGKey(1234), x, y, t)\n",
"tx = optax.sgd(learning_rate=1e-4)\n",
"opt_state = tx.init(params)\n",
"loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n",
"for i in range(1000):\n",
" loss_val, grads = loss_grad_fn(params)\n",
" updates, opt_state = tx.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
"\n",
"models_list, params_list = [], []\n",
"loss_vals = {}\n",
"# Plot the results and compare to the true model\n",
"plt.figure()\n",
"mu, var = model.apply(params, x, y, t)[1]\n",
"plt.plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n",
"plt.plot(x, y, \".k\", label=\"data\")\n",
"plt.plot(t, mu)\n",
"plt.fill_between(\n",
" t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n",
")\n",
"plt.xlim(-1.5, 1.5)\n",
"plt.ylim(-1.3, 1.3)\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"y\")\n",
"fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n",
"for it, (model_name, model) in enumerate(\n",
" zip(\n",
" [\"Deep\", \"Matern32\"],\n",
" [DeepLoss(), Matern32Loss()],\n",
" )\n",
"):\n",
" loss_vals[it] = []\n",
" params = model.init(jax.random.PRNGKey(1234), x, y, t)\n",
" tx = optax.sgd(learning_rate=1e-4)\n",
" opt_state = tx.init(params)\n",
"\n",
" loss = loss_func(model)\n",
" loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n",
" for i in range(1000):\n",
" loss_val, grads = loss_grad_fn(params)\n",
" updates, opt_state = tx.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" loss_vals[it].append(loss_val)\n",
"\n",
" mu, var = model.apply(params, x, y, t)[1]\n",
" ax[it].plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n",
" ax[it].plot(x, y, \".k\", label=\"data\")\n",
" ax[it].plot(t, mu)\n",
" ax[it].fill_between(\n",
" t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n",
" )\n",
" ax[it].set_xlim(-1.5, 1.5)\n",
" ax[it].set_ylim(-1.3, 1.3)\n",
" ax[it].set_xlabel(\"x\")\n",
" ax[it].set_ylabel(\"y\")\n",
" ax[it].set_title(model_name)\n",
" _ = ax[it].legend()\n",
"\n",
" models_list.append(model)\n",
" params_list.append(params)"
]
},
{
"cell_type": "markdown",
"id": "bb4d5f08",
"metadata": {},
"source": [
"The untransformed `Matern32` model suffers from over-smoothing at the discontinuity, and poor extrapolation performance.\n",
"The `Deep` model extrapolates well and captures the discontinuity reliably.\n",
"\n",
"We can compare the training loss (negative log likelihood) traces for these two models:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feff3a28",
"metadata": {},
"outputs": [],
"source": [
"fig = plt.plot()\n",
"plt.plot(loss_vals[0], label=\"Deep\")\n",
"plt.plot(loss_vals[1], label=\"Matern32\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.xlabel(\"Training Iterations\")\n",
"_ = plt.legend()"
]
},
{
"cell_type": "markdown",
"id": "5692e918",
"metadata": {},
"source": [
"To inspect what the transformed model is doing under the hood, we can plot the functional form of the transformation, as well as the transformed values of our input coordinates: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a281b035-513a-4215-87fd-1a83b52ebd79",
"metadata": {},
"outputs": [],
"source": [
"x_transform, t_transform = models_list[0].apply(params_list[0], x, y, t)[2]\n",
"\n",
"fig = plt.figure()\n",
"plt.plot(t, t_transform, \"k\")\n",
"plt.xlim(-1.5, 1.5)\n",
"plt.ylim(-1.3, 1.3)\n",
"plt.xlabel(\"input data; x\")\n",
"plt.ylabel(\"transformed data; x'\")\n",
"\n",
"fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n",
"for it, (fig_title, feature_input, x_label) in enumerate(\n",
" zip([\"Input Data\", \"Transformed Data\"], [x, x_transform], [\"x\", \"x'\"])\n",
"):\n",
" ax[it].plot(feature_input, y, \".k\")\n",
" ax[it].set_xlim(-1.5, 1.5)\n",
" ax[it].set_ylim(-1.3, 1.3)\n",
" ax[it].set_title(fig_title)\n",
" ax[it].set_xlabel(x_label)\n",
" ax[it].set_ylabel(\"y\")"
]
},
{
"cell_type": "markdown",
"id": "673435d3",
"metadata": {},
"source": [
"The neural network transforms the input feature into a step function like data (as shown in the figures above) before feeding to the base kernel, making it better suited than the baseline model for this data."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d805ca0",
"metadata": {},
"outputs": [],
"source": []
}
],
Expand Down
3 changes: 3 additions & 0 deletions news/70.doc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add more details to Deep Kernel learning tutorial,
showing comparison with Matern-3/2 kernel
and the transformed features.

0 comments on commit d6b7c3c

Please sign in to comment.