Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/dev #152

Merged
merged 3 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
"""
Abstract class to define the required interface for Sup3r model subclasses
"""
from abc import ABC, abstractmethod
import json
import logging
import os
import pprint
import time
import json
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from inspect import signature
from warnings import warn

import numpy as np
import tensorflow as tf
from phygnn import CustomNetwork
from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat
from rex.utilities.utilities import safe_json_load
import tensorflow as tf
from tensorflow.keras import optimizers
import numpy as np
import logging
import pprint
from warnings import warn

from sup3r.utilities import VERSION_RECORD
import sup3r.utilities.loss_metrics
from sup3r.utilities import VERSION_RECORD

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -491,7 +493,10 @@ def init_optimizer(optimizer, learning_rate):
if isinstance(optimizer, dict):
class_name = optimizer['name']
OptimizerClass = getattr(optimizers, class_name)
optimizer = OptimizerClass.from_config(optimizer)
sig = signature(OptimizerClass)
optimizer_kwargs = {k: v for k, v in optimizer.items()
if k in sig.parameters}
optimizer = OptimizerClass.from_config(optimizer_kwargs)
elif optimizer is None:
optimizer = optimizers.Adam(learning_rate=learning_rate)

Expand All @@ -518,7 +523,7 @@ def load_saved_params(out_dir, verbose=True):
"""

fp_params = os.path.join(out_dir, 'model_params.json')
with open(fp_params, 'r') as f:
with open(fp_params) as f:
params = json.load(f)

# using the saved model dir makes this more portable
Expand Down Expand Up @@ -705,7 +710,7 @@ def early_stop(history, column, threshold=0.005, n_epoch=5):
return stop

@abstractmethod
def save(self, low_res):
def save(self, out_dir):
"""Save the model with its sub-networks to a directory.

Parameters
Expand Down Expand Up @@ -922,6 +927,7 @@ class AbstractWindInterface(ABC):
Abstract class to define the required training interface
for Sup3r wind model subclasses
"""

# pylint: disable=E0211
@staticmethod
def set_model_params(**kwargs):
Expand Down
9 changes: 4 additions & 5 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""Sup3r model software"""
import copy
import logging
import os
import pprint
import time
import logging
from warnings import warn

import numpy as np
import pprint
import pandas as pd
import tensorflow as tf
from tensorflow.keras import optimizers
from warnings import warn

from sup3r.models.abstract import AbstractInterface, AbstractSingleModel
from sup3r.utilities import VERSION_RECORD


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -240,7 +240,6 @@ def _tf_generate(self, low_res):
hi_res : tf.Tensor
Synthetically generated high-resolution data
"""

hi_res = self.generator.layers[0](low_res)
for i, layer in enumerate(self.generator.layers[1:]):
try:
Expand Down
20 changes: 13 additions & 7 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ def get_hr_slices(slices, enhancement, step=None):
Low resolution slices to be enhanced
enhancement : int
Enhancement factor
step : int | None
Step size for slices

Returns
-------
Expand Down Expand Up @@ -1587,16 +1589,20 @@ def _reshape_data_chunk(model, data_chunk, exo_data):
i_lr_s : int
Axis index for the low-resolution spatial_1 dimension
"""

current_model = None
bnb32 marked this conversation as resolved.
Show resolved Hide resolved
if exo_data is not None:
for i, arr in enumerate(exo_data):
if arr is not None:
current_model = (model if not hasattr(model, 'models')
else model.models[i])
if current_model.input_dims == 4:
exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3))
else:
exo_data[i] = np.expand_dims(arr, axis=0)
if not hasattr(model, 'models'):
current_model = model
elif i < len(model.models):
current_model = model.models[i]

if current_model is not None:
if current_model.input_dims == 4:
exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3))
else:
exo_data[i] = np.expand_dims(arr, axis=0)

if model.input_dims == 4:
i_lr_t = 0
Expand Down
37 changes: 24 additions & 13 deletions sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,15 @@ def file_paths(self, file_paths):
"""
self._file_paths = file_paths
if isinstance(self._file_paths, str):
if '*' in self._file_paths:
if '*' in file_paths:
self._file_paths = glob.glob(self._file_paths)
else:
self._file_paths = [self._file_paths]

msg = ('No valid files provided to DataHandler. '
f'Received file_paths={file_paths}. Aborting.')
assert len(self._file_paths) > 0 and file_paths is not None, msg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better to check the None condition first. If file_paths is None, you'll get a weird error on cannot run len(None) instead of is not None i think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good call


self._file_paths = sorted(self._file_paths)

@property
Expand Down Expand Up @@ -376,6 +380,16 @@ def lat_lon(self, lat_lon):
"""Update lat lon"""
self._lat_lon = lat_lon

@property
def latitude(self):
"""Return latitude array"""
return self.lat_lon[..., 0]

@property
def longitude(self):
"""Return longitude array"""
return self.lat_lon[..., 1]

@property
def invert_lat(self):
"""Whether to invert the latitude axis during data extraction. This is
Expand Down Expand Up @@ -535,7 +549,7 @@ def time_freq_hours(self):
"""Get the time frequency in hours as a float"""
ti_deltas = self.raw_time_index - np.roll(self.raw_time_index, 1)
ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600
time_freq = float(mode(ti_deltas_hours).mode[0])
time_freq = float(mode(ti_deltas_hours).mode)
return time_freq

@property
Expand Down Expand Up @@ -718,9 +732,6 @@ def __init__(self, file_paths, features, target=None, shape=None,
raster_index=raster_index,
temporal_slice=temporal_slice)

msg = 'No files provided to DataHandler. Aborting.'
assert file_paths is not None and bool(file_paths), msg

self.file_paths = file_paths
self.features = (features if isinstance(features, (list, tuple))
else [features])
Expand Down Expand Up @@ -1366,7 +1377,7 @@ def get_observation_index(self):
spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2])
temporal_slice = uniform_time_sampler(self.data, self.sample_shape[2])
return tuple(
[*spatial_slice, temporal_slice] + [np.arange(len(self.features))])
[*spatial_slice, temporal_slice, np.arange(len(self.features))])

def get_next(self):
"""Get data for observation using random observation index. Loops
Expand Down Expand Up @@ -2182,7 +2193,7 @@ def direct_extract(cls, handle, feature, raster_index, time_slice):
# Sometimes xarray returns fields with (Times, time, lats, lons)
# with a single entry in the 'time' dimension so we include this [0]
if len(handle[feature].dims) == 4:
idx = tuple([time_slice] + [0] + raster_index)
idx = tuple([time_slice, 0, *raster_index])
elif len(handle[feature].dims) == 3:
idx = tuple([time_slice, *raster_index])
else:
Expand Down Expand Up @@ -2535,7 +2546,7 @@ def get_clearsky_ghi(self):

ti_deltas = ti_nsrdb - np.roll(ti_nsrdb, 1)
ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600
time_freq = float(mode(ti_deltas_hours).mode[0])
time_freq = float(mode(ti_deltas_hours).mode)
t_start = self.temporal_slice.start or 0
t_end_target = self.temporal_slice.stop or len(self.raw_time_index)
t_start = int(t_start * 24 * (1 / time_freq))
Expand Down Expand Up @@ -2898,11 +2909,11 @@ def get_observation_index(self):
t_slice_hourly = slice(t_slice_0.start, t_slice_1.stop)
t_slice_daily = slice(rand_day_ind, rand_day_ind + n_days)

obs_ind_hourly = tuple([*spatial_slice, t_slice_hourly]
+ [np.arange(len(self.features))])
obs_ind_hourly = tuple([*spatial_slice, t_slice_hourly,
np.arange(len(self.features))])

obs_ind_daily = tuple([*spatial_slice, t_slice_daily]
+ [np.arange(len(self.features))])
obs_ind_daily = tuple([*spatial_slice, t_slice_daily,
np.arange(len(self.features))])

return obs_ind_hourly, obs_ind_daily

Expand Down Expand Up @@ -3121,7 +3132,7 @@ def get_observation_index(self, temporal_weights=None,
self.sample_shape[2])

return tuple(
[*spatial_slice, temporal_slice] + [np.arange(len(self.features))])
[*spatial_slice, temporal_slice, np.arange(len(self.features))])

def get_next(self, temporal_weights=None, spatial_weights=None):
"""Get data for observation using weighted random observation index.
Expand Down
101 changes: 91 additions & 10 deletions tests/forward_pass/test_forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
"""pytests for data handling"""
import json
import os
import pytest
import tempfile
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import pytest
import tensorflow as tf
import xarray as xr
import matplotlib.pyplot as plt
from rex import ResourceX, init_logger

from sup3r import TEST_DATA_DIR, CONFIG_DIR, __version__
from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__
from sup3r.models import LinearInterp, Sup3rGan, WindGan
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
from sup3r.preprocessing.data_handling import DataHandlerNC
from sup3r.pipeline.forward_pass import (ForwardPass, ForwardPassStrategy)
from sup3r.models import Sup3rGan, WindGan, LinearInterp
from sup3r.utilities.pytest import make_fake_nc_files

from rex import ResourceX
from rex import init_logger


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
FEATURES = ['U_100m', 'V_100m', 'BVF2_200m']
Expand Down Expand Up @@ -562,6 +560,89 @@ def test_fwp_multi_step_model_topo_exoskip(log=False):
'topography']


def test_fwp_multi_step_spatial_model_topo_noskip():
"""Test the forward pass with a multi step spatial only model class using
exogenous data for all model steps"""
Sup3rGan.seed()
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')
s1_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4)
s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography']
s1_model.meta['output_features'] = ['U_100m', 'V_100m']
s1_model.meta['s_enhance'] = 2
s1_model.meta['t_enhance'] = 1
_ = s1_model.generate(np.ones((4, 10, 10, 3)))

s2_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4)
s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography']
s2_model.meta['output_features'] = ['U_100m', 'V_100m']
s2_model.meta['s_enhance'] = 2
s2_model.meta['t_enhance'] = 1
_ = s2_model.generate(np.ones((4, 10, 10, 3)))

with tempfile.TemporaryDirectory() as td:
input_files = make_fake_nc_files(td, INPUT_FILE, 8)

s1_out_dir = os.path.join(td, 's1_gan')
s2_out_dir = os.path.join(td, 's2_gan')
s1_model.save(s1_out_dir)
s2_model.save(s2_out_dir)

max_workers = 1
fwp_chunk_shape = (4, 4, 8)
s_enhancements = [2, 2, 1]
s_enhance = np.product(s_enhancements)

exo_kwargs = {'file_paths': input_files,
'features': ['topography'],
'source_file': FP_WTK,
'target': target,
'shape': shape,
's_enhancements': [1, 2, 2],
'agg_factors': [12, 4, 2]
}

model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]}

out_files = os.path.join(td, 'out_{file_id}.h5')
input_handler_kwargs = dict(
target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=max_workers),
overwrite_cache=True)
handler = ForwardPassStrategy(
input_files, model_kwargs=model_kwargs,
model_class='MultiStepGan',
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=1, temporal_pad=1,
input_handler_kwargs=input_handler_kwargs,
out_pattern=out_files,
worker_kwargs=dict(max_workers=max_workers),
exo_kwargs=exo_kwargs,
max_nodes=1)

forward_pass = ForwardPass(handler)
forward_pass.run(handler, node_index=0)

with ResourceX(handler.out_files[0]) as fh:
assert fh.shape == (
len(input_files),
s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1])
assert all(f in fh.attrs for f in ('windspeed_100m',
'winddirection_100m'))

assert fh.global_attrs['package'] == 'sup3r'
assert fh.global_attrs['version'] == __version__
assert 'full_version_record' in fh.global_attrs
version_record = json.loads(fh.global_attrs['full_version_record'])
assert version_record['tensorflow'] == tf.__version__
assert 'gan_meta' in fh.global_attrs
gan_meta = json.loads(fh.global_attrs['gan_meta'])
assert len(gan_meta) == 2 # two step model
assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m',
'topography']


def test_fwp_multi_step_model_topo_noskip():
"""Test the forward pass with a multi step model class using exogenous data
for all model steps"""
Expand Down
Loading