From 852cb655d16d77edb0d6bd6a3411ea7ecbd63711 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 5 Sep 2023 18:30:31 -0600 Subject: [PATCH] added weighted selection of data handler in batch handler --- sup3r/preprocessing/batch_handling.py | 601 +++++++----------- sup3r/preprocessing/data_handling/base.py | 11 + .../data_handling/dual_data_handling.py | 5 + sup3r/utilities/era_downloader.py | 101 +-- .../data_handling/test_dual_data_handling.py | 1 + 5 files changed, 295 insertions(+), 424 deletions(-) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index c89196a625..019a18228e 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -94,18 +94,17 @@ def reduce_features(high_res, output_features_ind=None): # pylint: disable=W0613 @classmethod - def get_coarse_batch( - cls, - high_res, - s_enhance, - t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - training_features=None, - smoothing=None, - smoothing_ignore=None, - ): + def get_coarse_batch(cls, + high_res, + s_enhance, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + training_features=None, + smoothing=None, + smoothing_ignore=None, + ): """Coarsen high res data and return Batch with high res and low res data @@ -155,13 +154,11 @@ def get_coarse_batch( smoothing_ignore = [] if t_enhance != 1: - low_res = temporal_coarsening( - low_res, t_enhance, temporal_coarsening_method - ) + low_res = temporal_coarsening(low_res, t_enhance, + temporal_coarsening_method) - low_res = smooth_data( - low_res, training_features, smoothing_ignore, smoothing - ) + low_res = smooth_data(low_res, training_features, smoothing_ignore, + smoothing) high_res = cls.reduce_features(high_res, output_features_ind) batch = cls(low_res, high_res) @@ -174,18 +171,16 @@ class ValidationData: # Classes to use for handling an individual batch obj. BATCH_CLASS = Batch - def __init__( - self, - data_handlers, - batch_size=8, - s_enhance=3, - t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - smoothing=None, - smoothing_ignore=None, - ): + def __init__(self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + smoothing=None, + smoothing_ignore=None): """ Parameters ---------- @@ -256,23 +251,32 @@ def _get_val_indices(self): if h.val_data is not None: for _ in range(h.val_data.shape[2]): spatial_slice = uniform_box_sampler( - h.val_data, self.sample_shape[:2] - ) + h.val_data, self.sample_shape[:2]) temporal_slice = uniform_time_sampler( - h.val_data, self.sample_shape[2] - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.val_data.shape[-1]), - ] - ) - val_indices.append( - {'handler_index': i, 'tuple_index': tuple_index} - ) + h.val_data, self.sample_shape[2]) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.val_data.shape[-1]), + ]) + val_indices.append({ + 'handler_index': i, + 'tuple_index': tuple_index + }) return val_indices + @property + def handler_weights(self): + """Get weights used to sample from different data handlers based on + relative sizes""" + sizes = [dh.size for dh in self.data_handlers] + weights = sizes / np.sum(sizes) + return weights + + def get_handler_index(self): + """Get random handler index based on handler weights""" + indices = np.arange(0, len(self.data_handlers)) + return np.random.choice(indices, p=self.handler_weights) + def any(self): """Return True if any validation data exists""" return any(self.val_indices) @@ -291,12 +295,9 @@ def shape(self): time_steps = 0 for h in self.data_handlers: time_steps += h.val_data.shape[2] - return ( - self.data_handlers[0].val_data.shape[0], - self.data_handlers[0].val_data.shape[1], - time_steps, - self.data_handlers[0].val_data.shape[3], - ) + return (self.data_handlers[0].val_data.shape[0], + self.data_handlers[0].val_data.shape[1], time_steps, + self.data_handlers[0].val_data.shape[3]) def __iter__(self): self._i = 0 @@ -334,8 +335,7 @@ def batch_next(self, high_res): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) def __next__(self): """Get validation data batch @@ -354,20 +354,13 @@ def __next__(self): n_obs = self._remaining_observations high_res = np.zeros( - ( - n_obs, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (n_obs, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.data_handlers[0].shape[-1]), + dtype=np.float32) for i in range(high_res.shape[0]): val_index = self.val_indices[self._i + i] - high_res[i, ...] = self.data_handlers[ - val_index['handler_index'] - ].val_data[val_index['tuple_index']] + high_res[i, ...] = self.data_handlers[val_index[ + 'handler_index']].val_data[val_index['tuple_index']] self._remaining_observations -= 1 self.current_batch_indices.append(val_index['handler_index']) @@ -388,24 +381,22 @@ class BatchHandler: BATCH_CLASS = Batch DATA_HANDLER_CLASS = None - def __init__( - self, - data_handlers, - batch_size=8, - s_enhance=3, - t_enhance=1, - means=None, - stds=None, - norm=True, - n_batches=10, - temporal_coarsening_method='subsample', - stdevs_file=None, - means_file=None, - overwrite_stats=False, - smoothing=None, - smoothing_ignore=None, - worker_kwargs=None, - ): + def __init__(self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + means=None, + stds=None, + norm=True, + n_batches=10, + temporal_coarsening_method='subsample', + stdevs_file=None, + means_file=None, + overwrite_stats=False, + smoothing=None, + smoothing_ignore=None, + worker_kwargs=None): """ Parameters ---------- @@ -507,19 +498,17 @@ def __init__( f for f in self.training_features if f not in self.smoothing_ignore ] - logger.info( - f'Initializing BatchHandler with smoothing={smoothing}. ' - f'Using stats_workers={self.stats_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'load_workers={self.load_workers}.' - ) + logger.info(f'Initializing BatchHandler with ' + f'{len(self.data_handlers)} data handlers with handler' + f'weights={self.handler_weights}, smoothing={smoothing}. ' + f'Using stats_workers={self.stats_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'load_workers={self.load_workers}.') now = dt.now() self.parallel_load() - logger.debug( - f'Finished loading data of shape {self.shape} ' - f'for BatchHandler in {dt.now() - now}.' - ) + logger.debug(f'Finished loading data of shape {self.shape} ' + f'for BatchHandler in {dt.now() - now}.') log_mem(logger, log_level='INFO') if norm: @@ -542,6 +531,24 @@ def __init__( logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') + @property + def handler_weights(self): + """Get weights used to sample from different data handlers based on + relative sizes""" + sizes = [dh.size for dh in self.data_handlers] + weights = sizes / np.sum(sizes) + return weights + + def get_handler_index(self): + """Get random handler index based on handler weights""" + indices = np.arange(0, len(self.data_handlers)) + return np.random.choice(indices, p=self.handler_weights) + + def get_current_handler(self): + """Get random handler based on handler weights""" + self.current_handler_index = self.get_handler_index() + return self.data_handlers[self.current_handler_index] + @property def feature_mem(self): """Get memory used by each feature in data handlers""" @@ -551,18 +558,16 @@ def feature_mem(self): def stats_workers(self): """Get max workers for calculating stats based on memory usage""" proc_mem = self.feature_mem - stats_workers = estimate_max_workers( - self._stats_workers, proc_mem, len(self.data_handlers) - ) + stats_workers = estimate_max_workers(self._stats_workers, proc_mem, + len(self.data_handlers)) return stats_workers @property def load_workers(self): """Get max workers for loading data handler based on memory usage""" proc_mem = len(self.data_handlers[0].features) * self.feature_mem - max_workers = estimate_max_workers( - self._load_workers, proc_mem, len(self.data_handlers) - ) + max_workers = estimate_max_workers(self._load_workers, proc_mem, + len(self.data_handlers)) return max_workers @property @@ -570,9 +575,8 @@ def norm_workers(self): """Get max workers used for calculating and normalization across features""" proc_mem = 2 * self.feature_mem - norm_workers = estimate_max_workers( - self._norm_workers, proc_mem, len(self.training_features) - ) + norm_workers = estimate_max_workers(self._norm_workers, proc_mem, + len(self.training_features)) return norm_workers @property @@ -595,8 +599,7 @@ def output_features_ind(self): return None else: out = [ - i - for i, feature in enumerate(self.training_features) + i for i, feature in enumerate(self.training_features) if feature in self.output_features ] return out @@ -613,12 +616,9 @@ def shape(self): dimension """ time_steps = np.sum([h.shape[-2] for h in self.data_handlers]) - return ( - self.data_handlers[0].shape[0], - self.data_handlers[0].shape[1], - time_steps, - self.data_handlers[0].shape[-1], - ) + return (self.data_handlers[0].shape[0], self.data_handlers[0].shape[1], + time_steps, self.data_handlers[0].shape[-1], + ) def parallel_normalization(self): """Normalize data in all data handlers in parallel.""" @@ -635,25 +635,19 @@ def parallel_normalization(self): future = exe.submit(d.normalize, self.means, self.stds) futures[future] = i - logger.info( - f'Started normalizing {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.' - ) + logger.info(f'Started normalizing {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.') for i, _ in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error normalizing data handler number ' - f'{futures[future]}' - ) + msg = ('Error normalizing data handler number ' + f'{futures[future]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {len(futures)} data handlers' - ' normalized.' - ) + logger.debug(f'{i+1} out of {len(futures)} data handlers' + ' normalized.') def parallel_load(self): """Load data handler data in parallel""" @@ -672,31 +666,25 @@ def parallel_load(self): future = exe.submit(d.load_cached_data) futures[future] = i - logger.info( - f'Started loading all {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.' - ) + logger.info(f'Started loading all {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error loading data handler number ' - f'{futures[future]}' - ) + msg = ('Error loading data handler number ' + f'{futures[future]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {len(futures)} handlers ' 'loaded.' - ) + logger.debug(f'{i+1} out of {len(futures)} handlers ' + 'loaded.') def parallel_stats(self): """Get standard deviations and means for training features in parallel.""" - logger.info( - f'Calculating stats for {len(self.training_features)} ' 'features.' - ) + logger.info(f'Calculating stats for {len(self.training_features)} ' + 'features.') max_workers = self.norm_workers if max_workers == 1: for f in self.training_features: @@ -709,27 +697,21 @@ def parallel_stats(self): future = exe.submit(self.get_stats_for_feature, f) futures[future] = i - logger.info( - 'Started calculating stats for ' - f'{len(self.training_features)} features in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating stats for ' + f'{len(self.training_features)} features in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error calculating stats for ' - f'{self.training_features[futures[future]]}' - ) + msg = ('Error calculating stats for ' + f'{self.training_features[futures[future]]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of ' - f'{len(self.training_features)} stats ' - 'calculated.' - ) + logger.debug(f'{i+1} out of ' + f'{len(self.training_features)} stats ' + 'calculated.') def __len__(self): """Use user input of n_batches to specify length @@ -752,9 +734,8 @@ def check_cached_stats(self): stds : ndarray Array of stdevs for each feature """ - stdevs_check = ( - self.stdevs_file is not None and not self.overwrite_stats - ) + stdevs_check = (self.stdevs_file is not None + and not self.overwrite_stats) stdevs_check = stdevs_check and os.path.exists(self.stdevs_file) means_check = self.means_file is not None and not self.overwrite_stats means_check = means_check and os.path.exists(self.means_file) @@ -766,12 +747,10 @@ def check_cached_stats(self): with open(self.means_file, 'rb') as fh: self.means = pickle.load(fh) - msg = ( - 'The training features and cached statistics are ' - 'incompatible. Number of training features is ' - f'{len(self.training_features)} and number of stats is' - f' {len(self.stds)}' - ) + msg = ('The training features and cached statistics are ' + 'incompatible. Number of training features is ' + f'{len(self.training_features)} and number of stats is' + f' {len(self.stds)}') check = len(self.means) == len(self.training_features) check = check and (len(self.stds) == len(self.training_features)) assert check, msg @@ -822,9 +801,8 @@ def get_handler_mean(self, feature_idx, handler_idx): float Feature mean """ - return np.nanmean( - self.data_handlers[handler_idx].data[..., feature_idx] - ) + return np.nanmean(self.data_handlers[handler_idx].data[..., + feature_idx]) def get_handler_variance(self, feature_idx, handler_idx, mean): """Get feature variance for a given handler @@ -887,18 +865,14 @@ def get_means_for_feature(self, feature, max_workers=None): future = exe.submit(self.get_handler_mean, idx, didx) futures[future] = didx - logger.info( - 'Started calculating means for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating means for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): self.means[idx] += future.result() - logger.debug( - f'{i+1} out of {len(self.data_handlers)} ' - 'means calculated.' - ) + logger.debug(f'{i+1} out of {len(self.data_handlers)} ' + 'means calculated.') self.means[idx] /= len(self.data_handlers) return self.means[idx] @@ -918,30 +892,24 @@ def get_stdevs_for_feature(self, feature, max_workers=None): if max_workers == 1: for didx, _ in enumerate(self.data_handlers): self.stds[idx] += self.get_handler_variance( - idx, didx, self.means[idx] - ) + idx, didx, self.means[idx]) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} now = dt.now() for didx, _ in enumerate(self.data_handlers): - future = exe.submit( - self.get_handler_variance, idx, didx, self.means[idx] - ) + future = exe.submit(self.get_handler_variance, idx, didx, + self.means[idx]) futures[future] = didx - logger.info( - 'Started calculating stdevs for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating stdevs for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): self.stds[idx] += future.result() - logger.debug( - f'{i+1} out of {len(self.data_handlers)} ' - 'stdevs calculated.' - ) + logger.debug(f'{i+1} out of {len(self.data_handlers)} ' + 'stdevs calculated.') self.stds[idx] /= len(self.data_handlers) self.stds[idx] = np.sqrt(self.stds[idx]) return self.stds[idx] @@ -962,18 +930,15 @@ def normalize(self, means=None, stds=None): self.get_stats() elif means is not None and stds is not None: if not np.array_equal(means, self.means) or not np.array_equal( - stds, self.stds - ): + stds, self.stds): self.unnormalize() self.means = means self.stds = stds now = dt.now() logger.info('Normalizing data in each data handler.') self.parallel_normalization() - logger.info( - 'Finished normalizing data in all data handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Finished normalizing data in all data handlers in ' + f'{dt.now() - now}.') def unnormalize(self): """Remove normalization from stored means and stds""" @@ -995,19 +960,11 @@ def __next__(self): """ self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next() @@ -1022,8 +979,7 @@ def __next__(self): output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1070,9 +1026,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() low_res = None high_res = None @@ -1082,8 +1036,7 @@ def __next__(self): self.current_batch_indices.append(handler.current_obs_index) obs_hourly = self.BATCH_CLASS.reduce_features( - obs_hourly, self.output_features_ind - ) + obs_hourly, self.output_features_ind) if low_res is None: lr_shape = (self.batch_size, *obs_daily_avg.shape) @@ -1097,25 +1050,22 @@ def __next__(self): high_res = self.reduce_high_res_sub_daily(high_res) low_res = spatial_coarsening(low_res, self.s_enhance) - if ( - self.output_features is not None - and 'clearsky_ratio' in self.output_features - ): + if (self.output_features is not None + and 'clearsky_ratio' in self.output_features): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ - j - for j in range(low_res.shape[-1]) + j for j in range(low_res.shape[-1]) if self.training_features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], self.smoothing, mode='nearest' - ) + low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], + self.smoothing, + mode='nearest') batch = self.BATCH_CLASS(low_res, high_res) @@ -1182,9 +1132,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = None @@ -1196,12 +1144,9 @@ def __next__(self): hr_shape = (self.batch_size, *obs_daily_avg.shape) high_res = np.zeros(hr_shape, dtype=np.float32) - msg = ( - 'SpatialBatchHandlerCC can only use n_temporal==1 ' - 'but received HR shape {} with n_temporal={}.'.format( - hr_shape, hr_shape[3] - ) - ) + msg = ('SpatialBatchHandlerCC can only use n_temporal==1 ' + 'but received HR shape {} with n_temporal={}.'.format( + hr_shape, hr_shape[3])) assert hr_shape[3] == 1, msg high_res[i] = obs_daily_avg @@ -1210,29 +1155,25 @@ def __next__(self): low_res = low_res[:, :, :, 0, :] high_res = high_res[:, :, :, 0, :] - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind - ) + high_res = self.BATCH_CLASS.reduce_features(high_res, + self.output_features_ind) - if ( - self.output_features is not None - and 'clearsky_ratio' in self.output_features - ): + if (self.output_features is not None + and 'clearsky_ratio' in self.output_features): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ - j - for j in range(low_res.shape[-1]) + j for j in range(low_res.shape[-1]) if self.training_features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], self.smoothing, mode='nearest' - ) + low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], + self.smoothing, + mode='nearest') batch = self.BATCH_CLASS(low_res, high_res) @@ -1245,17 +1186,10 @@ class SpatialBatchHandler(BatchHandler): def __next__(self): if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) + handler = self.get_current_handler() + high_res = np.zeros((self.batch_size, self.sample_shape[0], + self.sample_shape[1], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next()[..., 0, :] @@ -1265,8 +1199,7 @@ def __next__(self): output_features_ind=self.output_features_ind, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1295,69 +1228,58 @@ def _get_val_indices(self): val_indices = {} for t in range(self.N_TIME_BINS): val_indices[t] = [] - h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h_idx = self.get_handler_index() h = self.data_handlers[h_idx] for _ in range(self.batch_size): - spatial_slice = uniform_box_sampler( - h.data, self.sample_shape[:2] - ) + spatial_slice = uniform_box_sampler(h.data, + self.sample_shape[:2]) weights = np.zeros(self.N_TIME_BINS) weights[t] = 1 - temporal_slice = weighted_time_sampler( - h.data, self.sample_shape[2], weights - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.data.shape[-1]), - ] - ) - val_indices[t].append( - {'handler_index': h_idx, 'tuple_index': tuple_index} - ) + temporal_slice = weighted_time_sampler(h.data, + self.sample_shape[2], + weights) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.data.shape[-1]) + ]) + val_indices[t].append({ + 'handler_index': h_idx, + 'tuple_index': tuple_index + }) for s in range(self.N_SPACE_BINS): val_indices[s + self.N_TIME_BINS] = [] - h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h_idx = self.get_handler_index() h = self.data_handlers[h_idx] for _ in range(self.batch_size): weights = np.zeros(self.N_SPACE_BINS) weights[s] = 1 - spatial_slice = weighted_box_sampler( - h.data, self.sample_shape[:2], weights - ) - temporal_slice = uniform_time_sampler( - h.data, self.sample_shape[2] - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.data.shape[-1]), - ] - ) - val_indices[s + self.N_TIME_BINS].append( - {'handler_index': h_idx, 'tuple_index': tuple_index} - ) + spatial_slice = weighted_box_sampler(h.data, + self.sample_shape[:2], + weights) + temporal_slice = uniform_time_sampler(h.data, + self.sample_shape[2]) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.data.shape[-1]) + ]) + val_indices[s + self.N_TIME_BINS].append({ + 'handler_index': + h_idx, + 'tuple_index': + tuple_index + }) return val_indices def __next__(self): if self._i < len(self.val_indices.keys()): high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.data_handlers[0].shape[-1]), + dtype=np.float32) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): high_res[i, ...] = self.data_handlers[ - idx['handler_index'] - ].data[idx['tuple_index']] + idx['handler_index']].data[idx['tuple_index']] batch = self.BATCH_CLASS.get_coarse_batch( high_res, @@ -1367,8 +1289,7 @@ def __next__(self): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) self._i += 1 return batch else: @@ -1389,19 +1310,13 @@ class ValidationDataSpatialDC(ValidationDataDC): def __next__(self): if self._i < len(self.val_indices.keys()): high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.data_handlers[0].shape[-1]), + dtype=np.float32) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): high_res[i, ...] = self.data_handlers[ - idx['handler_index'] - ].data[idx['tuple_index']][..., 0, :] + idx['handler_index']].data[idx['tuple_index']][..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( high_res, @@ -1409,8 +1324,7 @@ def __next__(self): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) self._i += 1 return batch else: @@ -1440,15 +1354,12 @@ def __init__(self, *args, **kwargs): self.old_temporal_weights = [0] * self.val_data.N_TIME_BINS bin_range = self.data_handlers[0].data.shape[2] bin_range -= self.sample_shape[2] - 1 - self.temporal_bins = np.array_split( - np.arange(0, bin_range), self.val_data.N_TIME_BINS - ) + self.temporal_bins = np.array_split(np.arange(0, bin_range), + self.val_data.N_TIME_BINS) self.temporal_bins = [b[0] for b in self.temporal_bins] - logger.info( - 'Using temporal weights: ' - f'{[round(w, 3) for w in self.temporal_weights]}' - ) + logger.info('Using temporal weights: ' + f'{[round(w, 3) for w in self.temporal_weights]}') self.temporal_sample_record = [0] * self.val_data.N_TIME_BINS self.norm_temporal_record = [0] * self.val_data.N_TIME_BINS @@ -1467,24 +1378,15 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - temporal_weights=self.temporal_weights - ) + temporal_weights=self.temporal_weights) self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() @@ -1498,8 +1400,7 @@ def __next__(self): output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1538,15 +1439,12 @@ def __init__(self, *args, **kwargs): self.max_cols = self.data_handlers[0].data.shape[1] + 1 self.max_cols -= self.sample_shape[1] bin_range = self.max_rows * self.max_cols - self.spatial_bins = np.array_split( - np.arange(0, bin_range), self.val_data.N_SPACE_BINS - ) + self.spatial_bins = np.array_split(np.arange(0, bin_range), + self.val_data.N_SPACE_BINS) self.spatial_bins = [b[0] for b in self.spatial_bins] - logger.info( - 'Using spatial weights: ' - f'{[round(w, 3) for w in self.spatial_weights]}' - ) + logger.info('Using spatial weights: ' + f'{[round(w, 3) for w in self.spatial_weights]}') self.spatial_sample_record = [0] * self.val_data.N_SPACE_BINS self.norm_spatial_record = [0] * self.val_data.N_SPACE_BINS @@ -1568,23 +1466,16 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) + handler = self.get_current_handler() + high_res = np.zeros((self.batch_size, self.sample_shape[0], + self.sample_shape[1], self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - spatial_weights=self.spatial_weights - )[..., 0, :] + spatial_weights=self.spatial_weights)[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 3c646eaa64..fede69dd98 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -922,6 +922,17 @@ def shape(self): """ return self.data.shape + @property + def size(self): + """Size of data array + + Returns + ------- + size : int + Number of total elements contained in data array + """ + return self.data.size + def cache_data(self, cache_file_paths): """Cache feature data to file and delete from memory diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 5f5786d509..58d1ba17a0 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -334,6 +334,11 @@ def shape(self): """Get low_res shape""" return (*self.lr_required_shape, len(self.features)) + @property + def size(self): + """Get low_res size""" + return self.lr_data.size + @property def hr_required_shape(self): """Return required shape for high_res data""" diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index de0080ebf1..c6e6dc2ca4 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -8,11 +8,9 @@ import logging import os from calendar import monthrange -from concurrent.futures import ( - ProcessPoolExecutor, - ThreadPoolExecutor, - as_completed, -) +from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor, + as_completed, + ) from glob import glob from typing import ClassVar from warnings import warn @@ -46,66 +44,34 @@ class EraDownloader: assert os.path.exists(req_file), msg VALID_VARIABLES: ClassVar[list] = [ - 'u', - 'v', - 'pressure', - 'temperature', - 'relative_humidity', - 'specific_humidity', - 'total_precipitation', + 'u', 'v', 'pressure', 'temperature', 'relative_humidity', + 'specific_humidity', 'total_precipitation', ] - KEEP_VARIABLES: ClassVar[list] = [ - 'orog', - 'time', - 'latitude', - 'longitude', - ] + KEEP_VARIABLES: ClassVar[list] = ['orog'] KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'zg', - 'orog', - 'u', - 'v', - 'u_10m', - 'v_10m', - 'u_100m', - 'v_100m', - 'temperature', - 'pressure', + 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', + 'temperature', 'pressure', ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', - '10m_v_component_of_wind', - '100m_u_component_of_wind', - '100m_v_component_of_wind', - 'u_component_of_wind', - 'v_component_of_wind', - '2m_temperature', - 'temperature', - 'surface_pressure', - 'relative_humidity', + '10m_u_component_of_wind', '10m_v_component_of_wind', + '100m_u_component_of_wind', '100m_v_component_of_wind', + 'u_component_of_wind', 'v_component_of_wind', '2m_temperature', + 'temperature', 'surface_pressure', 'relative_humidity', 'total_precipitation', ] SFC_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', - '10m_v_component_of_wind', - '100m_u_component_of_wind', - '100m_v_component_of_wind', - 'surface_pressure', - '2m_temperature', - 'geopotential', + '10m_u_component_of_wind', '10m_v_component_of_wind', + '100m_u_component_of_wind', '100m_v_component_of_wind', + 'surface_pressure', '2m_temperature', 'geopotential', 'total_precipitation', ] LEVEL_VARS: ClassVar[list] = [ - 'u_component_of_wind', - 'v_component_of_wind', - 'geopotential', - 'temperature', - 'relative_humidity', - 'specific_humidity', + 'u_component_of_wind', 'v_component_of_wind', 'geopotential', + 'temperature', 'relative_humidity', 'specific_humidity', ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -119,7 +85,7 @@ class EraDownloader: 'sp': 'pressure_0m', 'r': 'relative_humidity', 'q': 'specific_humidity', - 'tp': 'total_precip', + 'tp': 'total_precipitation', } def __init__(self, @@ -418,11 +384,10 @@ def map_vars(self, old_ds, ds): """ for old_name, new_name in self.NAME_MAP.items(): if old_name in old_ds.variables: - _ = ds.createVariable( - new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) vals = old_ds.variables[old_name][:] if 'temperature' in new_name: vals -= 273.15 @@ -528,6 +493,7 @@ def good_file(self, file, required_shape): Whether or not data has required shape and variables. """ out = self.check_single_file(file, + var_list=self.variables, check_nans=False, check_heights=False, required_shape=required_shape) @@ -896,11 +862,9 @@ def _check_single_file(cls, Percent of data which consists of NaNs across all given variables. """ good_vars = all(var in res for var in var_list) - res_shape = ( - *res['level'].shape, - *res['latitude'].shape, - *res['longitude'].shape, - ) + res_shape = (*res['level'].shape, *res['latitude'].shape, + *res['longitude'].shape, + ) good_shape = ('NA' if required_shape is None else (res_shape == required_shape)) good_hgts = ('NA' if not check_heights else cls.check_heights( @@ -912,8 +876,8 @@ def _check_single_file(cls, res, var_list=var_list)) if not good_vars: - mask = np.array([var not in res for var in var_list]) - missing_vars = var_list[mask] + mask = [var not in res for var in var_list] + missing_vars = np.array(var_list)[mask] logger.error(f'Missing variables: {missing_vars}.') if good_shape != 'NA' and not good_shape: logger.error(f'Bad shape: {res_shape} != {required_shape}.') @@ -961,11 +925,10 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): futures = [] with ProcessPoolExecutor(max_workers=max_workers) as exe: for idt in range(heights.shape[0]): - future = exe.submit( - cls._check_heights_single_ts, - heights[idt], - max_interp_height=max_interp_height, - ) + future = exe.submit(cls._check_heights_single_ts, + heights[idt], + max_interp_height=max_interp_height, + ) futures.append(future) msg = (f'Submitted height check for {idt + 1} of ' f'{heights.shape[0]}') diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index a72a12c37b..e74bb5a0b2 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -169,6 +169,7 @@ def test_st_dual_batch_handler(log=False, s_enhance=s_enhance, t_enhance=t_enhance, n_batches=10) + assert np.allclose(batch_handler.handler_weights, 0.5) for batch in batch_handler: