Skip to content

Commit

Permalink
moved normalization to outside of batch iteration. more performant.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jul 23, 2024
1 parent 0a66f82 commit 508812c
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 102 deletions.
6 changes: 6 additions & 0 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ def std(self, **kwargs):
)
return type(self)(out) if isinstance(out, xr.Dataset) else out

def normalize(self, means, stds):
"""Normalize dataset using given means and stds. These are provided as
dictionaries."""
for f in self.features:
self._ds[f] = (self._ds[f] - means[f]) / stds[f]

def interpolate_na(self, **kwargs):
"""Use `xr.DataArray.interpolate_na` to fill NaN values with a dask
compatible method."""
Expand Down
7 changes: 6 additions & 1 deletion sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,14 @@ def std(self, **kwargs):
kwargs['skipna'] = kwargs.get('skipna', True)
return self._ds[-1].std(**kwargs)

def normalize(self, means, stds):
"""Normalize dataset using the given mean and stds. These are provided
as dictionaries."""
_ = [d.normalize(means=means, stds=stds) for d in self._ds]

def compute(self, **kwargs):
"""Load data into memory for each data member."""
_ = [data.compute(**kwargs) for data in self._ds]
_ = [d.compute(**kwargs) for d in self._ds]

@property
def loaded(self):
Expand Down
68 changes: 7 additions & 61 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
import threading
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn
from typing import Dict, List, Optional, Union

import numpy as np
import tensorflow as tf
from rex import safe_json_load

from sup3r.preprocessing.collections.base import Collection
from sup3r.preprocessing.samplers import DualSampler, Sampler
from sup3r.typing import T_Array
from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,17 +246,16 @@ def transform(self, samples, **kwargs):

def _post_proc(self, samples) -> Batch:
"""Performs some post proc on dequeued samples before sending out for
training. Post processing can include normalization, coarsening on
high-res data (if :class:`Collection` consists of :class:`Sampler`
objects and not :class:`DualSampler` objects), smoothing, etc
training. Post processing can include coarsening on high-res data (if
:class:`Collection` consists of :class:`Sampler` objects and not
:class:`DualSampler` objects), smoothing, etc
Returns
-------
Batch : namedtuple
namedtuple with `low_res` and `high_res` attributes
"""
lr, hr = self.transform(samples, **self.transform_kwargs)
lr, hr = self.normalize(lr, hr)
return self.Batch(low_res=lr, high_res=hr)

def start(self) -> None:
Expand Down Expand Up @@ -328,64 +324,14 @@ def __next__(self) -> Batch:
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
batch = self.timer(self._post_proc, log=True)(samples)
batch = self.timer(self._post_proc, log=True)(
samples
)
self._batch_counter += 1
else:
raise StopIteration
return batch

@staticmethod
def _get_stats(means, stds, features):
msg = (f'Some of the features: {features} not found in the provided '
f'means: {means}')
assert all(f in means for f in features), msg
msg = (f'Some of the features: {features} not found in the provided '
f'stds: {stds}')
assert all(f in stds for f in features), msg
f_means = np.array([means[k] for k in features]).astype(np.float32)
f_stds = np.array([stds[k] for k in features]).astype(np.float32)
return f_means, f_stds

def get_stats(self, means, stds):
"""Get means / stds from given files / dicts and group these into
low-res / high-res stats."""
means = means if isinstance(means, dict) else safe_json_load(means)
stds = stds if isinstance(stds, dict) else safe_json_load(stds)
msg = (
f'Received means = {means} with self.features = '
f'{self.features}. Make sure the means are valid, since they '
'clearly come from a different training run.'
)

if len(means) != len(self.features):
logger.warning(msg)
warn(msg)
msg = (
f'Received stds = {stds} with self.features = '
f'{self.features}. Make sure the stds are valid, since they '
'clearly come from a different training run.'
)
if len(stds) != len(self.features):
logger.warning(msg)
warn(msg)

lr_means, lr_stds = self._get_stats(means, stds, self.lr_features)
hr_means, hr_stds = self._get_stats(means, stds, self.hr_features)
return means, lr_means, hr_means, stds, lr_stds, hr_stds

@staticmethod
def _normalize(array, means, stds):
"""Normalize an array with given means and stds."""
return (array - means) / stds

def normalize(self, lr, hr) -> Tuple[T_Array, T_Array]:
"""Normalize a low-res / high-res pair with the stored means and
stdevs."""
return (
self._normalize(lr, self.lr_means, self.lr_stds),
self._normalize(hr, self.hr_means, self.hr_stds),
)

def get_container_index(self):
"""Get random container index based on weights"""
indices = np.arange(0, len(self.containers))
Expand Down
1 change: 0 additions & 1 deletion sup3r/preprocessing/batch_queues/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def _post_proc(self, samples):
attributes
"""
lr, hr = self.transform(samples, **self.transform_kwargs)
lr, hr = self.normalize(lr, hr)
mask = self.make_mask(high_res=hr)
output = self.make_output(samples=(lr, hr))
return self.ConditionalBatch(
Expand Down
10 changes: 9 additions & 1 deletion sup3r/preprocessing/collections/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, containers, means=None, stds=None):
self.means = self.get_means(means)
self.stds = self.get_stds(stds)
self.save_stats(stds=stds, means=means)
self.normalize(means=self.means, stds=self.stds)

def _get_stat(self, stat_type):
"""Get either mean or std for all features and all containers."""
Expand Down Expand Up @@ -95,7 +96,7 @@ def get_stds(self, stds):
stds = dict.fromkeys(all_feats, 0)
logger.info(f'Computing stds for {all_feats}.')
cstds = [
w * cm ** 2
w * cm**2
for cm, w in zip(self._get_stat('std'), self.container_weights)
]
for f in all_feats:
Expand All @@ -117,3 +118,10 @@ def save_stats(self, stds, means):
with open(means, 'w') as f:
f.write(safe_serialize(self.means))
logger.info(f'Saved means {self.means} to {means}.')

def normalize(self, stds, means):
"""Normalize container data with computed stats."""
logger.info(
f'Normalizing container data with means: {means}, stds: {stds}.'
)
_ = [c.normalize(means=means, stds=stds) for c in self.containers]
8 changes: 6 additions & 2 deletions sup3r/preprocessing/derivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import numpy as np

from sup3r.preprocessing.base import Container
from sup3r.preprocessing.utilities import Dimension, parse_to_list
from sup3r.preprocessing.utilities import (
Dimension,
_rechunk_if_dask,
parse_to_list,
)
from sup3r.typing import T_Array, T_Dataset
from sup3r.utilities.interpolation import Interpolator

Expand Down Expand Up @@ -250,7 +254,7 @@ def do_level_interpolation(
level=np.float32(level),
interp_method=interp_method,
)
return out
return _rechunk_if_dask(out)


class Deriver(BaseDeriver):
Expand Down
51 changes: 31 additions & 20 deletions sup3r/preprocessing/extracters/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,38 @@ def update_hr_data(self):
hr_data.shape is divisible by s_enhance. If not, take the largest
shape that can be."""
msg = (
f'hr_data.shape {self.hr_data.shape[:3]} is not '
f'divisible by s_enhance ({self.s_enhance}). Using shape = '
f'hr_data.shape: {self.hr_data.shape[:3]} is not '
f'divisible by s_enhance: {self.s_enhance}. Using shape: '
f'{self.hr_required_shape} instead.'
)
if self.hr_data.shape[:3] != self.hr_required_shape[:3]:
need_new_shape = self.hr_data.shape[:3] != self.hr_required_shape[:3]
if need_new_shape:
logger.warning(msg)
warn(msg)

hr_data_new = {
f: self.hr_data[
f,
slice(self.hr_required_shape[0]),
slice(self.hr_required_shape[1]),
slice(self.hr_required_shape[2]),
]
for f in self.hr_data.data_vars
}
hr_coords_new = {
Dimension.LATITUDE: self.hr_lat_lon[..., 0],
Dimension.LONGITUDE: self.hr_lat_lon[..., 1],
Dimension.TIME: self.hr_data.indexes['time'][
: self.hr_required_shape[2]
],
}
self.hr_data = self.hr_data.update_ds({**hr_coords_new, **hr_data_new})
hr_data_new = {
f: self.hr_data[
f,
slice(self.hr_required_shape[0]),
slice(self.hr_required_shape[1]),
slice(self.hr_required_shape[2]),
]
for f in self.hr_data.data_vars
}
hr_coords_new = {
Dimension.LATITUDE: self.hr_lat_lon[..., 0],
Dimension.LONGITUDE: self.hr_lat_lon[..., 1],
Dimension.TIME: self.hr_data.indexes['time'][
: self.hr_required_shape[2]
],
}
logger.info(
'Updating self.hr_data with new shape: '
f'{self.hr_required_shape[:3]}'
)
self.hr_data = self.hr_data.update_ds(
{**hr_coords_new, **hr_data_new}
)

def get_regridder(self):
"""Get regridder object"""
Expand Down Expand Up @@ -197,6 +204,7 @@ def update_lr_data(self):
: self.lr_required_shape[2]
],
}
logger.info('Updating self.lr_data with regridded data.')
self.lr_data = self.lr_data.update_ds(
{**lr_coords_new, **lr_data_new}
)
Expand All @@ -205,6 +213,9 @@ def check_regridded_lr_data(self):
"""Check for NaNs after regridding and do NN fill if needed."""
fill_feats = []
for f in self.lr_data.data_vars:
logger.info(
f'Checking for NaNs after regridding, for feature: {f}'
)
nan_perc = (
100
* np.isnan(self.lr_data[f].data).sum()
Expand Down
33 changes: 17 additions & 16 deletions sup3r/preprocessing/loaders/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
from rex import MultiFileWindX

Expand Down Expand Up @@ -57,7 +58,7 @@ def load(self) -> xr.Dataset:
dims = (Dimension.FLATTENED_SPATIAL,)
if not self._time_independent:
dims = (Dimension.TIME, *dims)
coords[Dimension.TIME] = self.res['time_index']
coords[Dimension.TIME] = pd.DatetimeIndex(self.res['time_index'])

chunks = (
tuple(self.chunks[d] for d in dims)
Expand All @@ -76,22 +77,22 @@ def load(self) -> xr.Dataset:
dims,
da.asarray(elev, dtype=np.float32, chunks=chunks),
)
data_vars = {
**data_vars,
**{
f: (
dims,
da.asarray(
self.res.h5[f],
dtype=np.float32,
chunks=chunks,
)
/ self.scale_factor(f),
feats = [
f
for f in self.res.h5.datasets
if f not in ('meta', 'time_index', 'coordinates')
]
for f in feats:
logger.debug(f'Rechunking "{f}" with chunks: {self.chunks}')
data_vars[f] = (
dims,
da.asarray(
self.res.h5[f],
dtype=np.float32,
chunks=chunks,
)
for f in self.res.h5.datasets
if f not in ('meta', 'time_index', 'coordinates')
},
}
/ self.scale_factor(f),
)
coords.update(
{
Dimension.LATITUDE: (
Expand Down
7 changes: 7 additions & 0 deletions sup3r/preprocessing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def _compute_if_dask(arr):
return arr.compute() if hasattr(arr, 'compute') else arr


def _rechunk_if_dask(arr, chunks='auto'):

if hasattr(arr, 'rechunk'):
return arr.rechunk(chunks)
return arr


def _parse_time_slice(value):
return (
value
Expand Down

0 comments on commit 508812c

Please sign in to comment.