Skip to content

Commit

Permalink
[Tests] Execute tutorial notebook in pytest suite
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiserls committed May 28, 2024
1 parent 89ba1f5 commit f872eb2
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 4 deletions.
35 changes: 35 additions & 0 deletions docs/source/tutorial_material/temperature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

import jax.numpy as jnp
import numpy as np

from eulerpi.core.model import Model


class Temperature(Model):

param_dim = 1
data_dim = 1

PARAM_LIMITS = np.array([[0, np.pi / 2]])
CENTRAL_PARAM = np.array([np.pi / 4.0])

def __init__(
self,
central_param: np.ndarray = CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)

def forward(self, param):
low_T = -30.0
high_T = 30.0
res = jnp.array(
[low_T + (high_T - low_T) * jnp.cos(jnp.abs(param[0]))]
)
return res

def jacobian(self, param):
return jnp.array([60.0 * jnp.sin(jnp.abs(param[0]))])
9 changes: 6 additions & 3 deletions docs/source/tutorial_material/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@
"outputs": [],
"source": [
"from typing import Optional\n",
"from eulerpi.core.model import Model\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"from eulerpi.core.model import Model\n",
"\n",
"\n",
"class Temperature(Model):\n",
"\n",
Expand Down Expand Up @@ -159,7 +162,7 @@
" return res\n",
"\n",
" def jacobian(self, param):\n",
" return jnp.array([60.0 * jnp.sin(jnp.abs(param[0]))])\n"
" return jnp.array([60.0 * jnp.sin(jnp.abs(param[0]))])"
]
},
{
Expand Down
Loading

0 comments on commit f872eb2

Please sign in to comment.