Skip to content

Commit

Permalink
doc: add / update type annotations and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Jul 12, 2024
1 parent fb12483 commit 92afafc
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 164 deletions.
2 changes: 1 addition & 1 deletion src/bernstein_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""".. include:: ../../README.md"""
""".. include:: ../../README.md""" # noqa: D415,D400
65 changes: 51 additions & 14 deletions src/bernstein_flow/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*-
# AUTHOR INFORMATION ##########################################################
# %% Author ####################################################################
# file : __init__.py
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com>
# author : Marcel Arpogaus <[email protected]>
#
# created : 2022-03-10 15:39:04 (Marcel Arpogaus)
# changed : 2024-07-09 18:08:52 (Marcel Arpogaus)
# DESCRIPTION #################################################################
# ...
# LICENSE #####################################################################
# ...
###############################################################################
# created : 2024-07-12 15:12:18 (Marcel Arpogaus)
# changed : 2024-07-12 15:20:43 (Marcel Arpogaus)

# %% License ###################################################################
# Copyright 2024 Marcel Arpogaus
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# %% Description ###############################################################
"""Activation functions applied to unconstrained outputs of neural networks."""

# %% imports ###################################################################
import tensorflow as tf
from tensorflow_probability.python.internal import (
dtype_util,
Expand All @@ -18,17 +33,39 @@
)


# %% functions #################################################################
def get_thetas_constrain_fn(
bounds=(None, None),
smooth_bounds=True,
allow_flexible_bounds=False,
bounds: tuple[float | None, float | None] = (None, None),
smooth_bounds: bool = True,
allow_flexible_bounds: bool = False,
fn=tf.math.softplus,
eps=1e-5,
eps: float = 1e-5,
):
"""Return a function that constrains the output of a neural network.
Parameters
----------
bounds
The lower and upper bounds of the output.
smooth_bounds
Whether to ensure smooth the bounds by enforcing `Be''(0)==Be(1)==0`.
allow_flexible_bounds
Whether to allow the bounds to be flexible, i.e. to depend on the input.
fn
The positive definite function to apply to the unconstrained parameters.
eps
A small number to add to the output to avoid numerical issues.
Returns
-------
callable
A function that constrains the output of a neural network.
"""
low, high = bounds

# @tf.function
def constrain_fn(diff):
def constrain_fn(diff: tf.Tensor):
dtype = dtype_util.common_dtype([diff], dtype_hint=tf.float32)
diff = tensor_util.convert_nonref_to_tensor(diff, name="diff", dtype=dtype)

Expand Down
83 changes: 47 additions & 36 deletions src/bernstein_flow/bijectors/bernstein.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
# AUTHOR INFORMATION ##########################################################
# file : bernstein_bijector.py
# brief : [Description]
# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*-
# %% Author ####################################################################
# file : bernstein.py
# author : Marcel Arpogaus <[email protected]>
#
# author : Marcel Arpogaus
# created : 2020-09-11 14:14:24
# changed : 2020-12-07 16:29:11
# DESCRIPTION #################################################################
#
# This project is following the PEP8 style guide:
#
# https://www.python.org/dev/peps/pep-0008/)
#
# COPYRIGHT ###################################################################
# created : 2024-07-12 14:52:28 (Marcel Arpogaus)
# changed : 2024-07-12 14:52:28 (Marcel Arpogaus)

# %% License ###################################################################
# Copyright 2020 Marcel Arpogaus
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -26,8 +21,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# %% Description ###############################################################
"""Implement a Bernstein Polynomial as a `tfp.Bijector`."""

# REQUIRED PYTHON MODULES #####################################################
# %% imports ###################################################################
from functools import partial
from typing import Tuple

Expand All @@ -37,30 +34,46 @@

from bernstein_flow.activations import get_thetas_constrain_fn
from bernstein_flow.math.bernstein import (
gen_bernstein_polynomial_with_linear_extension,
gen_bernstein_polynomial_with_linear_extrapolation,
generate_bernstein_polynomial_with_linear_extension,
generate_bernstein_polynomial_with_linear_extrapolation,
)


# %% classes ###################################################################
class BernsteinPolynomial(tfp.experimental.bijectors.ScalarFunctionWithInferredInverse):
"""Implementation of a Bernstein polynomials Bijector."""
"""Implement a Bernstein polynomials Bijector.
This bijector transforms an input tensor by applying a Bernstein polynomial.
Attributes
----------
thetas: The Bernstein coefficients.
extrapolation: The method to extrapolate outside of bounds.
analytic_jacobian: Whether to use the analytic Jacobian.
domain: The domain of the Bernstein polynomial.
name: The name to give Ops created by the initializer.
"""

def __init__(
self,
thetas: tf.Tensor,
extrapolation: bool = True,
analytic_jacobian: bool = True,
domain: Tuple[int, ...] = None,
domain: Tuple[float, float] = (0.0, 1.0),
name: str = "bernstein_bijector",
**kwds,
**kwargs,
) -> None:
"""Constructs a new instance of a Bernstein polynomial bijector.
"""Construct a new instance of a Bernstein polynomial bijector.
:param thetas: The Bernstein coefficients.
:type thetas: tf.Tensor
:param extrapolation: Method to extrapolate outside of bounds.
:type extrapolation: str
:param name: The name to give Ops created by the initializer.
Args:
----
thetas: The Bernstein coefficients.
extrapolation: The method to extrapolate outside of bounds.
analytic_jacobian: Whether to use the analytic Jacobian.
domain: The domain of the Bernstein polynomial.
name: The name to give Ops created by the initializer.
kwargs: Keyword arguments for the parent class.
"""
with tf.name_scope(name) as name:
Expand All @@ -76,7 +89,7 @@ def __init__(
forward_log_det_jacobian,
self.b_poly_inverse_extra,
self.order,
) = gen_bernstein_polynomial_with_linear_extrapolation(
) = generate_bernstein_polynomial_with_linear_extrapolation(
self.thetas, domain=domain
)
else:
Expand All @@ -85,16 +98,12 @@ def __init__(
forward_log_det_jacobian,
self.b_poly_inverse_extra,
self.order,
) = gen_bernstein_polynomial_with_linear_extension(
) = generate_bernstein_polynomial_with_linear_extension(
self.thetas, domain=domain
)

if domain:
low = tf.convert_to_tensor(domain[0], dtype=dtype)
high = tf.convert_to_tensor(domain[1], dtype=dtype)
else:
low = tf.convert_to_tensor(0, dtype=dtype)
high = tf.convert_to_tensor(1, dtype=dtype)
low = tf.convert_to_tensor(domain[0], dtype=dtype)
high = tf.convert_to_tensor(domain[1], dtype=dtype)

if analytic_jacobian:
self._forward_log_det_jacobian = forward_log_det_jacobian
Expand Down Expand Up @@ -125,10 +134,11 @@ def root_search_fn(objective_fn, _, max_iterations=None):
max_iterations=50,
name=name,
dtype=dtype,
**kwds,
**kwargs,
)

def _inverse_no_gradient(self, y):
def _inverse_no_gradient(self, y: tf.Tensor) -> tf.Tensor:
"""Compute the inverse of the bijector."""
return tf.stop_gradient(
self.b_poly_inverse_extra(y, inverse_approx_fn=super()._inverse_no_gradient)
)
Expand All @@ -141,5 +151,6 @@ def _parameter_properties(cls, dtype=None):
),
)

def _is_increasing(self, **kwargs):
def _is_increasing(self, **kwargs) -> bool:
"""Check if the bijector is increasing."""
return tf.reduce_all(self.thetas[..., 1:] >= self.thetas[..., :-1])
Loading

0 comments on commit 92afafc

Please sign in to comment.