Skip to content

Commit

Permalink
Merge branch 'bnb/dh_refactor' into gb/presrat_updates_2
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Sep 12, 2024
2 parents a5f42b8 + c21ed3c commit ae2324b
Show file tree
Hide file tree
Showing 15 changed files with 526 additions and 232 deletions.
12 changes: 6 additions & 6 deletions sup3r/bias/bias_calc_vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import dask
import numpy as np
import pandas as pd
import xarray as xr
from rex import Resource
from scipy.interpolate import interp1d

from sup3r.postprocessing import OutputHandler, RexOutputs
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import xr_open_mfdataset

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,7 +114,7 @@ def convert_month_height_tif(self, month, height):
os.remove(outfile)

if not os.path.exists(outfile) or self.overwrite:
ds = xr_open_mfdataset(infile)
ds = xr.open_mfdataset(infile)
ds = ds.rename(
{
'band_data': f'windspeed_{height}m',
Expand Down Expand Up @@ -142,7 +142,7 @@ def convert_all_tifs(self):
def mask(self):
"""Mask coordinates without data"""
if self._mask is None:
with xr_open_mfdataset(self.get_height_files('January')) as res:
with xr.open_mfdataset(self.get_height_files('January')) as res:
mask = (res[self.in_features[0]] != -999) & (
~np.isnan(res[self.in_features[0]])
)
Expand Down Expand Up @@ -173,13 +173,13 @@ def get_month(self, month):

if os.path.exists(month_file) and not self.overwrite:
logger.info(f'Loading month_file {month_file}.')
data = xr_open_mfdataset(month_file)
data = xr.open_mfdataset(month_file)
else:
logger.info(
'Getting mean windspeed for all heights '
f'({self.in_heights}) for {month}'
)
data = xr_open_mfdataset(self.get_height_files(month))
data = xr.open_mfdataset(self.get_height_files(month))
logger.info(
'Interpolating windspeed for all heights '
f'({self.out_heights}) for {month}.'
Expand Down Expand Up @@ -239,7 +239,7 @@ def interp(self, data):

def get_lat_lon(self):
"""Get lat lon grid"""
with xr_open_mfdataset(self.get_height_files('January')) as res:
with xr.open_mfdataset(self.get_height_files('January')) as res:
lons, lats = np.meshgrid(
res['longitude'].values, res['latitude'].values
)
Expand Down
33 changes: 28 additions & 5 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ def save(self, out_dir):
logger.info('Saved GAN to disk in directory: {}'.format(out_dir))

@classmethod
def load(cls, model_dir, verbose=True):
"""Load the GAN with its sub-networks from a previously saved-to output
directory.
def _load(cls, model_dir, verbose=True):
"""Get gen, disc, and params for given model_dir.
Parameters
----------
Expand All @@ -166,8 +165,12 @@ def load(cls, model_dir, verbose=True):
Returns
-------
out : BaseModel
Returns a pretrained gan model that was previously saved to out_dir
fp_gen : str
Path to generator model
fp_disc : str
Path to discriminator model
params : dict
Dictionary of model params to be used in model initialization
"""
if verbose:
logger.info(
Expand All @@ -182,6 +185,26 @@ def load(cls, model_dir, verbose=True):
fp_disc = os.path.join(model_dir, 'model_disc.pkl')
params = cls.load_saved_params(model_dir, verbose=verbose)

return fp_gen, fp_disc, params

@classmethod
def load(cls, model_dir, verbose=True):
"""Load the GAN with its sub-networks from a previously saved-to output
directory.
Parameters
----------
model_dir : str
Directory to load GAN model files from.
verbose : bool
Flag to log information about the loaded model.
Returns
-------
out : BaseModel
Returns a pretrained gan model that was previously saved to out_dir
"""
fp_gen, fp_disc, params = cls._load(model_dir, verbose=verbose)
return cls(fp_gen, fp_disc, **params)

@property
Expand Down
30 changes: 22 additions & 8 deletions sup3r/models/multi_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Sup3r multi step model frameworks"""
"""Sup3r multi step model frameworks
TODO: SolarMultiStepGan can be cleaned up a little with the output padding and
t_enhance argument moved to SolarCC.
"""

import json
import logging
Expand Down Expand Up @@ -35,7 +39,7 @@ def __len__(self):
return len(self._models)

@classmethod
def load(cls, model_dirs, verbose=True):
def load(cls, model_dirs, model_kwargs=None, verbose=True):
"""Load the GANs with its sub-networks from a previously saved-to
output directory.
Expand All @@ -44,6 +48,9 @@ def load(cls, model_dirs, verbose=True):
model_dirs : list | tuple
An ordered list/tuple of one or more directories containing trained
+ saved Sup3rGan models created using the Sup3rGan.save() method.
model_kwargs : list | tuple
An ordered list/tuple of one or more dictionaries containing kwargs
for the corresponding model in model_dirs
verbose : bool
Flag to log information about the loaded model.
Expand All @@ -55,11 +62,14 @@ def load(cls, model_dirs, verbose=True):
"""

models = []

if isinstance(model_dirs, str):
model_dirs = [model_dirs]

for model_dir in model_dirs:
model_kwargs = model_kwargs or [{}] * len(model_dirs)
if isinstance(model_kwargs, dict):
model_kwargs = [model_kwargs]

for model_dir, kwargs in zip(model_dirs, model_kwargs):
fp_params = os.path.join(model_dir, 'model_params.json')
assert os.path.exists(fp_params), f'Could not find: {fp_params}'
with open(fp_params) as f:
Expand All @@ -68,7 +78,9 @@ def load(cls, model_dirs, verbose=True):
meta = params.get('meta', {'class': 'Sup3rGan'})
class_name = meta.get('class', 'Sup3rGan')
Sup3rClass = getattr(sup3r.models, class_name)
models.append(Sup3rClass.load(model_dir, verbose=verbose))
models.append(
Sup3rClass.load(model_dir, verbose=verbose, **kwargs)
)

return cls(models)

Expand Down Expand Up @@ -841,9 +853,11 @@ def load(
spatial_solar_models and the spatial_wind_models.
t_enhance : int | None
Optional argument to fix or update the temporal enhancement of the
model. This can be used with temporal_pad to manipulate the output
shape to match whatever padded shape the sup3r forward pass module
expects.
model. This can be used to manipulate the output shape to match
whatever padded shape the sup3r forward pass module expects. If
this differs from the t_enhance value based on model layers the
output will be padded so that the output shape matches low_res *
t_enhance for the time dimension.
verbose : bool
Flag to log information about the loaded model.
Expand Down
155 changes: 135 additions & 20 deletions sup3r/models/solar_cc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Sup3r model software"""

import logging

import numpy as np
import tensorflow as tf

from sup3r.models.base import Sup3rGan
Expand All @@ -20,6 +22,8 @@ class SolarCC(Sup3rGan):
daily true high res sample.
- Discriminator sees random n_days of 8-hour samples of the daily
synthetic high res sample.
- Includes padding on high resolution output of :meth:`generate` so
that forward pass always outputs a multiple of 24 hours.
"""

# starting hour is the hour that daylight starts at, daylight hours is the
Expand All @@ -34,6 +38,27 @@ class SolarCC(Sup3rGan):
DAYLIGHT_HOURS = 8
STRIDE_LEN = 4

def __init__(self, *args, t_enhance=None, **kwargs):
"""Add optional t_enhance adjustment.
Parameters
----------
*args : list
List of arguments to parent class
t_enhance : int | None
Optional argument to fix or update the temporal enhancement of the
model. This can be used to manipulate the output shape to match
whatever padded shape the sup3r forward pass module expects. If
this differs from the t_enhance value based on model layers the
output will be padded so that the output shape matches low_res *
t_enhance for the time dimension.
**kwargs : Mappable
Keyword arguments for parent class
"""
super().__init__(*args, **kwargs)
self._t_enhance = t_enhance or self.t_enhance
self.meta['t_enhance'] = self._t_enhance

def init_weights(self, lr_shape, hr_shape, device=None):
"""Initialize the generator and discriminator weights with device
placement.
Expand Down Expand Up @@ -61,8 +86,14 @@ def init_weights(self, lr_shape, hr_shape, device=None):
super().init_weights(lr_shape, hr_shape, device=device)

@tf.function
def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001,
train_gen=True, train_disc=False):
def calc_loss(
self,
hi_res_true,
hi_res_gen,
weight_gen_advers=0.001,
train_gen=True,
train_disc=False,
):
"""Calculate the GAN loss function using generated and true high
resolution data.
Expand Down Expand Up @@ -91,33 +122,43 @@ def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001,
"""

if hi_res_gen.shape != hi_res_true.shape:
msg = ('The tensor shapes of the synthetic output {} and '
'true high res {} did not have matching shape! '
'Check the spatiotemporal enhancement multipliers in your '
'your model config and data handlers.'
.format(hi_res_gen.shape, hi_res_true.shape))
msg = (
'The tensor shapes of the synthetic output {} and '
'true high res {} did not have matching shape! '
'Check the spatiotemporal enhancement multipliers in your '
'your model config and data handlers.'.format(
hi_res_gen.shape, hi_res_true.shape
)
)
logger.error(msg)
raise RuntimeError(msg)

msg = ('Special SolarCC model can only accept multi-day hourly '
'(multiple of 24) true / synthetic high res data in the axis=3 '
'position but received shape {}'.format(hi_res_true.shape))
msg = (
'Special SolarCC model can only accept multi-day hourly '
'(multiple of 24) true / synthetic high res data in the axis=3 '
'position but received shape {}'.format(hi_res_true.shape)
)
assert hi_res_true.shape[3] % 24 == 0

t_len = hi_res_true.shape[3]
n_days = int(t_len // 24)
day_slices = [slice(self.STARTING_HOUR + x,
self.STARTING_HOUR + x + self.DAYLIGHT_HOURS)
for x in range(0, 24 * n_days, 24)]
day_slices = [
slice(
self.STARTING_HOUR + x,
self.STARTING_HOUR + x + self.DAYLIGHT_HOURS,
)
for x in range(0, 24 * n_days, 24)
]

# sample only daylight hours for disc training and gen content loss
disc_out_true = []
disc_out_gen = []
loss_gen_content = 0.0
for tslice in day_slices:
disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice, :])
gen_c = self.calc_loss_gen_content(hi_res_true[:, :, :, tslice, :],
hi_res_gen[:, :, :, tslice, :])
gen_c = self.calc_loss_gen_content(
hi_res_true[:, :, :, tslice, :], hi_res_gen[:, :, :, tslice, :]
)
disc_out_true.append(disc_t)
loss_gen_content += gen_c

Expand Down Expand Up @@ -146,10 +187,84 @@ def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001,
elif train_disc:
loss = loss_disc

loss_details = {'loss_gen': loss_gen,
'loss_gen_content': loss_gen_content,
'loss_gen_advers': loss_gen_advers,
'loss_disc': loss_disc,
}
loss_details = {
'loss_gen': loss_gen,
'loss_gen_content': loss_gen_content,
'loss_gen_advers': loss_gen_advers,
'loss_disc': loss_disc,
}

return loss, loss_details

def temporal_pad(self, low_res, hi_res, mode='reflect'):
"""Optionally add temporal padding to the 5D generated output array
Parameters
----------
low_res : np.ndarray
Low-resolution input data to the spatio(temporal) GAN, which is a
5D array of shape: (1, spatial_1, spatial_2, n_temporal,
n_features).
hi_res : ndarray
Synthetically generated high-resolution data output from the
(spatio)temporal GAN with a 5D array shape:
(1, spatial_1, spatial_2, n_temporal, n_features)
mode : str
Padding mode for np.pad()
Returns
-------
hi_res : ndarray
Synthetically generated high-resolution data output from the
(spatio)temporal GAN with a 5D array shape:
(1, spatial_1, spatial_2, n_temporal, n_features)
With the temporal axis padded with self._temporal_pad on either
side.
"""
t_shape = low_res.shape[-2] * self._t_enhance
t_pad = int((t_shape - hi_res.shape[-2]) / 2)
pad_width = ((0, 0), (0, 0), (0, 0), (t_pad, t_pad), (0, 0))
prepad_shape = hi_res.shape
hi_res = np.pad(hi_res, pad_width, mode=mode)
logger.debug(
'Padded hi_res output from %s to %s', prepad_shape, hi_res.shape
)
return hi_res

def generate(self, low_res, **kwargs):
"""Override parent method to apply padding on high res output."""

hi_res = self.temporal_pad(
low_res, super().generate(low_res=low_res, **kwargs)
)

logger.debug('Final SolarCC output has shape: {}'.format(hi_res.shape))

return hi_res

@classmethod
def load(cls, model_dir, t_enhance=None, verbose=True):
"""Load the GAN with its sub-networks from a previously saved-to output
directory.
Parameters
----------
model_dir : str
Directory to load GAN model files from.
t_enhance : int | None
Optional argument to fix or update the temporal enhancement of the
model. This can be used to manipulate the output shape to match
whatever padded shape the sup3r forward pass module expects. If
this differs from the t_enhance value based on model layers the
output will be padded so that the output shape matches low_res *
t_enhance for the time dimension.
verbose : bool
Flag to log information about the loaded model.
Returns
-------
out : BaseModel
Returns a pretrained gan model that was previously saved to out_dir
"""
fp_gen, fp_disc, params = cls._load(model_dir, verbose=verbose)
return cls(fp_gen, fp_disc, t_enhance=t_enhance, **params)
Loading

0 comments on commit ae2324b

Please sign in to comment.