Skip to content

v0.4.0: dict-based parameters!

Latest
Compare
Choose a tag to compare
@phinate phinate released this 04 Aug 10:39
783dec3

What's Changed

This refactor was largely inspired by me wanting to better keep track of which parameter is doing what! Thanks to many of the operations in jax working on arbitrary pytrees, I'm pursuing tracking parameters as dictionaries, and have adjusted the fit-level logic to now assume this. CI has been updated to include a different kind of HistFactory model structure that looks much more like what I'm building out elsewhere!

The assumptions for this library to work are now just this:

import equinox as eqx  # turn our class into a PyTree


class Model(eqx.Module):
    # any attributes here that are not valid jax types (e.g. str) need to be declared like:
    name: str = eqx.field(static=True)

    def logpdf(self, pars: dict[str, ArrayLike], data: Array) -> float | Array: ...
    def expected_data(self, pars: dict[str, ArrayLike]) -> Array: ...

In particular, note that logpdf returns a float (or scalar array) -- this means that pyhf models, if they are ever compatible again, would need to patch in lambda pars, data: model.logpdf(pars, data)[0].

Python 3.8 has been officially dropped, keeping in-step with libraries that this depends on (e.g. equinox).

  • remove intel macs from CI since jaxopt LBFGS-B does not correctly converge by @phinate in #59
  • Refactor that assumes parameters are in a key-value mapping by @phinate in #61

Full Changelog: v0.3.0...v0.4.0