Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 2, 2024
1 parent 64a2d26 commit e236d9e
Show file tree
Hide file tree
Showing 14 changed files with 45 additions and 43 deletions.
9 changes: 4 additions & 5 deletions docs/benchmarks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,12 @@
"source": [
"from functools import partial\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import celerite2\n",
"import george\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import george\n",
"import celerite2\n",
"import tinygp\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/tutorials/derivative.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n",
"\n",
"X = np.linspace(0.0, 5 * np.pi, 50)\n",
"y = np.concatenate(\n",
Expand Down Expand Up @@ -97,10 +96,11 @@
"metadata": {},
"outputs": [],
"source": [
"import tinygp\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import tinygp\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"\n",
Expand Down Expand Up @@ -169,10 +169,10 @@
" )[0]\n",
"\n",
" plt.figure()\n",
" plt.plot(dt, k00, label=\"$\\mathrm{cov}(f,\\,f)$\", lw=1)\n",
" plt.plot(dt, k01, label=\"$\\mathrm{cov}(f,\\,\\dot{f})$\", lw=1)\n",
" plt.plot(dt, k10, label=\"$\\mathrm{cov}(\\dot{f},\\,f)$\", lw=1)\n",
" plt.plot(dt, k11, label=\"$\\mathrm{cov}(\\dot{f},\\,\\dot{f})$\", lw=1)\n",
" plt.plot(dt, k00, label=r\"$\\mathrm{cov}(f,\\,f)$\", lw=1)\n",
" plt.plot(dt, k01, label=r\"$\\mathrm{cov}(f,\\,\\dot{f})$\", lw=1)\n",
" plt.plot(dt, k10, label=r\"$\\mathrm{cov}(\\dot{f},\\,f)$\", lw=1)\n",
" plt.plot(dt, k11, label=r\"$\\mathrm{cov}(\\dot{f},\\,\\dot{f})$\", lw=1)\n",
" plt.legend()\n",
" plt.xlabel(r\"$\\Delta t$\")\n",
" plt.xlim(dt.min(), dt.max())\n",
Expand Down
7 changes: 4 additions & 3 deletions docs/tutorials/geometry.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from tinygp import kernels, GaussianProcess\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"\n",
"def plot_kernel(kernel, **kwargs):\n",
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials/kernels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@
"metadata": {},
"outputs": [],
"source": [
"import tinygp\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import tinygp\n",
"\n",
"\n",
"class SpectralMixture(tinygp.kernels.Kernel):\n",
" weight: jax.Array\n",
Expand Down Expand Up @@ -85,8 +86,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"\n",
"def build_gp(theta):\n",
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials/likelihoods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(203618)\n",
"x = np.linspace(-3, 3, 20)\n",
Expand Down Expand Up @@ -90,7 +90,8 @@
"import jax.numpy as jnp\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from tinygp import kernels, GaussianProcess\n",
"\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
7 changes: 4 additions & 3 deletions docs/tutorials/means.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
"outputs": [],
"source": [
"from functools import partial\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down Expand Up @@ -172,7 +173,7 @@
"metadata": {},
"outputs": [],
"source": [
"from tinygp import kernels, GaussianProcess\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"\n",
"def build_gp(params):\n",
Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/mixture.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from tinygp import GaussianProcess, kernels, transforms\n",
"\n",
Expand Down
10 changes: 4 additions & 6 deletions docs/tutorials/modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(42)\n",
"\n",
Expand Down Expand Up @@ -120,15 +120,13 @@
"metadata": {},
"outputs": [],
"source": [
"from tinygp import kernels, GaussianProcess\n",
"\n",
"import flax.linen as nn\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import flax.linen as nn\n",
"import optax\n",
"from flax.linen.initializers import zeros\n",
"\n",
"import optax\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"\n",
"class GPModule(nn.Module):\n",
Expand Down
6 changes: 4 additions & 2 deletions docs/tutorials/multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@
"outputs": [],
"source": [
"import jax\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"from tinygp import kernels\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
Expand Down Expand Up @@ -117,8 +118,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(48392)\n",
"X = random.uniform(-5, 5, (100, 2))\n",
Expand Down Expand Up @@ -163,6 +164,7 @@
"outputs": [],
"source": [
"import jaxopt\n",
"\n",
"from tinygp import GaussianProcess, kernels, transforms\n",
"\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/quasisep-custom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(394)\n",
"t = np.sort(random.uniform(0, 10, 700))\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/quasisep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(42)\n",
"\n",
Expand Down Expand Up @@ -108,7 +108,7 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from tinygp import kernels, GaussianProcess\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/tutorials/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from statsmodels.datasets import co2\n",
"\n",
"data = co2.load_pandas().data\n",
Expand Down Expand Up @@ -102,8 +102,7 @@
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from tinygp import kernels, GaussianProcess\n",
"\n",
"from tinygp import GaussianProcess, kernels\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
Expand Down
9 changes: 5 additions & 4 deletions docs/tutorials/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"random = np.random.default_rng(567)\n",
"\n",
Expand Down Expand Up @@ -99,12 +99,13 @@
"metadata": {},
"outputs": [],
"source": [
"import flax.linen as nn\n",
"import jax\n",
"import optax\n",
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"import optax\n",
"from flax.linen.initializers import zeros\n",
"from tinygp import kernels, transforms, GaussianProcess"
"\n",
"from tinygp import GaussianProcess, kernels, transforms"
]
},
{
Expand Down

0 comments on commit e236d9e

Please sign in to comment.