Skip to content

Commit

Permalink
pr review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jul 7, 2023
1 parent d5986cf commit 0fde99b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def file_paths(self, file_paths):

msg = ('No valid files provided to DataHandler. '
f'Received file_paths={file_paths}. Aborting.')
assert len(self._file_paths) > 0 and file_paths is not None, msg
assert file_paths is not None and len(self._file_paths) > 0, msg

self._file_paths = sorted(self._file_paths)

Expand Down
38 changes: 37 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Loss metrics for Sup3r"""

from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
import tensorflow as tf
from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError


def gaussian_kernel(x1, x2, sigma=1.0):
Expand Down Expand Up @@ -173,6 +173,42 @@ def __call__(self, x1, x2):
return self.MSE_LOSS(x1_coarse, x2_coarse)


class SpatialExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
spatial domain"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom content loss that encourages temporal min/max accuracy
Parameters
----------
x1 : tf.tensor
synthetic generator output
(n_observations, spatial_1, spatial_2, features)
x2 : tf.tensor
high resolution data
(n_observations, spatial_1, spatial_2, features)
Returns
-------
tf.tensor
0D tensor with loss value
"""
x1_min = tf.reduce_min(x1, axis=(1, 2))
x2_min = tf.reduce_min(x2, axis=(1, 2))

x1_max = tf.reduce_max(x1, axis=(1, 2))
x2_max = tf.reduce_max(x2, axis=(1, 2))

mae = self.MAE_LOSS(x1, x2)
mae_min = self.MAE_LOSS(x1_min, x2_min)
mae_max = self.MAE_LOSS(x1_max, x2_max)

return mae + mae_min + mae_max


class TemporalExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
timeseries"""
Expand Down

0 comments on commit 0fde99b

Please sign in to comment.