From 92afafcf42751be87ff1c286a0ccbda63e136a25 Mon Sep 17 00:00:00 2001 From: Marcel Arpogaus <38564291+MArpogaus@users.noreply.github.com> Date: Fri, 12 Jul 2024 15:33:22 +0200 Subject: [PATCH] doc: add / update type annotations and docstrings --- src/bernstein_flow/__init__.py | 2 +- src/bernstein_flow/activations/__init__.py | 65 ++++++-- src/bernstein_flow/bijectors/bernstein.py | 83 ++++++----- .../distributions/bernstein_flow.py | 139 ++++++++++-------- .../losses/bernstein_flow_loss.py | 6 +- src/bernstein_flow/math/__init__.py | 1 + src/bernstein_flow/math/bernstein.py | 15 +- src/bernstein_flow/util/__init__.py | 18 ++- .../util/visualization/__init__.py | 121 ++++++++++++--- .../util/visualization/plot_flow.py | 27 ++-- 10 files changed, 313 insertions(+), 164 deletions(-) diff --git a/src/bernstein_flow/__init__.py b/src/bernstein_flow/__init__.py index 7c19b9d..a0a8a1d 100644 --- a/src/bernstein_flow/__init__.py +++ b/src/bernstein_flow/__init__.py @@ -1 +1 @@ -""".. include:: ../../README.md""" +""".. include:: ../../README.md""" # noqa: D415,D400 diff --git a/src/bernstein_flow/activations/__init__.py b/src/bernstein_flow/activations/__init__.py index 0e724e3..082a7c0 100644 --- a/src/bernstein_flow/activations/__init__.py +++ b/src/bernstein_flow/activations/__init__.py @@ -1,15 +1,30 @@ # -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- -# AUTHOR INFORMATION ########################################################## +# %% Author #################################################################### # file : __init__.py -# author : Marcel Arpogaus +# author : Marcel Arpogaus # -# created : 2022-03-10 15:39:04 (Marcel Arpogaus) -# changed : 2024-07-09 18:08:52 (Marcel Arpogaus) -# DESCRIPTION ################################################################# -# ... -# LICENSE ##################################################################### -# ... -############################################################################### +# created : 2024-07-12 15:12:18 (Marcel Arpogaus) +# changed : 2024-07-12 15:20:43 (Marcel Arpogaus) + +# %% License ################################################################### +# Copyright 2024 Marcel Arpogaus +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% Description ############################################################### +"""Activation functions applied to unconstrained outputs of neural networks.""" + +# %% imports ################################################################### import tensorflow as tf from tensorflow_probability.python.internal import ( dtype_util, @@ -18,17 +33,39 @@ ) +# %% functions ################################################################# def get_thetas_constrain_fn( - bounds=(None, None), - smooth_bounds=True, - allow_flexible_bounds=False, + bounds: tuple[float | None, float | None] = (None, None), + smooth_bounds: bool = True, + allow_flexible_bounds: bool = False, fn=tf.math.softplus, - eps=1e-5, + eps: float = 1e-5, ): + """Return a function that constrains the output of a neural network. + + Parameters + ---------- + bounds + The lower and upper bounds of the output. + smooth_bounds + Whether to ensure smooth the bounds by enforcing `Be''(0)==Be(1)==0`. + allow_flexible_bounds + Whether to allow the bounds to be flexible, i.e. to depend on the input. + fn + The positive definite function to apply to the unconstrained parameters. + eps + A small number to add to the output to avoid numerical issues. + + Returns + ------- + callable + A function that constrains the output of a neural network. + + """ low, high = bounds # @tf.function - def constrain_fn(diff): + def constrain_fn(diff: tf.Tensor): dtype = dtype_util.common_dtype([diff], dtype_hint=tf.float32) diff = tensor_util.convert_nonref_to_tensor(diff, name="diff", dtype=dtype) diff --git a/src/bernstein_flow/bijectors/bernstein.py b/src/bernstein_flow/bijectors/bernstein.py index 598bd5e..fea2f52 100644 --- a/src/bernstein_flow/bijectors/bernstein.py +++ b/src/bernstein_flow/bijectors/bernstein.py @@ -1,17 +1,12 @@ -# AUTHOR INFORMATION ########################################################## -# file : bernstein_bijector.py -# brief : [Description] +# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- +# %% Author #################################################################### +# file : bernstein.py +# author : Marcel Arpogaus # -# author : Marcel Arpogaus -# created : 2020-09-11 14:14:24 -# changed : 2020-12-07 16:29:11 -# DESCRIPTION ################################################################# -# -# This project is following the PEP8 style guide: -# -# https://www.python.org/dev/peps/pep-0008/) -# -# COPYRIGHT ################################################################### +# created : 2024-07-12 14:52:28 (Marcel Arpogaus) +# changed : 2024-07-12 14:52:28 (Marcel Arpogaus) + +# %% License ################################################################### # Copyright 2020 Marcel Arpogaus # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,8 +21,10 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### +# %% Description ############################################################### +"""Implement a Bernstein Polynomial as a `tfp.Bijector`.""" -# REQUIRED PYTHON MODULES ##################################################### +# %% imports ################################################################### from functools import partial from typing import Tuple @@ -37,30 +34,46 @@ from bernstein_flow.activations import get_thetas_constrain_fn from bernstein_flow.math.bernstein import ( - gen_bernstein_polynomial_with_linear_extension, - gen_bernstein_polynomial_with_linear_extrapolation, + generate_bernstein_polynomial_with_linear_extension, + generate_bernstein_polynomial_with_linear_extrapolation, ) +# %% classes ################################################################### class BernsteinPolynomial(tfp.experimental.bijectors.ScalarFunctionWithInferredInverse): - """Implementation of a Bernstein polynomials Bijector.""" + """Implement a Bernstein polynomials Bijector. + + This bijector transforms an input tensor by applying a Bernstein polynomial. + + Attributes + ---------- + thetas: The Bernstein coefficients. + extrapolation: The method to extrapolate outside of bounds. + analytic_jacobian: Whether to use the analytic Jacobian. + domain: The domain of the Bernstein polynomial. + name: The name to give Ops created by the initializer. + + """ def __init__( self, thetas: tf.Tensor, extrapolation: bool = True, analytic_jacobian: bool = True, - domain: Tuple[int, ...] = None, + domain: Tuple[float, float] = (0.0, 1.0), name: str = "bernstein_bijector", - **kwds, + **kwargs, ) -> None: - """Constructs a new instance of a Bernstein polynomial bijector. + """Construct a new instance of a Bernstein polynomial bijector. - :param thetas: The Bernstein coefficients. - :type thetas: tf.Tensor - :param extrapolation: Method to extrapolate outside of bounds. - :type extrapolation: str - :param name: The name to give Ops created by the initializer. + Args: + ---- + thetas: The Bernstein coefficients. + extrapolation: The method to extrapolate outside of bounds. + analytic_jacobian: Whether to use the analytic Jacobian. + domain: The domain of the Bernstein polynomial. + name: The name to give Ops created by the initializer. + kwargs: Keyword arguments for the parent class. """ with tf.name_scope(name) as name: @@ -76,7 +89,7 @@ def __init__( forward_log_det_jacobian, self.b_poly_inverse_extra, self.order, - ) = gen_bernstein_polynomial_with_linear_extrapolation( + ) = generate_bernstein_polynomial_with_linear_extrapolation( self.thetas, domain=domain ) else: @@ -85,16 +98,12 @@ def __init__( forward_log_det_jacobian, self.b_poly_inverse_extra, self.order, - ) = gen_bernstein_polynomial_with_linear_extension( + ) = generate_bernstein_polynomial_with_linear_extension( self.thetas, domain=domain ) - if domain: - low = tf.convert_to_tensor(domain[0], dtype=dtype) - high = tf.convert_to_tensor(domain[1], dtype=dtype) - else: - low = tf.convert_to_tensor(0, dtype=dtype) - high = tf.convert_to_tensor(1, dtype=dtype) + low = tf.convert_to_tensor(domain[0], dtype=dtype) + high = tf.convert_to_tensor(domain[1], dtype=dtype) if analytic_jacobian: self._forward_log_det_jacobian = forward_log_det_jacobian @@ -125,10 +134,11 @@ def root_search_fn(objective_fn, _, max_iterations=None): max_iterations=50, name=name, dtype=dtype, - **kwds, + **kwargs, ) - def _inverse_no_gradient(self, y): + def _inverse_no_gradient(self, y: tf.Tensor) -> tf.Tensor: + """Compute the inverse of the bijector.""" return tf.stop_gradient( self.b_poly_inverse_extra(y, inverse_approx_fn=super()._inverse_no_gradient) ) @@ -141,5 +151,6 @@ def _parameter_properties(cls, dtype=None): ), ) - def _is_increasing(self, **kwargs): + def _is_increasing(self, **kwargs) -> bool: + """Check if the bijector is increasing.""" return tf.reduce_all(self.thetas[..., 1:] >= self.thetas[..., :-1]) diff --git a/src/bernstein_flow/distributions/bernstein_flow.py b/src/bernstein_flow/distributions/bernstein_flow.py index 2ae645f..d0c1739 100644 --- a/src/bernstein_flow/distributions/bernstein_flow.py +++ b/src/bernstein_flow/distributions/bernstein_flow.py @@ -1,18 +1,13 @@ # -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- -# AUTHOR INFORMATION ########################################################### +# %% Author #################################################################### # file : bernstein_flow.py # author : Marcel Arpogaus # -# created : 2020-05-15 10:44:23 (Marcel Arpogaus) -# changed : 2024-06-25 19:32:32 (Marcel Arpogaus) -# DESCRIPTION ################################################################## -# -# This project is following the PEP8 style guide: -# -# https://www.python.org/dev/peps/pep-0008/) -# -# COPYRIGHT #################################################################### -# Copyright 2020 Marcel Arpogaus +# created : 2024-07-12 14:55:22 (Marcel Arpogaus) +# changed : 2024-07-12 15:21:35 (Marcel Arpogaus) + +# %% License ################################################################### +# Copyright 2024 Marcel Arpogaus # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,10 +20,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -################################################################################ -# REQUIRED PYTHON MODULES ###################################################### -from typing import Callable, Dict, Optional +# %% Description ############################################################### +"""Normalizing flow using Bernstein Polynomial as transformation function.""" + +# %% imports ################################################################### +from typing import Any, Callable, Dict, Optional import tensorflow as tf import tensorflow_probability as tfp @@ -44,19 +41,22 @@ from bernstein_flow.bijectors import BernsteinPolynomial -def slice_parameter_vector(params: tf.Tensor, p_spec: dict) -> dict: - """Slices parameters of the given size from a tensor. +# %% functions ################################################################# +def slice_parameter_vector( + params: tf.Tensor, p_spec: Dict[str, int] +) -> Dict[str, tf.Tensor]: + """Slice parameters of the given size from a tensor. Parameters ---------- - params : tf.Tensor + params The parameter vector. - p_spec : dict + p_spec Specification of parameter sizes in the form {'parameter_name': size}. Returns ------- - dict + Dict[str, tf.Tensor] Dictionary containing the sliced parameters. """ @@ -71,21 +71,21 @@ def slice_parameter_vector(params: tf.Tensor, p_spec: dict) -> dict: def apply_constraining_bijectors( - unconstrained_parameters: dict, - thetas_constrain_fn: Optional[Callable] = None, -) -> dict: + unconstrained_parameters: Dict[str, tf.Tensor], + thetas_constraint_fn: Optional[Callable] = None, +) -> Dict[str, tf.Tensor]: """Apply activation functions to raw parameters. Parameters ---------- - unconstrained_parameters : dict + unconstrained_parameters Dictionary of raw parameters. - thetas_constrain_fn : Callable, optional + thetas_constraint_fn Function used to constrain the Bernstein coefficients, by default None. Returns ------- - dict + Dict[str, tf.Tensor] Dictionary with constrained parameters. """ @@ -93,8 +93,8 @@ def apply_constraining_bijectors( parameters = {} for parameter_name, parameter in unconstrained_parameters.items(): - if parameter_name == "thetas" and (thetas_constrain_fn is not None): - constraining_bijector = thetas_constrain_fn + if parameter_name == "thetas" and (thetas_constraint_fn is not None): + constraining_bijector = thetas_constraint_fn else: parameter_properties = BernsteinFlow.parameter_properties( dtype=parameter.dtype @@ -112,22 +112,24 @@ def init_bijectors( a1: Optional[tf.Tensor] = None, b1: Optional[tf.Tensor] = None, a2: Optional[tf.Tensor] = None, - **bernstein_bijector_kwargs, + **bernstein_bijector_kwargs: Dict[str, Any], ) -> tfb.Bijector: """Build a normalizing flow using a Bernstein polynomial as Bijector. Parameters ---------- - thetas : tf.Tensor + thetas The Bernstein coefficients. - clip_to_bernstein_domain : bool + clip_to_bernstein_domain Whether to clip to the Bernstein domain [0, 1]. - a1 : tf.Tensor, optional + a1 The scale of f1., by default None. - b1 : tf.Tensor, optional + b1 The shift of f1., by default None. - a2 : tf.Tensor, optional + a2 The scale of f3., by default None. + bernstein_bijector_kwargs + Keyword arguments passed to the `BernsteinPolynomial` Returns ------- @@ -167,16 +169,18 @@ def init_bijectors( def get_base_distribution( - base_distribution: str, dtype: tf.DType, **kwargs + base_distribution: str, dtype: tf.DType, **kwargs: Dict[str, Any] ) -> tfd.Distribution: """Get an instance of a base distribution. Parameters ---------- - base_distribution : str + base_distribution Name of the base distribution. - dtype : tf.DType + dtype Data type of the distribution. + kwargs + Keyword arguments passed to the Distribution class. Returns ------- @@ -221,8 +225,9 @@ def get_base_distribution( return dist +# %% classes ################################################################### class BernsteinFlow(tfd.TransformedDistribution): - """Implements a `tfd.TransformedDistribution` using Bernstein polynomials.""" + """Implement a `tfd.TransformedDistribution` using Bernstein polynomials.""" def __init__( self, @@ -231,31 +236,33 @@ def __init__( b1: Optional[tf.Tensor] = None, a2: Optional[tf.Tensor] = None, base_distribution: str = "normal", - base_distribution_kwargs: dict = {}, + base_distribution_kwargs: Dict[str, Any] = {}, clip_to_bernstein_domain: bool = False, name: Optional[str] = None, - **bernstein_bijector_kwargs, + **kwargs: Dict[str, Any], ) -> None: """Initialize the BernsteinFlow. Parameters ---------- - thetas : tf.Tensor + thetas The Bernstein coefficients. - a1 : tf.Tensor, optional + a1 The scale of f1., by default None. - b1 : tf.Tensor, optional + b1 The shift of f1., by default None. - a2 : tf.Tensor, optional + a2 The scale of f3., by default None. - base_distribution : str, optional + base_distribution The base distribution, by default "normal". - base_distribution_kwargs : dict, optional + base_distribution_kwargs Keyword arguments of the base distribution, by default {}. - clip_to_bernstein_domain : bool, optional + clip_to_bernstein_domain Whether to clip to the Bernstein domain [0, 1], by default False. - name : str, optional + name The name of the flow, by default "BernsteinFlow". + kwargs + Keyword arguments passed to `init_bijectors`. """ parameters = dict(locals()) @@ -275,7 +282,9 @@ def __init__( if tf.is_tensor(a2): a2 = tensor_util.convert_nonref_to_tensor(a2, dtype=dtype, name="a2") - base_distribution = get_base_distribution(base_distribution, dtype) + base_distribution = get_base_distribution( + base_distribution, dtype, **base_distribution_kwargs + ) bijector = init_bijectors( thetas, @@ -283,7 +292,7 @@ def __init__( b1=b1, a2=a2, clip_to_bernstein_domain=clip_to_bernstein_domain, - **bernstein_bijector_kwargs, + **kwargs, ) super().__init__( @@ -316,34 +325,40 @@ def new( scale_data: bool, shift_data: bool, scale_base_distribution: bool, - get_thetas_constrain_fn: Callable = get_thetas_constrain_fn, + get_thetas_constraint_fn: Callable = get_thetas_constrain_fn, base_distribution: str = "normal", - base_distribution_kwargs: dict = {}, + base_distribution_kwargs: Dict[str, Any] = {}, clip_to_bernstein_domain: bool = False, name: Optional[str] = None, - bernstein_bijector_kwargs: Dict = {}, - **kwargs, + bernstein_bijector_kwargs: Dict[str, Any] = {}, + **kwargs: Dict[str, Any], ) -> "BernsteinFlow": """Create the distribution instance from a `params` vector. Parameters ---------- - params : tf.Tensor + params The parameters of the flow. - scale_data : bool + scale_data Whether to scale the data. - shift_data : bool + shift_data Whether to shift the data. - scale_base_distribution : bool + scale_base_distribution Whether to scale the base distribution. - base_distribution : str, optional + get_thetas_constraint_fn + Function returning a constrain function for the Bernstein coefficients. + base_distribution The base distribution, by default "normal". - base_distribution_kwargs : dict, optional + base_distribution_kwargs Keyword arguments for the base distribution, by default {}. - clip_to_bernstein_domain : bool, optional + clip_to_bernstein_domain Whether to clip to the Bernstein domain [0, 1], by default False. - name : str, optional + name The name of the flow, by default "BernsteinFlow". + bernstein_bijector_kwargs + Keyword arguments for the Bernstein bijector, by default {}. + kwargs + Keyword arguments passed to `get_thetas_constraint_fn`. Returns ------- @@ -380,7 +395,7 @@ def new( return BernsteinFlow( **apply_constraining_bijectors( unconstrained_parameters=slice_parameter_vector(params, p_spec), - thetas_constrain_fn=get_thetas_constrain_fn(**kwargs), + thetas_constraint_fn=get_thetas_constraint_fn(**kwargs), ), base_distribution=base_distribution, base_distribution_kwargs=base_distribution_kwargs, diff --git a/src/bernstein_flow/losses/bernstein_flow_loss.py b/src/bernstein_flow/losses/bernstein_flow_loss.py index 40dc23d..18aa73d 100644 --- a/src/bernstein_flow/losses/bernstein_flow_loss.py +++ b/src/bernstein_flow/losses/bernstein_flow_loss.py @@ -4,7 +4,7 @@ # author : Marcel Arpogaus # # created : 2024-07-10 10:13:31 (Marcel Arpogaus) -# changed : 2024-07-10 10:13:31 (Marcel Arpogaus) +# changed : 2024-07-12 15:17:17 (Marcel Arpogaus) # %% License ################################################################### # Copyright 2020 Marcel Arpogaus @@ -33,9 +33,7 @@ # %% Classes ################################################################### class BernsteinFlowLoss(Loss): - """This Keras Loss function implements the negative logarithmic likelihood for - a bijective transformation model using Bernstein polynomials. - """ + """NLL of a bijective transformation model using Bernstein polynomials.""" def __init__(self, **kwargs: dict) -> None: """Construct a new instance of the Keras Loss function. diff --git a/src/bernstein_flow/math/__init__.py b/src/bernstein_flow/math/__init__.py index e69de29..feb6276 100644 --- a/src/bernstein_flow/math/__init__.py +++ b/src/bernstein_flow/math/__init__.py @@ -0,0 +1 @@ +"""Defining all the math required.""" diff --git a/src/bernstein_flow/math/bernstein.py b/src/bernstein_flow/math/bernstein.py index 699ad4c..540effa 100644 --- a/src/bernstein_flow/math/bernstein.py +++ b/src/bernstein_flow/math/bernstein.py @@ -4,10 +4,10 @@ # author : Marcel Arpogaus # # created : 2024-07-10 10:10:18 (Marcel Arpogaus) -# changed : 2024-07-10 10:10:28 (Marcel Arpogaus) +# changed : 2024-07-12 15:23:22 (Marcel Arpogaus) # %% License ################################################################### -# Copyright 2020 Marcel Arpogaus +# Copyright 2024 Marcel Arpogaus # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ # %% Description ############################################################### """Mathematical definitions of Bernstein Polynomials.""" + # %% Imports ################################################################### from typing import Callable, Optional, Tuple @@ -415,7 +416,7 @@ def transform_to_support(y: Tensor, low: Tensor, high: Tensor) -> Tensor: return y * (high - low) + low -def gen_bernstein_polynomial_with_extrapolation( +def generate_bernstein_polynomial_with_extrapolation( theta: Tensor, gen_extrapolation_fn: Callable = gen_linear_extrapolation, domain: Optional[Tuple[Tensor, Tensor]] = None, @@ -576,7 +577,7 @@ def bpoly_inverse_extra_scaled( return bpoly_extra, bpoly_log_det_jacobian_extra, bpoly_inverse_extra, order -def gen_bernstein_polynomial_with_linear_extension(*args, **kwargs): +def generate_bernstein_polynomial_with_linear_extension(*args, **kwargs): """Generate a Bernstein polynomial with linear extension. Args: @@ -591,12 +592,12 @@ def gen_bernstein_polynomial_with_linear_extension(*args, **kwargs): `gen_extrapolation_fn` set to `gen_linear_extension`. """ - return gen_bernstein_polynomial_with_extrapolation( + return generate_bernstein_polynomial_with_extrapolation( *args, gen_extrapolation_fn=gen_linear_extension, **kwargs ) -def gen_bernstein_polynomial_with_linear_extrapolation(*args, **kwargs): +def generate_bernstein_polynomial_with_linear_extrapolation(*args, **kwargs): """Generate a Bernstein polynomial with linear extrapolation. Args: @@ -611,6 +612,6 @@ def gen_bernstein_polynomial_with_linear_extrapolation(*args, **kwargs): `gen_extrapolation_fn` set to `gen_linear_extrapolation`. """ - return gen_bernstein_polynomial_with_extrapolation( + return generate_bernstein_polynomial_with_extrapolation( *args, gen_extrapolation_fn=gen_linear_extrapolation, **kwargs ) diff --git a/src/bernstein_flow/util/__init__.py b/src/bernstein_flow/util/__init__.py index 97262c2..359db24 100644 --- a/src/bernstein_flow/util/__init__.py +++ b/src/bernstein_flow/util/__init__.py @@ -1,7 +1,23 @@ +"""Defines helper functions for training and plotting.""" + from functools import partial +from typing import Callable from ..distributions import BernsteinFlow -def gen_flow(**kwds): +def gen_flow(**kwds: int) -> Callable[..., BernsteinFlow]: + """Generate a Bernstein flow factory. + + Parameters + ---------- + kwds + Keyword arguments to pass to :meth:`BernsteinFlow.new`. + + Returns + ------- + FlowFactory + A factory function that creates Bernstein flows. + + """ return partial(BernsteinFlow.new, **kwds) diff --git a/src/bernstein_flow/util/visualization/__init__.py b/src/bernstein_flow/util/visualization/__init__.py index 507a9d8..7fa8f02 100644 --- a/src/bernstein_flow/util/visualization/__init__.py +++ b/src/bernstein_flow/util/visualization/__init__.py @@ -1,18 +1,13 @@ -# AUTHOR INFORMATION ########################################################## -# file : visualization.py -# brief : [Description] +# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- +# %% Author #################################################################### +# file : __init__.py +# author : Marcel Arpogaus # -# author : Marcel Arpogaus -# created : 2020-04-13 16:04:37 -# changed : 2020-10-26 10:55:48 -# DESCRIPTION ################################################################# -# -# This project is following the PEP8 style guide: -# -# https://www.python.org/dev/peps/pep-0008/) -# -# LICENSE ##################################################################### -# Copyright 2020 Marcel Arpogaus +# created : 2024-07-12 14:48:27 (Marcel Arpogaus) +# changed : 2024-07-12 14:48:27 (Marcel Arpogaus) + +# %% License ################################################################### +# Copyright 2024 Marcel Arpogaus # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,8 +20,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -############################################################################### -# REQUIRED PYTHON MODULES ##################################################### + +# %% description ############################################################### +"""Defines functions to create som insigtfull plots.""" + +# %% imports ################################################################### import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt @@ -38,8 +36,27 @@ from .plot_flow import plot_flow # noqa F401 -# function definitions ######################################################## -def vizualize_flow_from_z_domain(flow, z_min=-3, z_max=3): +# %% functions ################################################################# +def visualize_flow_from_z_domain( + flow: tfp.bijectors.Bijector, z_min: float = -3, z_max: float = 3 +) -> plt.Figure: + """Visualize a flow from the z-domain to the y-domain. + + Parameters + ---------- + flow : tfp.bijectors.Bijector + The flow to visualize. + z_min : float, optional + The minimum value of z, by default -3 + z_max : float, optional + The maximum value of z, by default 3 + + Returns + ------- + plt.Figure + The figure with the plots. + + """ bijector = flow.bijector base_dist = flow.distribution @@ -53,7 +70,7 @@ def vizualize_flow_from_z_domain(flow, z_min=-3, z_max=3): y_samples = np.squeeze(bijector.forward(z_samples)) # p_y(y) = p_z(h^-1(y))*|h^-1'(y)| = p_z(z)*|h^-1'(y)| - ildj = bijector.inverse_log_det_jacobian(y_samples, 0) + ildj = bijector.inverse_log_det_jacobian(y_samples, event_ndims=0) log_prob = base_dist.log_prob(z_samples) log_prob = log_prob + ildj @@ -92,7 +109,8 @@ def vizualize_flow_from_z_domain(flow, z_min=-3, z_max=3): p_mu_z = base_dist.prob(mu_z) p_mu_y = np.exp( - base_dist.log_prob(mu_z) + bijector.inverse_log_det_jacobian(mu_y, 0) + base_dist.log_prob(mu_z) + + bijector.inverse_log_det_jacobian(mu_y, event_ndims=0) ) cp_kwds = dict(color="darkgray", lw=1, ls="--", arrowstyle="->") @@ -122,7 +140,20 @@ def vizualize_flow_from_z_domain(flow, z_min=-3, z_max=3): return fig -def plot_chained_bijectors(flow): +def plot_chained_bijectors(flow: tfp.bijectors.Bijector) -> plt.Figure: + """Plot the chain of bijectors in a flow. + + Parameters + ---------- + flow : tfp.bijectors.Bijector + The flow to plot. + + Returns + ------- + plt.Figure + The figure of the plot. + + """ chained_bijectors = flow.bijector.bijector.bijectors base_dist = flow.distribution cols = len(chained_bijectors) + 1 @@ -141,7 +172,7 @@ def plot_chained_bijectors(flow): for i, (a, b) in enumerate(zip(ax[1:], chained_bijectors)): # we need to use the inverse here since we are going from z->y! z = b.inverse(zz) - ildj += b.forward_log_det_jacobian(z, 1) + ildj += b.forward_log_det_jacobian(z, event_ndims=1) # print(z.shape, zz.shape, ildj.shape) a.scatter(z, np.exp(log_probs + ildj)) a.set_title(b.name.replace("_", " ")) @@ -151,7 +182,34 @@ def plot_chained_bijectors(flow): return fig -def plot_x_trafo(flow, xmin=-1, xmax=1, n=20, size=3): +def plot_x_trafo( + flow: tfp.bijectors.Bijector, + xmin: float = -1, + xmax: float = 1, + n: int = 20, + size: int = 3, +) -> plt.Figure: + """Plot the transformation of x for each bijector in a flow. + + Parameters + ---------- + flow : tfp.bijectors.Bijector + The flow to plot. + xmin : float, optional + The minimum value of x, by default -1 + xmax : float, optional + The maximum value of x, by default 1 + n : int, optional + The number of points to plot, by default 20 + size : int, optional + The size of the plot, by default 3 + + Returns + ------- + plt.Figure + The figure of the plot. + + """ x = np.linspace(xmin, xmax, n, dtype=np.float32) pos = n // 2 con_kwds = dict( @@ -229,7 +287,22 @@ def plot_x_trafo(flow, xmin=-1, xmax=1, n=20, size=3): return fig -def plot_value_and_gradient(func, y): +def plot_value_and_gradient(func: callable, y: np.ndarray) -> plt.Figure: + """Plot the value and gradient of a function. + + Parameters + ---------- + func : callable + The function to plot. + y : np.ndarray + The values to evaluate the function at. + + Returns + ------- + plt.Figure + The figure of the plot. + + """ [funval, grads] = tfp.math.value_and_gradient(func, y) fig, ax = plt.subplots(1, 2, figsize=(16, 8), constrained_layout=True) diff --git a/src/bernstein_flow/util/visualization/plot_flow.py b/src/bernstein_flow/util/visualization/plot_flow.py index 01b761e..9c4fe04 100644 --- a/src/bernstein_flow/util/visualization/plot_flow.py +++ b/src/bernstein_flow/util/visualization/plot_flow.py @@ -1,19 +1,13 @@ -"""Convenience Function to plot a normalizing flow.""" # -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- -# AUTHOR INFORMATION ########################################################## +# %% Author #################################################################### # file : plot_flow.py -# author : Marcel Arpogaus -# -# created : 2022-06-01 15:21:22 (Marcel Arpogaus) -# changed : 2024-06-13 22:57:11 (Marcel Arpogaus) -# DESCRIPTION ################################################################# -# -# This project is following the PEP8 style guide: -# -# https://www.python.org/dev/peps/pep-0008/ +# author : Marcel Arpogaus # -# LICENSE ##################################################################### -# Copyright 2020 Marcel Arpogaus +# created : 2024-07-12 14:49:21 (Marcel Arpogaus) +# changed : 2024-07-12 14:49:21 (Marcel Arpogaus) + +# %% License ################################################################### +# Copyright 2024 Marcel Arpogaus # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +20,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -############################################################################### -# REQUIRED PYTHON MODULES ##################################################### + +# %% Description ############################################################### +"""Convenience Function to plot a normalizing flow.""" + +# %% imports ################################################################### from functools import partial from typing import Dict, List, Tuple