diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 2f2c3be1..e15b8ecd 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -87,13 +87,15 @@ "metadata": {}, "source": [ "Then we will fit this model using a model similar to the one described in {ref}`modeling-flax`, except our kernel will include a custom {class}`tinygp.kernels.Transform` that will pass the input coordinates through a (small) neural network before passing them into a {class}`tinygp.kernels.Matern32` kernel.\n", - "Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`." + "Otherwise, the model and optimization procedure are similar to the ones used in {ref}`modeling-flax`.\n", + "\n", + "We compare the performance of the Deep Matern-3/2 kernel (a {class}`tinygp.kernels.Matern32` kernel, with custom neural network transform) to the performance of the same kernel without the transform. The untransformed model doesn't have the capacity to capture our simulated step function, but our transformed model does. In our transformed model, the hyperparameters of our kernel now include the weights of our neural network transform, and we learn those simultaneously with the length scale and amplitude of the `Matern32` kernel." ] }, { "cell_type": "code", "execution_count": null, - "id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f", + "id": "94938ebe", "metadata": {}, "outputs": [], "source": [ @@ -102,11 +104,49 @@ "import jax.numpy as jnp\n", "import flax.linen as nn\n", "from flax.linen.initializers import zeros\n", - "from tinygp import kernels, transforms, GaussianProcess\n", - "\n", + "from tinygp import kernels, transforms, GaussianProcess" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21e63e64", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "class Matern32Loss(nn.Module):\n", + " @nn.compact\n", + " def __call__(self, x, y, t):\n", + " # Set up a typical Matern-3/2 kernel\n", + " log_sigma = self.param(\"log_sigma\", zeros, ())\n", + " log_rho = self.param(\"log_rho\", zeros, ())\n", + " log_jitter = self.param(\"log_jitter\", zeros, ())\n", + " base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(\n", + " jnp.exp(log_rho)\n", + " )\n", "\n", - "# Define a small neural network used to non-linearly transform the input data in our model\n", + " # Evaluate and return the GP negative log likelihood as usual\n", + " gp = GaussianProcess(\n", + " base_kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n", + " )\n", + " log_prob, gp_cond = gp.condition(y, t[:, None])\n", + " return -log_prob, (gp_cond.loc, gp_cond.variance)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0065dea-379a-4e0d-8cf2-f460c8126a5f", + "metadata": {}, + "outputs": [], + "source": [ "class Transformer(nn.Module):\n", + " \"\"\"A small neural network used to non-linearly transform the input data\"\"\"\n", + "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Dense(features=15)(x)\n", @@ -117,7 +157,7 @@ " return x\n", "\n", "\n", - "class GPLoss(nn.Module):\n", + "class DeepLoss(nn.Module):\n", " @nn.compact\n", " def __call__(self, x, y, t):\n", " # Set up a typical Matern-3/2 kernel\n", @@ -128,56 +168,152 @@ " jnp.exp(log_rho)\n", " )\n", "\n", - " # Define a custom transform to pass the input coordinates through our `Transformer`\n", - " # network from above\n", + " # Define a custom transform to pass the input coordinates through our\n", + " # `Transformer` network from above\n", " transform = Transformer()\n", " kernel = transforms.Transform(transform, base_kernel)\n", "\n", - " # Evaluate and return the GP negative log likelihood as usual\n", + " # Evaluate and return the GP negative log likelihood as usual with the\n", + " # transformed features\n", " gp = GaussianProcess(\n", " kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)\n", " )\n", " log_prob, gp_cond = gp.condition(y, t[:, None])\n", - " return -log_prob, (gp_cond.loc, gp_cond.variance)\n", + "\n", + " # We return the loss, the conditional mean and variance, and the\n", + " # transformed input parameters\n", + " return (\n", + " -log_prob,\n", + " (gp_cond.loc, gp_cond.variance),\n", + " (transform(x[:, None]), transform(t[:, None])),\n", + " )\n", "\n", "\n", "# Define and train the model\n", - "def loss(params):\n", - " return model.apply(params, x, y, t)[0]\n", + "def loss_func(model):\n", + " def loss(params):\n", + " return model.apply(params, x, y, t)[0]\n", "\n", + " return loss\n", "\n", - "model = GPLoss()\n", - "params = model.init(jax.random.PRNGKey(1234), x, y, t)\n", - "tx = optax.sgd(learning_rate=1e-4)\n", - "opt_state = tx.init(params)\n", - "loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n", - "for i in range(1000):\n", - " loss_val, grads = loss_grad_fn(params)\n", - " updates, opt_state = tx.update(grads, opt_state)\n", - " params = optax.apply_updates(params, updates)\n", "\n", + "models_list, params_list = [], []\n", + "loss_vals = {}\n", "# Plot the results and compare to the true model\n", - "plt.figure()\n", - "mu, var = model.apply(params, x, y, t)[1]\n", - "plt.plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n", - "plt.plot(x, y, \".k\", label=\"data\")\n", - "plt.plot(t, mu)\n", - "plt.fill_between(\n", - " t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n", - ")\n", - "plt.xlim(-1.5, 1.5)\n", - "plt.ylim(-1.3, 1.3)\n", - "plt.xlabel(\"x\")\n", - "plt.ylabel(\"y\")\n", + "fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n", + "for it, (model_name, model) in enumerate(\n", + " zip(\n", + " [\"Deep\", \"Matern32\"],\n", + " [DeepLoss(), Matern32Loss()],\n", + " )\n", + "):\n", + " loss_vals[it] = []\n", + " params = model.init(jax.random.PRNGKey(1234), x, y, t)\n", + " tx = optax.sgd(learning_rate=1e-4)\n", + " opt_state = tx.init(params)\n", + "\n", + " loss = loss_func(model)\n", + " loss_grad_fn = jax.jit(jax.value_and_grad(loss))\n", + " for i in range(1000):\n", + " loss_val, grads = loss_grad_fn(params)\n", + " updates, opt_state = tx.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " loss_vals[it].append(loss_val)\n", + "\n", + " mu, var = model.apply(params, x, y, t)[1]\n", + " ax[it].plot(t, 2 * (t > 0) - 1, \"k\", lw=1, label=\"truth\")\n", + " ax[it].plot(x, y, \".k\", label=\"data\")\n", + " ax[it].plot(t, mu)\n", + " ax[it].fill_between(\n", + " t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label=\"model\"\n", + " )\n", + " ax[it].set_xlim(-1.5, 1.5)\n", + " ax[it].set_ylim(-1.3, 1.3)\n", + " ax[it].set_xlabel(\"x\")\n", + " ax[it].set_ylabel(\"y\")\n", + " ax[it].set_title(model_name)\n", + " _ = ax[it].legend()\n", + "\n", + " models_list.append(model)\n", + " params_list.append(params)" + ] + }, + { + "cell_type": "markdown", + "id": "bb4d5f08", + "metadata": {}, + "source": [ + "The untransformed `Matern32` model suffers from over-smoothing at the discontinuity, and poor extrapolation performance.\n", + "The `Deep` model extrapolates well and captures the discontinuity reliably.\n", + "\n", + "We can compare the training loss (negative log likelihood) traces for these two models:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "feff3a28", + "metadata": {}, + "outputs": [], + "source": [ + "fig = plt.plot()\n", + "plt.plot(loss_vals[0], label=\"Deep\")\n", + "plt.plot(loss_vals[1], label=\"Matern32\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.xlabel(\"Training Iterations\")\n", "_ = plt.legend()" ] }, + { + "cell_type": "markdown", + "id": "5692e918", + "metadata": {}, + "source": [ + "To inspect what the transformed model is doing under the hood, we can plot the functional form of the transformation, as well as the transformed values of our input coordinates: " + ] + }, { "cell_type": "code", "execution_count": null, "id": "a281b035-513a-4215-87fd-1a83b52ebd79", "metadata": {}, "outputs": [], + "source": [ + "x_transform, t_transform = models_list[0].apply(params_list[0], x, y, t)[2]\n", + "\n", + "fig = plt.figure()\n", + "plt.plot(t, t_transform, \"k\")\n", + "plt.xlim(-1.5, 1.5)\n", + "plt.ylim(-1.3, 1.3)\n", + "plt.xlabel(\"input data; x\")\n", + "plt.ylabel(\"transformed data; x'\")\n", + "\n", + "fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))\n", + "for it, (fig_title, feature_input, x_label) in enumerate(\n", + " zip([\"Input Data\", \"Transformed Data\"], [x, x_transform], [\"x\", \"x'\"])\n", + "):\n", + " ax[it].plot(feature_input, y, \".k\")\n", + " ax[it].set_xlim(-1.5, 1.5)\n", + " ax[it].set_ylim(-1.3, 1.3)\n", + " ax[it].set_title(fig_title)\n", + " ax[it].set_xlabel(x_label)\n", + " ax[it].set_ylabel(\"y\")" + ] + }, + { + "cell_type": "markdown", + "id": "673435d3", + "metadata": {}, + "source": [ + "The neural network transforms the input feature into a step function like data (as shown in the figures above) before feeding to the base kernel, making it better suited than the baseline model for this data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d805ca0", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/news/70.doc b/news/70.doc new file mode 100644 index 00000000..163fdd62 --- /dev/null +++ b/news/70.doc @@ -0,0 +1,3 @@ +Add more details to Deep Kernel learning tutorial, +showing comparison with Matern-3/2 kernel +and the transformed features.