-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
doc: add / update type annotations and docstrings
- Loading branch information
Showing
10 changed files
with
313 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
""".. include:: ../../README.md""" | ||
""".. include:: ../../README.md""" # noqa: D415,D400 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,30 @@ | ||
# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- | ||
# AUTHOR INFORMATION ########################################################## | ||
# %% Author #################################################################### | ||
# file : __init__.py | ||
# author : Marcel Arpogaus <marcel dot arpogaus at gmail dot com> | ||
# author : Marcel Arpogaus <[email protected]> | ||
# | ||
# created : 2022-03-10 15:39:04 (Marcel Arpogaus) | ||
# changed : 2024-07-09 18:08:52 (Marcel Arpogaus) | ||
# DESCRIPTION ################################################################# | ||
# ... | ||
# LICENSE ##################################################################### | ||
# ... | ||
############################################################################### | ||
# created : 2024-07-12 15:12:18 (Marcel Arpogaus) | ||
# changed : 2024-07-12 15:20:43 (Marcel Arpogaus) | ||
|
||
# %% License ################################################################### | ||
# Copyright 2024 Marcel Arpogaus | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# %% Description ############################################################### | ||
"""Activation functions applied to unconstrained outputs of neural networks.""" | ||
|
||
# %% imports ################################################################### | ||
import tensorflow as tf | ||
from tensorflow_probability.python.internal import ( | ||
dtype_util, | ||
|
@@ -18,17 +33,39 @@ | |
) | ||
|
||
|
||
# %% functions ################################################################# | ||
def get_thetas_constrain_fn( | ||
bounds=(None, None), | ||
smooth_bounds=True, | ||
allow_flexible_bounds=False, | ||
bounds: tuple[float | None, float | None] = (None, None), | ||
smooth_bounds: bool = True, | ||
allow_flexible_bounds: bool = False, | ||
fn=tf.math.softplus, | ||
eps=1e-5, | ||
eps: float = 1e-5, | ||
): | ||
"""Return a function that constrains the output of a neural network. | ||
Parameters | ||
---------- | ||
bounds | ||
The lower and upper bounds of the output. | ||
smooth_bounds | ||
Whether to ensure smooth the bounds by enforcing `Be''(0)==Be(1)==0`. | ||
allow_flexible_bounds | ||
Whether to allow the bounds to be flexible, i.e. to depend on the input. | ||
fn | ||
The positive definite function to apply to the unconstrained parameters. | ||
eps | ||
A small number to add to the output to avoid numerical issues. | ||
Returns | ||
------- | ||
callable | ||
A function that constrains the output of a neural network. | ||
""" | ||
low, high = bounds | ||
|
||
# @tf.function | ||
def constrain_fn(diff): | ||
def constrain_fn(diff: tf.Tensor): | ||
dtype = dtype_util.common_dtype([diff], dtype_hint=tf.float32) | ||
diff = tensor_util.convert_nonref_to_tensor(diff, name="diff", dtype=dtype) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,12 @@ | ||
# AUTHOR INFORMATION ########################################################## | ||
# file : bernstein_bijector.py | ||
# brief : [Description] | ||
# -*- time-stamp-pattern: "changed[\s]+:[\s]+%%$"; -*- | ||
# %% Author #################################################################### | ||
# file : bernstein.py | ||
# author : Marcel Arpogaus <[email protected]> | ||
# | ||
# author : Marcel Arpogaus | ||
# created : 2020-09-11 14:14:24 | ||
# changed : 2020-12-07 16:29:11 | ||
# DESCRIPTION ################################################################# | ||
# | ||
# This project is following the PEP8 style guide: | ||
# | ||
# https://www.python.org/dev/peps/pep-0008/) | ||
# | ||
# COPYRIGHT ################################################################### | ||
# created : 2024-07-12 14:52:28 (Marcel Arpogaus) | ||
# changed : 2024-07-12 14:52:28 (Marcel Arpogaus) | ||
|
||
# %% License ################################################################### | ||
# Copyright 2020 Marcel Arpogaus | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
|
@@ -26,8 +21,10 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
############################################################################### | ||
# %% Description ############################################################### | ||
"""Implement a Bernstein Polynomial as a `tfp.Bijector`.""" | ||
|
||
# REQUIRED PYTHON MODULES ##################################################### | ||
# %% imports ################################################################### | ||
from functools import partial | ||
from typing import Tuple | ||
|
||
|
@@ -37,30 +34,46 @@ | |
|
||
from bernstein_flow.activations import get_thetas_constrain_fn | ||
from bernstein_flow.math.bernstein import ( | ||
gen_bernstein_polynomial_with_linear_extension, | ||
gen_bernstein_polynomial_with_linear_extrapolation, | ||
generate_bernstein_polynomial_with_linear_extension, | ||
generate_bernstein_polynomial_with_linear_extrapolation, | ||
) | ||
|
||
|
||
# %% classes ################################################################### | ||
class BernsteinPolynomial(tfp.experimental.bijectors.ScalarFunctionWithInferredInverse): | ||
"""Implementation of a Bernstein polynomials Bijector.""" | ||
"""Implement a Bernstein polynomials Bijector. | ||
This bijector transforms an input tensor by applying a Bernstein polynomial. | ||
Attributes | ||
---------- | ||
thetas: The Bernstein coefficients. | ||
extrapolation: The method to extrapolate outside of bounds. | ||
analytic_jacobian: Whether to use the analytic Jacobian. | ||
domain: The domain of the Bernstein polynomial. | ||
name: The name to give Ops created by the initializer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
thetas: tf.Tensor, | ||
extrapolation: bool = True, | ||
analytic_jacobian: bool = True, | ||
domain: Tuple[int, ...] = None, | ||
domain: Tuple[float, float] = (0.0, 1.0), | ||
name: str = "bernstein_bijector", | ||
**kwds, | ||
**kwargs, | ||
) -> None: | ||
"""Constructs a new instance of a Bernstein polynomial bijector. | ||
"""Construct a new instance of a Bernstein polynomial bijector. | ||
:param thetas: The Bernstein coefficients. | ||
:type thetas: tf.Tensor | ||
:param extrapolation: Method to extrapolate outside of bounds. | ||
:type extrapolation: str | ||
:param name: The name to give Ops created by the initializer. | ||
Args: | ||
---- | ||
thetas: The Bernstein coefficients. | ||
extrapolation: The method to extrapolate outside of bounds. | ||
analytic_jacobian: Whether to use the analytic Jacobian. | ||
domain: The domain of the Bernstein polynomial. | ||
name: The name to give Ops created by the initializer. | ||
kwargs: Keyword arguments for the parent class. | ||
""" | ||
with tf.name_scope(name) as name: | ||
|
@@ -76,7 +89,7 @@ def __init__( | |
forward_log_det_jacobian, | ||
self.b_poly_inverse_extra, | ||
self.order, | ||
) = gen_bernstein_polynomial_with_linear_extrapolation( | ||
) = generate_bernstein_polynomial_with_linear_extrapolation( | ||
self.thetas, domain=domain | ||
) | ||
else: | ||
|
@@ -85,16 +98,12 @@ def __init__( | |
forward_log_det_jacobian, | ||
self.b_poly_inverse_extra, | ||
self.order, | ||
) = gen_bernstein_polynomial_with_linear_extension( | ||
) = generate_bernstein_polynomial_with_linear_extension( | ||
self.thetas, domain=domain | ||
) | ||
|
||
if domain: | ||
low = tf.convert_to_tensor(domain[0], dtype=dtype) | ||
high = tf.convert_to_tensor(domain[1], dtype=dtype) | ||
else: | ||
low = tf.convert_to_tensor(0, dtype=dtype) | ||
high = tf.convert_to_tensor(1, dtype=dtype) | ||
low = tf.convert_to_tensor(domain[0], dtype=dtype) | ||
high = tf.convert_to_tensor(domain[1], dtype=dtype) | ||
|
||
if analytic_jacobian: | ||
self._forward_log_det_jacobian = forward_log_det_jacobian | ||
|
@@ -125,10 +134,11 @@ def root_search_fn(objective_fn, _, max_iterations=None): | |
max_iterations=50, | ||
name=name, | ||
dtype=dtype, | ||
**kwds, | ||
**kwargs, | ||
) | ||
|
||
def _inverse_no_gradient(self, y): | ||
def _inverse_no_gradient(self, y: tf.Tensor) -> tf.Tensor: | ||
"""Compute the inverse of the bijector.""" | ||
return tf.stop_gradient( | ||
self.b_poly_inverse_extra(y, inverse_approx_fn=super()._inverse_no_gradient) | ||
) | ||
|
@@ -141,5 +151,6 @@ def _parameter_properties(cls, dtype=None): | |
), | ||
) | ||
|
||
def _is_increasing(self, **kwargs): | ||
def _is_increasing(self, **kwargs) -> bool: | ||
"""Check if the bijector is increasing.""" | ||
return tf.reduce_all(self.thetas[..., 1:] >= self.thetas[..., :-1]) |
Oops, something went wrong.