From eb3fc9c6ba1a558a4ef31355fb45c58ea28950c8 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 29 Jun 2023 09:32:13 -0600 Subject: [PATCH 1/3] misc --- sup3r/models/abstract.py | 28 +++++++++++++---------- sup3r/models/base.py | 9 ++++---- sup3r/pipeline/forward_pass.py | 2 ++ sup3r/preprocessing/data_handling.py | 33 ++++++++++++++++++---------- tests/training/test_train_gan.py | 27 +++++++++++------------ 5 files changed, 58 insertions(+), 41 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 4106336d26..a5da6dfa9a 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -2,23 +2,25 @@ """ Abstract class to define the required interface for Sup3r model subclasses """ -from abc import ABC, abstractmethod +import json +import logging import os +import pprint import time -import json +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from inspect import signature +from warnings import warn + +import numpy as np +import tensorflow as tf from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load -import tensorflow as tf from tensorflow.keras import optimizers -import numpy as np -import logging -import pprint -from warnings import warn -from sup3r.utilities import VERSION_RECORD import sup3r.utilities.loss_metrics +from sup3r.utilities import VERSION_RECORD logger = logging.getLogger(__name__) @@ -491,7 +493,10 @@ def init_optimizer(optimizer, learning_rate): if isinstance(optimizer, dict): class_name = optimizer['name'] OptimizerClass = getattr(optimizers, class_name) - optimizer = OptimizerClass.from_config(optimizer) + sig = signature(OptimizerClass) + optimizer_kwargs = {k: v for k, v in optimizer.items() + if k in sig.parameters} + optimizer = OptimizerClass.from_config(optimizer_kwargs) elif optimizer is None: optimizer = optimizers.Adam(learning_rate=learning_rate) @@ -518,7 +523,7 @@ def load_saved_params(out_dir, verbose=True): """ fp_params = os.path.join(out_dir, 'model_params.json') - with open(fp_params, 'r') as f: + with open(fp_params) as f: params = json.load(f) # using the saved model dir makes this more portable @@ -705,7 +710,7 @@ def early_stop(history, column, threshold=0.005, n_epoch=5): return stop @abstractmethod - def save(self, low_res): + def save(self, out_dir): """Save the model with its sub-networks to a directory. Parameters @@ -922,6 +927,7 @@ class AbstractWindInterface(ABC): Abstract class to define the required training interface for Sup3r wind model subclasses """ + # pylint: disable=E0211 @staticmethod def set_model_params(**kwargs): diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 6fdb4231d0..7f7220cc75 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1,20 +1,20 @@ # -*- coding: utf-8 -*- """Sup3r model software""" import copy +import logging import os +import pprint import time -import logging +from warnings import warn + import numpy as np -import pprint import pandas as pd import tensorflow as tf from tensorflow.keras import optimizers -from warnings import warn from sup3r.models.abstract import AbstractInterface, AbstractSingleModel from sup3r.utilities import VERSION_RECORD - logger = logging.getLogger(__name__) @@ -240,7 +240,6 @@ def _tf_generate(self, low_res): hi_res : tf.Tensor Synthetically generated high-resolution data """ - hi_res = self.generator.layers[0](low_res) for i, layer in enumerate(self.generator.layers[1:]): try: diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index c030d03363..d09b024296 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -429,6 +429,8 @@ def get_hr_slices(slices, enhancement, step=None): Low resolution slices to be enhanced enhancement : int Enhancement factor + step : int | None + Step size for slices Returns ------- diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index d40972c53b..ac5b20f90a 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -264,11 +264,15 @@ def file_paths(self, file_paths): """ self._file_paths = file_paths if isinstance(self._file_paths, str): - if '*' in self._file_paths: + if '*' in file_paths: self._file_paths = glob.glob(self._file_paths) else: self._file_paths = [self._file_paths] + msg = ('No valid files provided to DataHandler. ' + f'Received file_paths={file_paths}. Aborting.') + assert len(self._file_paths) > 0 and file_paths is not None, msg + self._file_paths = sorted(self._file_paths) @property @@ -376,6 +380,16 @@ def lat_lon(self, lat_lon): """Update lat lon""" self._lat_lon = lat_lon + @property + def latitude(self): + """Return latitude array""" + return self.lat_lon[..., 0] + + @property + def longitude(self): + """Return longitude array""" + return self.lat_lon[..., 1] + @property def invert_lat(self): """Whether to invert the latitude axis during data extraction. This is @@ -718,9 +732,6 @@ def __init__(self, file_paths, features, target=None, shape=None, raster_index=raster_index, temporal_slice=temporal_slice) - msg = 'No files provided to DataHandler. Aborting.' - assert file_paths is not None and bool(file_paths), msg - self.file_paths = file_paths self.features = (features if isinstance(features, (list, tuple)) else [features]) @@ -1366,7 +1377,7 @@ def get_observation_index(self): spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2]) temporal_slice = uniform_time_sampler(self.data, self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice] + [np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, np.arange(len(self.features))]) def get_next(self): """Get data for observation using random observation index. Loops @@ -2182,7 +2193,7 @@ def direct_extract(cls, handle, feature, raster_index, time_slice): # Sometimes xarray returns fields with (Times, time, lats, lons) # with a single entry in the 'time' dimension so we include this [0] if len(handle[feature].dims) == 4: - idx = tuple([time_slice] + [0] + raster_index) + idx = tuple([time_slice, 0, *raster_index]) elif len(handle[feature].dims) == 3: idx = tuple([time_slice, *raster_index]) else: @@ -2898,11 +2909,11 @@ def get_observation_index(self): t_slice_hourly = slice(t_slice_0.start, t_slice_1.stop) t_slice_daily = slice(rand_day_ind, rand_day_ind + n_days) - obs_ind_hourly = tuple([*spatial_slice, t_slice_hourly] - + [np.arange(len(self.features))]) + obs_ind_hourly = tuple([*spatial_slice, t_slice_hourly, + np.arange(len(self.features))]) - obs_ind_daily = tuple([*spatial_slice, t_slice_daily] - + [np.arange(len(self.features))]) + obs_ind_daily = tuple([*spatial_slice, t_slice_daily, + np.arange(len(self.features))]) return obs_ind_hourly, obs_ind_daily @@ -3121,7 +3132,7 @@ def get_observation_index(self, temporal_weights=None, self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice] + [np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, np.arange(len(self.features))]) def get_next(self, temporal_weights=None, spatial_weights=None): """Get data for observation using weighted random observation index. diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 0eae292f73..885ad18032 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -1,28 +1,27 @@ # -*- coding: utf-8 -*- """Test the basic training of super resolution GAN""" -import os import json +import os +import tempfile + import numpy as np import pytest -import tempfile import tensorflow as tf -from tensorflow.python.framework.errors_impl import InvalidArgumentError - from rex import init_logger +from tensorflow.python.framework.errors_impl import InvalidArgumentError -from sup3r import TEST_DATA_DIR -from sup3r import CONFIG_DIR +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models import Sup3rGan from sup3r.models.data_centric import Sup3rGanDC, Sup3rGanSpatialDC -from sup3r.preprocessing.data_handling import (DataHandlerH5, - DataHandlerDCforH5) -from sup3r.preprocessing.batch_handling import (BatchHandler, - BatchHandlerDC, - SpatialBatchHandler, - BatchHandlerSpatialDC) +from sup3r.preprocessing.batch_handling import ( + BatchHandler, + BatchHandlerDC, + BatchHandlerSpatialDC, + SpatialBatchHandler, +) +from sup3r.preprocessing.data_handling import DataHandlerDCforH5, DataHandlerH5 from sup3r.utilities.loss_metrics import MmdMseLoss - FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) FEATURES = ['U_100m', 'V_100m'] @@ -295,7 +294,7 @@ def test_train_st(n_epoch=2, log=False): model.save(out_dir) loaded = model.load(out_dir) - with open(os.path.join(out_dir, 'model_params.json'), 'r') as f: + with open(os.path.join(out_dir, 'model_params.json')) as f: model_params = json.load(f) assert np.allclose(model_params['optimizer']['learning_rate'], 5e-5) From d5986cff37300cd51ea15dfc6a9396cec6449609 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 5 Jul 2023 12:51:58 -0600 Subject: [PATCH 2/3] edits for multi step spatial only model --- sup3r/pipeline/forward_pass.py | 18 +++-- sup3r/preprocessing/data_handling.py | 4 +- tests/forward_pass/test_forward_pass.py | 101 +++++++++++++++++++++--- 3 files changed, 104 insertions(+), 19 deletions(-) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index d09b024296..1467a0b484 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1589,16 +1589,20 @@ def _reshape_data_chunk(model, data_chunk, exo_data): i_lr_s : int Axis index for the low-resolution spatial_1 dimension """ - + current_model = None if exo_data is not None: for i, arr in enumerate(exo_data): if arr is not None: - current_model = (model if not hasattr(model, 'models') - else model.models[i]) - if current_model.input_dims == 4: - exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3)) - else: - exo_data[i] = np.expand_dims(arr, axis=0) + if not hasattr(model, 'models'): + current_model = model + elif i < len(model.models): + current_model = model.models[i] + + if current_model is not None: + if current_model.input_dims == 4: + exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3)) + else: + exo_data[i] = np.expand_dims(arr, axis=0) if model.input_dims == 4: i_lr_t = 0 diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index ac5b20f90a..b8ea75ea44 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -549,7 +549,7 @@ def time_freq_hours(self): """Get the time frequency in hours as a float""" ti_deltas = self.raw_time_index - np.roll(self.raw_time_index, 1) ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 - time_freq = float(mode(ti_deltas_hours).mode[0]) + time_freq = float(mode(ti_deltas_hours).mode) return time_freq @property @@ -2546,7 +2546,7 @@ def get_clearsky_ghi(self): ti_deltas = ti_nsrdb - np.roll(ti_nsrdb, 1) ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 - time_freq = float(mode(ti_deltas_hours).mode[0]) + time_freq = float(mode(ti_deltas_hours).mode) t_start = self.temporal_slice.start or 0 t_end_target = self.temporal_slice.stop or len(self.raw_time_index) t_start = int(t_start * 24 * (1 / time_freq)) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index efb544d08a..e712206c07 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -2,23 +2,21 @@ """pytests for data handling""" import json import os -import pytest import tempfile -import tensorflow as tf + +import matplotlib.pyplot as plt import numpy as np +import pytest +import tensorflow as tf import xarray as xr -import matplotlib.pyplot as plt +from rex import ResourceX, init_logger -from sup3r import TEST_DATA_DIR, CONFIG_DIR, __version__ +from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ +from sup3r.models import LinearInterp, Sup3rGan, WindGan +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC -from sup3r.pipeline.forward_pass import (ForwardPass, ForwardPassStrategy) -from sup3r.models import Sup3rGan, WindGan, LinearInterp from sup3r.utilities.pytest import make_fake_nc_files -from rex import ResourceX -from rex import init_logger - - FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] @@ -562,6 +560,89 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): 'topography'] +def test_fwp_multi_step_spatial_model_topo_noskip(): + """Test the forward pass with a multi step spatial only model class using + exogenous data for all model steps""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 1] + s_enhance = np.product(s_enhancements) + + exo_kwargs = {'file_paths': input_files, + 'features': ['topography'], + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [12, 4, 2] + } + + model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == ( + len(input_files), + s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs for f in ('windspeed_100m', + 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 2 # two step model + assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', + 'topography'] + + def test_fwp_multi_step_model_topo_noskip(): """Test the forward pass with a multi step model class using exogenous data for all model steps""" From 0fde99b30dca61faf5a366e842b4af24eda22601 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 7 Jul 2023 10:14:40 -0600 Subject: [PATCH 3/3] pr review changes --- sup3r/preprocessing/data_handling.py | 2 +- sup3r/utilities/loss_metrics.py | 38 +++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index b8ea75ea44..9245dc8cc5 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -271,7 +271,7 @@ def file_paths(self, file_paths): msg = ('No valid files provided to DataHandler. ' f'Received file_paths={file_paths}. Aborting.') - assert len(self._file_paths) > 0 and file_paths is not None, msg + assert file_paths is not None and len(self._file_paths) > 0, msg self._file_paths = sorted(self._file_paths) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 312397f8c5..64a0680ed7 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -1,7 +1,7 @@ """Loss metrics for Sup3r""" -from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError def gaussian_kernel(x1, x2, sigma=1.0): @@ -173,6 +173,42 @@ def __call__(self, x1, x2): return self.MSE_LOSS(x1_coarse, x2_coarse) +class SpatialExtremesLoss(tf.keras.losses.Loss): + """Loss class that encourages accuracy of the min/max values in the + spatial domain""" + + MAE_LOSS = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Custom content loss that encourages temporal min/max accuracy + + Parameters + ---------- + x1 : tf.tensor + synthetic generator output + (n_observations, spatial_1, spatial_2, features) + x2 : tf.tensor + high resolution data + (n_observations, spatial_1, spatial_2, features) + + Returns + ------- + tf.tensor + 0D tensor with loss value + """ + x1_min = tf.reduce_min(x1, axis=(1, 2)) + x2_min = tf.reduce_min(x2, axis=(1, 2)) + + x1_max = tf.reduce_max(x1, axis=(1, 2)) + x2_max = tf.reduce_max(x2, axis=(1, 2)) + + mae = self.MAE_LOSS(x1, x2) + mae_min = self.MAE_LOSS(x1_min, x2_min) + mae_max = self.MAE_LOSS(x1_max, x2_max) + + return mae + mae_min + mae_max + + class TemporalExtremesLoss(tf.keras.losses.Loss): """Loss class that encourages accuracy of the min/max values in the timeseries"""