Skip to content

Commit

Permalink
trim exo_handler_kwargs for features in model.lr_features or model.hr…
Browse files Browse the repository at this point in the history
…_exo_features
  • Loading branch information
bnb32 committed Sep 21, 2024
1 parent f4619a4 commit f9f9ac5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
12 changes: 12 additions & 0 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from scipy.interpolate import RegularGridInterpolator
from tensorflow.keras import optimizers

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +57,17 @@ def run(self):
model_thread.join()


def get_optimizer_class(conf):
"""Get optimizer class from keras"""
if hasattr(optimizers, conf['name']):
optimizer_class = getattr(optimizers, conf['name'])
else:
msg = '%s not found in keras optimizers.'
logger.error(msg, conf['name'])
raise ValueError(msg)
return optimizer_class


def st_interp(low, s_enhance, t_enhance, t_centered=False):
"""Spatiotemporal bilinear interpolation for low resolution field on a
regular grid. Used to provide baseline for comparison with gan output
Expand Down
9 changes: 7 additions & 2 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def init_input_handler(self):
def _init_features(self, model):
"""Initialize feature attributes."""
self.exo_handler_kwargs = self.exo_handler_kwargs or {}
exo_features = list(self.exo_handler_kwargs)
possible_exo_feats = set(model.hr_exo_features + model.lr_features)
exo_kwargs_feats = list(self.exo_handler_kwargs)
exo_features = list(possible_exo_feats.intersection(exo_kwargs_feats))
features = [f for f in model.lr_features if f not in exo_features]
return features, exo_features

Expand Down Expand Up @@ -334,7 +336,7 @@ def preflight(self):
self.lr_slices, self.lr_pad_slices, self.hr_slices = out

non_masked = self.fwp_slicer.n_spatial_chunks - sum(self.fwp_mask)
non_masked *= self.fwp_slicer.n_time_chunks
non_masked *= int(self.fwp_slicer.n_time_chunks)
log_dict = {
'n_nodes': len(self.node_chunks),
'n_spatial_chunks': self.fwp_slicer.n_spatial_chunks,
Expand Down Expand Up @@ -455,6 +457,9 @@ def prep_chunk_data(self, chunk_index=0):
kwargs = dict(zip(Dimension.dims_2d(), lr_pad_slice))
kwargs[Dimension.TIME] = ti_pad_slice
input_data = self.input_handler.isel(**kwargs)
logger.info(
'Loading data for chunk_index=%s into memory.', chunk_index
)
input_data.load()

if self.bias_correct_kwargs != {}:
Expand Down
7 changes: 6 additions & 1 deletion sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def merge_datasets(files, **kwargs):
if 'longitude' in dset.dims:
dset = dset.swap_dims({'longitude': 'west_east'})
dsets[i] = dset
# temporary to handle downloaded era files
if 'expver' in dset:
dset.drop_vars('expver')
if 'number' in dset:
dset.drop_vars('number')
out = xr.merge(dsets, **get_class_kwargs(xr.merge, kwargs))
msg = ('Merged time index does not have the same number of time steps '
'(%s) as the sum of the individual time index steps (%s).')
Expand All @@ -50,7 +55,7 @@ def merge_datasets(files, **kwargs):

def xr_open_mfdataset(files, **kwargs):
"""Wrapper for xr.open_mfdataset with default opening options."""
default_kwargs = {'engine': 'netcdf4'}
default_kwargs = {'engine': 'netcdf4', 'chunks': 'auto'}
default_kwargs.update(kwargs)
try:
return xr.open_mfdataset(files, **default_kwargs)
Expand Down

0 comments on commit f9f9ac5

Please sign in to comment.