From 0fde99b30dca61faf5a366e842b4af24eda22601 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 7 Jul 2023 10:14:40 -0600 Subject: [PATCH] pr review changes --- sup3r/preprocessing/data_handling.py | 2 +- sup3r/utilities/loss_metrics.py | 38 +++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index b8ea75ea44..9245dc8cc5 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -271,7 +271,7 @@ def file_paths(self, file_paths): msg = ('No valid files provided to DataHandler. ' f'Received file_paths={file_paths}. Aborting.') - assert len(self._file_paths) > 0 and file_paths is not None, msg + assert file_paths is not None and len(self._file_paths) > 0, msg self._file_paths = sorted(self._file_paths) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 312397f8c5..64a0680ed7 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -1,7 +1,7 @@ """Loss metrics for Sup3r""" -from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError def gaussian_kernel(x1, x2, sigma=1.0): @@ -173,6 +173,42 @@ def __call__(self, x1, x2): return self.MSE_LOSS(x1_coarse, x2_coarse) +class SpatialExtremesLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the min/max values in the + spatial domain""" + + MAE_LOSS = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Custom content loss that encourages temporal min/max accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_min = tf.reduce_min(x1, axis=(1, 2)) + x2_min = tf.reduce_min(x2, axis=(1, 2)) + + x1_max = tf.reduce_max(x1, axis=(1, 2)) + x2_max = tf.reduce_max(x2, axis=(1, 2)) + + mae = self.MAE_LOSS(x1, x2) + mae_min = self.MAE_LOSS(x1_min, x2_min) + mae_max = self.MAE_LOSS(x1_max, x2_max) + + return mae + mae_min + mae_max + + class TemporalExtremesLoss(tf.keras.losses.Loss): """Loss class that encourages accuracy of the min/max values in the timeseries"""