diff --git a/src/bernstein_flow/bijectors/bernstein.py b/src/bernstein_flow/bijectors/bernstein.py index fea2f52..01638ba 100644 --- a/src/bernstein_flow/bijectors/bernstein.py +++ b/src/bernstein_flow/bijectors/bernstein.py @@ -4,7 +4,7 @@ # author : Marcel Arpogaus # # created : 2024-07-12 14:52:28 (Marcel Arpogaus) -# changed : 2024-07-12 14:52:28 (Marcel Arpogaus) +# changed : 2024-07-18 12:01:50 (Marcel Arpogaus) # %% License ################################################################### # Copyright 2020 Marcel Arpogaus @@ -47,11 +47,8 @@ class BernsteinPolynomial(tfp.experimental.bijectors.ScalarFunctionWithInferredI Attributes ---------- - thetas: The Bernstein coefficients. - extrapolation: The method to extrapolate outside of bounds. - analytic_jacobian: Whether to use the analytic Jacobian. - domain: The domain of the Bernstein polynomial. - name: The name to give Ops created by the initializer. + thetas + The Bernstein coefficients. """ @@ -66,14 +63,20 @@ def __init__( ) -> None: """Construct a new instance of a Bernstein polynomial bijector. - Args: - ---- - thetas: The Bernstein coefficients. - extrapolation: The method to extrapolate outside of bounds. - analytic_jacobian: Whether to use the analytic Jacobian. - domain: The domain of the Bernstein polynomial. - name: The name to give Ops created by the initializer. - kwargs: Keyword arguments for the parent class. + Parameters + ---------- + thetas + The Bernstein coefficients. + extrapolation + The method to extrapolate outside of bounds. + analytic_jacobian + Whether to use the analytic Jacobian. + domain + The domain of the Bernstein polynomial. + name + The name to give Ops created by the initializer. + kwargs + Keyword arguments for the parent class. """ with tf.name_scope(name) as name: diff --git a/src/bernstein_flow/math/bernstein.py b/src/bernstein_flow/math/bernstein.py index 540effa..cff6d19 100644 --- a/src/bernstein_flow/math/bernstein.py +++ b/src/bernstein_flow/math/bernstein.py @@ -4,7 +4,7 @@ # author : Marcel Arpogaus # # created : 2024-07-10 10:10:18 (Marcel Arpogaus) -# changed : 2024-07-12 15:23:22 (Marcel Arpogaus) +# changed : 2024-07-18 11:57:17 (Marcel Arpogaus) # %% License ################################################################### # Copyright 2024 Marcel Arpogaus @@ -24,7 +24,6 @@ # %% Description ############################################################### """Mathematical definitions of Bernstein Polynomials.""" - # %% Imports ################################################################### from typing import Callable, Optional, Tuple @@ -33,43 +32,45 @@ from tensorflow_probability import distributions as tfd from tensorflow_probability.python.internal import dtype_util, prefer_static -# %% Globals ################################################################### -Tensor = tf.Tensor -Distribution = tfd.Distribution - # %% Functions ################################################################# def reshape_output( - batch_shape: tf.TensorShape, sample_shape: tf.TensorShape, y: Tensor -) -> Tensor: + batch_shape: tf.TensorShape, sample_shape: tf.TensorShape, y: tf.Tensor +) -> tf.Tensor: """Reshape tensor to output shape. - Args: - ---- - batch_shape (tf.TensorShape): The batch shape of the output. - sample_shape (tf.TensorShape): The sample shape of the output. - y (Tensor): The tensor to reshape. + Parameters + ---------- + batch_shape + The batch shape of the output. + sample_shape + The sample shape of the output. + y + The tensor to reshape. - Returns: + Returns ------- - Tensor: The reshaped tensor. + Tensor: The reshaped tensor. """ output_shape = prefer_static.broadcast_shape(sample_shape, batch_shape) return tf.reshape(y, output_shape) -def gen_basis(order: int, dtype: tf.DType = tf.float32) -> Distribution: +def gen_basis(order: int, dtype: tf.DType = tf.float32) -> tfd.Distribution: """Generate Bernstein basis polynomials from Beta distributions. - Args: - ---- - order (int): The order of the Bernstein polynomial. - dtype (tf.DType, optional): The dtype of the Beta distribution. + Parameters + ---------- + order + The order of the Bernstein polynomial. + dtype + The dtype of the Beta distribution. - Returns: + Returns ------- - Distribution: A Beta distribution. + Distribution + A Beta distribution. """ return tfd.Beta( @@ -77,17 +78,21 @@ def gen_basis(order: int, dtype: tf.DType = tf.float32) -> Distribution: ) -def gen_bernstein_polynomial(thetas: Tensor) -> Tuple[Callable[[Tensor], Tensor], int]: +def gen_bernstein_polynomial( + thetas: tf.Tensor, +) -> Tuple[Callable[[tf.Tensor], tf.Tensor], int]: """Generate Bernstein polynomial as a Callable. - Args: - ---- - thetas (Tensor): The weights of the Bernstein polynomial. + Parameters + ---------- + thetas + The weights of the Bernstein polynomial. - Returns: + Returns ------- - Callable[[Tensor], Tensor]: A function that evaluates the Bernstein polynomial. - int: The order of the polynomial. + Callable[[Tensor], Tensor]: + A function that evaluates the Bernstein polynomial. + int: The order of the polynomial. """ theta_shape = prefer_static.shape(thetas) @@ -95,16 +100,18 @@ def gen_bernstein_polynomial(thetas: Tensor) -> Tuple[Callable[[Tensor], Tensor] basis = gen_basis(order, thetas.dtype) - def b_poly(y: Tensor) -> Tensor: + def b_poly(y: tf.Tensor) -> tf.Tensor: """Evaluate the Bernstein polynomial. - Args: - ---- - y (Tensor): The input to the Bernstein polynomial. + Parameters + ---------- + y + The input to the Bernstein polynomial. - Returns: + Returns ------- - Tensor: The output of the Bernstein polynomial. + Tensor + The output of the Bernstein polynomial. """ y = y[..., tf.newaxis] @@ -115,16 +122,18 @@ def b_poly(y: Tensor) -> Tensor: return b_poly, order -def derive_thetas(thetas: Tensor) -> Tensor: +def derive_thetas(thetas: tf.Tensor) -> tf.Tensor: """Calculate the derivative of the Bernstein polynomial weights. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. + Parameters + ---------- + thetas + The Bernstein polynomial weights. - Returns: + Returns ------- - Tensor: The derivative of the Bernstein polynomial weights. + Tensor + The derivative of the Bernstein polynomial weights. """ theta_shape = prefer_static.shape(thetas) @@ -134,18 +143,20 @@ def derive_thetas(thetas: Tensor) -> Tensor: return dtheta -def derive_bpoly(thetas: Tensor) -> Tuple[Callable[[Tensor], Tensor], int]: +def derive_bpoly(thetas: tf.Tensor) -> Tuple[Callable[[tf.Tensor], tf.Tensor], int]: """Generate the derivative of the Bernstein polynomial function. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. + Parameters + ---------- + thetas + The Bernstein polynomial weights. - Returns: + Returns ------- - Callable[[Tensor], Tensor]: A function that evaluates the derivative of the - Bernstein polynomial. - int: The order of the polynomial. + Callable[[Tensor], Tensor]: + A function that evaluates the derivative of the Bernstein polynomial. + int + The order of the polynomial. """ dtheta = derive_thetas(thetas) @@ -153,16 +164,18 @@ def derive_bpoly(thetas: Tensor) -> Tuple[Callable[[Tensor], Tensor], int]: return b_poly_dash, order -def get_bounds(thetas: Tensor) -> Tensor: +def get_bounds(thetas: tf.Tensor) -> tf.Tensor: """Get the bounds of the Bernstein polynomial. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. + Parameters + ---------- + thetas + The Bernstein polynomial weights. - Returns: + Returns ------- - Tensor: A tensor containing the lower and upper bounds. + Tensor + A tensor containing the lower and upper bounds. """ eps = dtype_util.eps(thetas.dtype) @@ -178,17 +191,20 @@ def get_bounds(thetas: Tensor) -> Tensor: return x -def evaluate_bpoly_on_bounds(thetas: Tensor, bounds: Tensor) -> Tensor: +def evaluate_bpoly_on_bounds(thetas: tf.Tensor, bounds: tf.Tensor) -> tf.Tensor: """Evaluate the Bernstein polynomial on the given bounds. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. - bounds (Tensor): The bounds to evaluate the polynomial on. + Parameters + ---------- + thetas + The Bernstein polynomial weights. + bounds + The bounds to evaluate the polynomial on. - Returns: + Returns ------- - Tensor: The Bernstein polynomial evaluated at the given bounds. + Tensor + The Bernstein polynomial evaluated at the given bounds. """ b_poly, _ = gen_bernstein_polynomial(thetas) @@ -197,28 +213,33 @@ def evaluate_bpoly_on_bounds(thetas: Tensor, bounds: Tensor) -> Tensor: def gen_linear_extension( - thetas: Tensor, + thetas: tf.Tensor, ) -> Tuple[ - Callable[[Tensor], Tensor], - Callable[[Tensor], Tensor], - Callable[[Tensor], Tensor], - Tensor, - Tensor, + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor], tf.Tensor], + tf.Tensor, + tf.Tensor, ]: """Generate a linear extension function. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. + Parameters + ---------- + thetas + The Bernstein polynomial weights. - Returns: + Returns ------- - Callable[[Tensor], Tensor]: The linear extension function. - Callable[[Tensor], Tensor]: The log determinant Jacobian of the - extension function. - Callable[[Tensor], Tensor]: The inverse of the extension function. - Tensor: The x bounds. - Tensor: The y bounds. + Callable[[Tensor], Tensor]: + The linear extension function. + Callable[[Tensor], Tensor]: + The log determinant Jacobian of the extension function. + Callable[[Tensor], Tensor]: + The inverse of the extension function. + Tensor + The x bounds. + Tensor + The y bounds. """ # [eps, 1 - eps] @@ -227,16 +248,18 @@ def gen_linear_extension( # [Be(eps), Be(1 - eps)] y_bounds = evaluate_bpoly_on_bounds(thetas, x_bounds) - def extra(x: Tensor) -> Tensor: + def extra(x: tf.Tensor) -> tf.Tensor: """Linear extension function. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- - Tensor: The extended output. + Tensor + The extended output. """ e0 = x + y_bounds[0] @@ -247,16 +270,18 @@ def extra(x: Tensor) -> Tensor: return y - def extra_log_det_jacobian(x: Tensor) -> Tensor: + def extra_log_det_jacobian(x: tf.Tensor) -> tf.Tensor: """Log determinant Jacobian of the linear extension function. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- - Tensor: The log determinant Jacobian. + Tensor + The log determinant Jacobian. """ y = tf.where(x <= x_bounds[0], tf.ones_like(x), np.nan) @@ -264,16 +289,18 @@ def extra_log_det_jacobian(x: Tensor) -> Tensor: return tf.math.log(tf.abs(y)) - def extra_inv(y: Tensor) -> Tensor: + def extra_inv(y: tf.Tensor) -> tf.Tensor: """Inverse of the linear extension function. - Args: - ---- - y (Tensor): The input tensor. + Parameters + ---------- + y + The input tensor. - Returns: + Returns ------- - Tensor: The inverse transformed tensor. + Tensor + The inverse transformed tensor. """ x0 = y - y_bounds[0] @@ -288,28 +315,33 @@ def extra_inv(y: Tensor) -> Tensor: def gen_linear_extrapolation( - thetas: Tensor, + thetas: tf.Tensor, ) -> Tuple[ - Callable[[Tensor], Tensor], - Callable[[Tensor], Tensor], - Callable[[Tensor], Tensor], - Tensor, - Tensor, + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor], tf.Tensor], + tf.Tensor, + tf.Tensor, ]: """Generate a linear extrapolation function. - Args: - ---- - thetas (Tensor): The Bernstein polynomial weights. + Parameters + ---------- + thetas + The Bernstein polynomial weights. - Returns: + Returns ------- - Callable[[Tensor], Tensor]: The linear extrapolation function. - Callable[[Tensor], Tensor]: The log determinant Jacobian of the - extrapolation function. - Callable[[Tensor], Tensor]: The inverse of the extrapolation function. - Tensor: The x bounds. - Tensor: The y bounds. + Callable[[Tensor], Tensor]: + The linear extrapolation function. + Callable[[Tensor], Tensor]: + The log determinant Jacobian of the extrapolation function. + Callable[[Tensor], Tensor]: + The inverse of the extrapolation function. + Tensor + The x bounds. + Tensor + The y bounds. """ # [eps, 1 - eps] @@ -322,16 +354,18 @@ def gen_linear_extrapolation( dtheta = derive_thetas(thetas) a = evaluate_bpoly_on_bounds(dtheta, x_bounds) - def extra(x: Tensor) -> Tensor: + def extra(x: tf.Tensor) -> tf.Tensor: """Linear extrapolation function. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- - Tensor: The extrapolated output. + Tensor + The extrapolated output. """ e0 = a[0] * x + y_bounds[0] @@ -342,16 +376,18 @@ def extra(x: Tensor) -> Tensor: return y - def extra_log_det_jacobian(x: Tensor) -> Tensor: + def extra_log_det_jacobian(x: tf.Tensor) -> tf.Tensor: """Log determinant Jacobian of the linear extrapolation function. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- - Tensor: The log determinant Jacobian. + Tensor + The log determinant Jacobian. """ y = tf.where(x <= x_bounds[0], a[0], np.nan) @@ -359,16 +395,18 @@ def extra_log_det_jacobian(x: Tensor) -> Tensor: return tf.math.log(tf.abs(y)) - def extra_inv(y: Tensor) -> Tensor: + def extra_inv(y: tf.Tensor) -> tf.Tensor: """Inverse of the linear extrapolation function. - Args: - ---- - y (Tensor): The input tensor. + Parameters + ---------- + y + The input tensor. - Returns: + Returns ------- - Tensor: The inverse transformed tensor. + Tensor + The inverse transformed tensor. """ x0 = (y - y_bounds[0]) / a[0] @@ -382,68 +420,82 @@ def extra_inv(y: Tensor) -> Tensor: return extra, extra_log_det_jacobian, extra_inv, x_bounds, y_bounds -def transform_to_bernstein_domain(x: Tensor, low: Tensor, high: Tensor) -> Tensor: +def transform_to_bernstein_domain( + x: tf.Tensor, low: tf.Tensor, high: tf.Tensor +) -> tf.Tensor: """Transform the input to the Bernstein polynomial domain. - Args: - ---- - x (Tensor): The input. - low (Tensor): The lower bound of the domain. - high (Tensor): The upper bound of the domain. + Parameters + ---------- + x + The input. + low + The lower bound of the domain. + high + The upper bound of the domain. - Returns: + Returns ------- - Tensor: The transformed input. + Tensor + The transformed input. """ return (x - low) / (high - low) -def transform_to_support(y: Tensor, low: Tensor, high: Tensor) -> Tensor: +def transform_to_support(y: tf.Tensor, low: tf.Tensor, high: tf.Tensor) -> tf.Tensor: """Transform the output from the Bernstein polynomial domain to the original domain. - Args: - ---- - y (Tensor): The output. - low (Tensor): The lower bound of the domain. - high (Tensor): The upper bound of the domain. + Parameters + ---------- + y + The output. + low + The lower bound of the domain. + high + The upper bound of the domain. - Returns: + Returns ------- - Tensor: The transformed output. + Tensor + The transformed output. """ return y * (high - low) + low def generate_bernstein_polynomial_with_extrapolation( - theta: Tensor, + theta: tf.Tensor, gen_extrapolation_fn: Callable = gen_linear_extrapolation, - domain: Optional[Tuple[Tensor, Tensor]] = None, + domain: Optional[Tuple[tf.Tensor, tf.Tensor]] = None, ) -> Tuple[ - Callable[[Tensor], Tensor], - Callable[[Tensor], Tensor], - Callable[[Tensor, Callable], Tensor], + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor], tf.Tensor], + Callable[[tf.Tensor, Callable], tf.Tensor], int, ]: """Generate a Bernstein polynomial with extrapolation. - Args: - ---- - theta (Tensor): The Bernstein polynomial weights. - gen_extrapolation_fn (Callable): The function used to generate the - extrapolation function. - domain (Optional[Tuple[Tensor, Tensor]]): The domain of the Bernstein polynomial - - Returns: + Parameters + ---------- + theta + The Bernstein polynomial weights. + gen_extrapolation_fn + The function used to generate the + extrapolation function. + domain + The domain of the Bernstein polynomial + + Returns ------- - Callable[[Tensor], Tensor]: A function that evaluates the Bernstein polynomial - with extrapolation. - Callable[[Tensor], Tensor]: A function that computes the log determinant - Jacobian of the function. - Callable[[Tensor, Callable], Tensor]: A function that computes the inverse of - the function. - int: The order of the polynomial. + Callable[[Tensor], Tensor]: + A function that evaluates the Bernstein polynomial with extrapolation. + Callable[[Tensor], Tensor]: + A function that computes the log determinant Jacobian of the function. + Callable[[Tensor, Callable], Tensor]: + A function that computes the inverse of the function. + int + The order of the polynomial. """ theta_shape = prefer_static.shape(theta) @@ -455,14 +507,15 @@ def generate_bernstein_polynomial_with_extrapolation( theta ) - def bpoly_extra(x: Tensor) -> Tensor: + def bpoly_extra(x: tf.Tensor) -> tf.Tensor: """Bernstein polynomial with extrapolation function. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- Tensor: The evaluated Bernstein polynomial with extrapolation. @@ -473,14 +526,15 @@ def bpoly_extra(x: Tensor) -> Tensor: y = tf.where(x_safe, y, extra(x)) return reshape_output(batch_shape, sample_shape, y) - def bpoly_log_det_jacobian_extra(x: Tensor) -> Tensor: + def bpoly_log_det_jacobian_extra(x: tf.Tensor) -> tf.Tensor: """Log determinant Jacobian of the Bernstein polynomial with extrapolation. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- Tensor: The log determinant Jacobian. @@ -491,17 +545,20 @@ def bpoly_log_det_jacobian_extra(x: Tensor) -> Tensor: y = tf.where(x_safe, y, extra_log_det_jacobian(x)) return reshape_output(batch_shape, sample_shape, y) - def bpoly_inverse_extra(y: Tensor, inverse_approx_fn: Callable) -> Tensor: + def bpoly_inverse_extra(y: tf.Tensor, inverse_approx_fn: Callable) -> tf.Tensor: """Inverse of the Bernstein polynomial with extrapolation. - Args: - ---- - y (Tensor): The input tensor. - inverse_approx_fn (Callable): Function to approximate the inverse. + Parameters + ---------- + y + The input tensor. + inverse_approx_fn + Function to approximate the inverse. - Returns: + Returns ------- - Tensor: The inverse transformed tensor. + Tensor + The inverse transformed tensor. """ sample_shape = prefer_static.shape(y) @@ -514,29 +571,32 @@ def bpoly_inverse_extra(y: Tensor, inverse_approx_fn: Callable) -> Tensor: low = tf.convert_to_tensor(domain[0], theta.dtype) high = tf.convert_to_tensor(domain[1], theta.dtype) - def bpoly_extra_scaled(x: Tensor) -> Tensor: + def bpoly_extra_scaled(x: tf.Tensor) -> tf.Tensor: """Scaled Bernstein polynomial with extrapolation. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- - Tensor: The scaled evaluated Bernstein polynomial with extrapolation. + Tensor + The scaled evaluated Bernstein polynomial with extrapolation. """ x = transform_to_bernstein_domain(x, low, high) return bpoly_extra(x) - def bpoly_log_det_jacobian_extra_scaled(x: Tensor) -> Tensor: + def bpoly_log_det_jacobian_extra_scaled(x: tf.Tensor) -> tf.Tensor: """Log det. Jacobian of the scaled Bernstein polynomial with extrapolation. - Args: - ---- - x (Tensor): The input tensor. + Parameters + ---------- + x + The input tensor. - Returns: + Returns ------- Tensor: The log determinant Jacobian. @@ -545,16 +605,18 @@ def bpoly_log_det_jacobian_extra_scaled(x: Tensor) -> Tensor: return bpoly_log_det_jacobian_extra(x) - tf.math.log(high - low) def bpoly_inverse_extra_scaled( - y: Tensor, inverse_approx_fn: Callable - ) -> Tensor: + y: tf.Tensor, inverse_approx_fn: Callable + ) -> tf.Tensor: """Inverse of the scaled Bernstein polynomial with extrapolation function. - Args: - ---- - y (Tensor): The input tensor. - inverse_approx_fn (Callable): Function to approximate the inverse. + Parameters + ---------- + y + The input tensor. + inverse_approx_fn + Function to approximate the inverse. - Returns: + Returns ------- Tensor: The inverse transformed tensor. @@ -580,16 +642,17 @@ def bpoly_inverse_extra_scaled( def generate_bernstein_polynomial_with_linear_extension(*args, **kwargs): """Generate a Bernstein polynomial with linear extension. - Args: - ---- - *args: Arguments passed to `gen_bernstein_polynomial_with_extrapolation`. - **kwargs: Keyword arguments passed to - `gen_bernstein_polynomial_with_extrapolation`. + Parameters + ---------- + *args + Arguments passed to `gen_bernstein_polynomial_with_extrapolation`. + **kwargs + Keyword arguments passed to `gen_bernstein_polynomial_with_extrapolation`. - Returns: + Returns ------- - Tuple: The output of `gen_bernstein_polynomial_with_extrapolation` with - `gen_extrapolation_fn` set to `gen_linear_extension`. + Tuple: The output of `gen_bernstein_polynomial_with_extrapolation` with + `gen_extrapolation_fn` set to `gen_linear_extension`. """ return generate_bernstein_polynomial_with_extrapolation( @@ -600,16 +663,17 @@ def generate_bernstein_polynomial_with_linear_extension(*args, **kwargs): def generate_bernstein_polynomial_with_linear_extrapolation(*args, **kwargs): """Generate a Bernstein polynomial with linear extrapolation. - Args: - ---- - *args: Arguments passed to `gen_bernstein_polynomial_with_extrapolation`. - **kwargs: Keyword arguments passed to - `gen_bernstein_polynomial_with_extrapolation`. + Parameters + ---------- + *args + Arguments passed to `gen_bernstein_polynomial_with_extrapolation`. + **kwargs + Keyword arguments passed to `gen_bernstein_polynomial_with_extrapolation`. - Returns: + Returns ------- - Tuple: The output of `gen_bernstein_polynomial_with_extrapolation` with - `gen_extrapolation_fn` set to `gen_linear_extrapolation`. + Tuple: The output of `gen_bernstein_polynomial_with_extrapolation` with + `gen_extrapolation_fn` set to `gen_linear_extrapolation`. """ return generate_bernstein_polynomial_with_extrapolation( diff --git a/src/bernstein_flow/util/visualization/plot_flow.py b/src/bernstein_flow/util/visualization/plot_flow.py index 9c4fe04..689c0cd 100644 --- a/src/bernstein_flow/util/visualization/plot_flow.py +++ b/src/bernstein_flow/util/visualization/plot_flow.py @@ -4,7 +4,7 @@ # author : Marcel Arpogaus # # created : 2024-07-12 14:49:21 (Marcel Arpogaus) -# changed : 2024-07-12 14:49:21 (Marcel Arpogaus) +# changed : 2024-07-18 12:02:58 (Marcel Arpogaus) # %% License ################################################################### # Copyright 2024 Marcel Arpogaus @@ -50,9 +50,9 @@ def _get_annot_map(bijector_names: List[str], bijector_name: str) -> Dict[str, s Parameters ---------- - bijector_names : List[str] + bijector_names List of bijector names. - bijector_name : str + bijector_name Name of the bijector to split the data at. Returns @@ -84,7 +84,7 @@ def _get_formulas(bijectors: List[tfb.Bijector]) -> str: Parameters ---------- - bijectors : List[tfb.Bijector] + bijectors List of bijectors. Returns @@ -110,7 +110,7 @@ def _get_bijectors_recursive(bijector: tfb.Bijector) -> List[tfb.Bijector]: Parameters ---------- - bijector : tfb.Bijector + bijector The bijector to extract from. Returns @@ -134,7 +134,7 @@ def _get_bijectors( Parameters ---------- - flow : tfp.distributions.TransformedDistribution + flow Transformed distribution. Returns @@ -151,7 +151,7 @@ def _get_bijector_names(bijectors: List[tfb.Bijector]) -> List[str]: Parameters ---------- - bijectors : List[tfb.Bijector] + bijectors List of bijectors. Returns @@ -170,9 +170,9 @@ def _split_bijector_names( Parameters ---------- - bijector_names : List[str] + bijector_names List of bijector names. - split_bijector_name : str + split_bijector_name Name of the bijector to split the list at. Returns @@ -197,17 +197,17 @@ def _get_plot_data( Parameters ---------- - flow : tfp.distributions.TransformedDistribution + flow Transformed distribution. - bijector_name : str + bijector_name Name of the bijector to split the data at. - n : int + n Number of samples - z_values : np.ndarray + z_values Predefined sample values - seed : int + seed Random seed - ignore_bijectors : Tuple[str] + ignore_bijectors Tuple containing names of bijectors to ignore during plotting Returns @@ -263,9 +263,9 @@ def _configure_axes(a: Axes, style: str): Parameters ---------- - a : Axes + a Axes object to configure. - style : str + style Style of the axes. Can be "right", "top", or "none". """ @@ -300,17 +300,17 @@ def _prepare_figure( Parameters ---------- - plot_data : Dict[str, Dict[str, np.ndarray]] + plot_data Plot data. - pre_bpoly_trafos : List[str] + pre_bpoly_trafos Pre-split bijector names. - post_bpoly_trafos : List[str] + post_bpoly_trafos Post-split bijector names. - size : int, optional + size Figure size - wspace : float + wspace Width space between subplots - hspace : float + hspace Height space between subplots Returns @@ -377,13 +377,13 @@ def _plot_data_to_axes( Parameters ---------- - axs : Dict[str, Axes] + axs Dictionary mapping bijector names to axes. - plot_data : Dict[str, Dict[str, np.ndarray]] + plot_data Plot data. - pre_bpoly_trafos : List[str] + pre_bpoly_trafos Pre-split bijector names. - post_bpoly_trafos : List[str] + post_bpoly_trafos Post-split bijector names. """ @@ -436,30 +436,30 @@ def _add_annot_to_axes( Parameters ---------- - axs : Dict[str, Axes] + axs Dictionary mapping bijector names to axes. - plot_data : Dict[str, Dict[str, np.ndarray]] + plot_data Plot data. - pre_bpoly_trafos : List[str] + pre_bpoly_trafos Pre-split bijector names. - post_bpoly_trafos : List[str] + post_bpoly_trafos Post-split bijector names. - bijector_name : str + bijector_name Name of the bijector to split the data at. - annot_map : Dict[str, str], optional + annot_map Dictionary mapping bijector names to annotations, by default {} - extra_annot_prob : Dict[str, Tuple[Tuple[float, float], str, int]], optional + extra_annot_prob Dictionary containing extra annotations for probabilities, by default {} - extra_annot_sample : Dict[str, Tuple[Tuple[float, float], str, int]], optional + extra_annot_sample Dictionary containing extra annotations for samples, by default {} - formulas : str, optional + formulas LaTeX formulas string, by default "" - pos : float, optional + pos Position of the arrows, by default 0.5 - cp_kwds : Dict, optional + cp_kwds Keyword arguments for the ConnectionPatch, by default dict(arrowstyle="-|>", shrinkA=10, shrinkB=10, color="gray") - usetex : bool, optional + usetex Whether to use LaTeX for text rendering, by default True """ @@ -590,27 +590,27 @@ def plot_flow( Parameters ---------- - flow : tfp.distributions.TransformedDistribution + flow Transformed distribution to plot. - bijector_name : str, optional + bijector_name Name of the bijector to split the data at, by default "bernstein_bijector" - n : int, optional + n Number of samples, by default 500 - z_values : np.ndarray, optional + z_values Predefined sample values, by default None - seed : int, optional + seed Random seed, by default 1 - size : float, optional + size Figure size scaling factor, by default 1.5 - wspace : float, optional + wspace Width space between subplots, by default 0.5 - hspace : float, optional + hspace Height space between subplots, by default 0.5 - usetex : bool, optional + usetex Whether to use LaTeX for text rendering, by default True - ignore_bijectors : Tuple[str], optional + ignore_bijectors Tuple containing names of bijectors to ignore during plotting, by default () - **kwds : optional + **kwds Additional keyword arguments passed to add_annot_to_axes. Returns