diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 74490bb2e4..aad5355789 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -78,6 +78,104 @@ def __init__(self, ds: Union[xr.Dataset, Self]): self._features = None self.time_slice = None + def __getattr__(self, attr): + """Get attribute and cast to type(self) if a xr.Dataset is returned + first.""" + out = getattr(self._ds, attr) + return type(self)(out) if isinstance(out, xr.Dataset) else out + + def __mul__(self, other): + """Multiply Sup3rX object by other. Used to compute weighted means and + stdevs.""" + try: + return type(self)(other * self._ds) + except Exception as e: + raise NotImplementedError( + f'Multiplication not supported for type {type(other)}.' + ) from e + + def __rmul__(self, other): + return self.__mul__(other) + + def __pow__(self, other): + """Raise Sup3rX object to an integer power. Used to compute weighted + standard deviations.""" + try: + return type(self)(self._ds**other) + except Exception as e: + raise NotImplementedError( + f'Exponentiation not supported for type {type(other)}.' + ) from e + + def __setitem__(self, keys, data): + """ + Parameters + ---------- + keys : str | list | tuple + keys to set. This can be a string like 'temperature' or a list + like ['u', 'v']. `data` will be iterated over in the latter case. + data : T_Array | xr.DataArray + array object used to set variable data. If `variable` is a list + then this is expected to have a trailing dimension with length + equal to the length of the list. + """ + if _is_strings(keys): + if isinstance(keys, (list, tuple)): + data_dict = {v: data[..., i] for i, v in enumerate(keys)} + else: + data_dict = {keys.lower(): data} + _ = self.assign(data_dict) + elif isinstance(keys[0], str) and keys[0] not in self.coords: + feats, slices = self._parse_keys(keys) + var_array = self[feats].data + var_array[tuple(slices.values())] = data + _ = self.assign({feats: var_array}) + else: + msg = f'Cannot set values for keys {keys}' + logger.error(msg) + raise KeyError(msg) + + def __getitem__(self, keys) -> Union[T_Array, Self]: + """Method for accessing variables or attributes. keys can optionally + include a feature name or list of feature names as the first entry of a + keys tuple. When keys take the form of numpy style indexing we return a + dask or numpy array, depending on whether contained data has been + loaded into memory, otherwise we return xarray or Sup3rX objects""" + features, slices = self._parse_keys(keys) + out = self._ds[features] + slices = {k: v for k, v in slices.items() if k in out.dims} + if self._needs_fancy_indexing(slices.values()): + out = self.as_array(data=out, features=features) + return out.vindex[tuple(slices.values())] + + out = out.isel(**slices) + # numpy style indexing requested so we return an array (dask or np) + if isinstance(keys, (slice, tuple)) or _contains_ellipsis(keys): + return self.as_array(data=out, features=features) + if isinstance(out, xr.Dataset): + return type(self)(out) + return out.transpose(*ordered_dims(out.dims), ...) + + def __contains__(self, vals): + """Check if self._ds contains `vals`. + + Parameters + ---------- + vals : str | list + Values to check. Can be a list of strings or a single string. + + Examples + -------- + bool(['u', 'v'] in self) + bool('u' in self) + """ + feature_check = isinstance(vals, (list, tuple)) and all( + isinstance(s, str) for s in vals + ) + if feature_check: + return all(s.lower() in self._ds for s in vals) + return self._ds.__contains__(vals) + def compute(self, **kwargs): """Load `._ds` into memory. This updates the internal `xr.Dataset` if it has not been loaded already.""" @@ -128,52 +226,20 @@ def update_ds(self, new_dset, attrs=None): """ coords = dict(self._ds.coords) data_vars = dict(self._ds.data_vars) - coords.update( - { - k: dims_array_tuple(v) - for k, v in new_dset.items() - if k in coords - } - ) - data_vars.update( - { - k: dims_array_tuple(v) - for k, v in new_dset.items() - if k not in coords - } - ) + new_coords = { + k: dims_array_tuple(v) for k, v in new_dset.items() if k in coords + } + coords.update(new_coords) + new_data = { + k: dims_array_tuple(v) + for k, v in new_dset.items() + if k not in coords + } + data_vars.update(new_data) + self._ds = xr.Dataset(coords=coords, data_vars=data_vars, attrs=attrs) return type(self)(self._ds) - def __getattr__(self, attr): - """Get attribute and cast to type(self) if a xr.Dataset is returned - first.""" - out = getattr(self._ds, attr) - return type(self)(out) if isinstance(out, xr.Dataset) else out - - def __mul__(self, other): - """Multiply Sup3rX object by other. Used to compute weighted means and - stdevs.""" - try: - return type(self)(other * self._ds) - except Exception as e: - raise NotImplementedError( - f'Multiplication not supported for type {type(other)}.' - ) from e - - def __rmul__(self, other): - return self.__mul__(other) - - def __pow__(self, other): - """Raise Sup3rX object to an integer power. Used to compute weighted - standard deviations.""" - try: - return type(self)(self._ds**other) - except Exception as e: - raise NotImplementedError( - f'Exponentiation not supported for type {type(other)}.' - ) from e - @property def name(self): """Name of dataset. Used to label datasets when grouped in @@ -198,9 +264,13 @@ def name(self, value): self._ds.attrs['name'] = value def isel(self, *args, **kwargs): - """Override xr.Dataset.sel to cast back to Sup3rX object.""" + """Override xr.Dataset.isel to cast back to Sup3rX object.""" return type(self)(self._ds.isel(*args, **kwargs)) + def coarsen(self, *args, **kwargs): + """Override xr.Dataset.coarsen to cast back to Sup3rX object.""" + return type(self)(self._ds.coarsen(*args, **kwargs)) + @property def dims(self): """Return dims with our own enforced ordering.""" @@ -309,47 +379,6 @@ def _parse_keys(self, keys): dim_keys = _parse_ellipsis(dim_keys, dim_num=len(self._ds.dims)) return features, dict(zip(ordered_dims(self._ds.dims), dim_keys)) - def __getitem__(self, keys) -> Union[T_Array, Self]: - """Method for accessing variables or attributes. keys can optionally - include a feature name or list of feature names as the first entry of a - keys tuple. When keys take the form of numpy style indexing we return a - dask or numpy array, depending on whether contained data has been - loaded into memory, otherwise we return xarray or Sup3rX objects""" - features, slices = self._parse_keys(keys) - out = self._ds[features] - slices = {k: v for k, v in slices.items() if k in out.dims} - if self._needs_fancy_indexing(slices.values()): - out = self.as_array(data=out, features=features) - return out.vindex[*slices.values()] - - out = out.isel(**slices) - # numpy style indexing requested so we return an array (dask or np) - if isinstance(keys, (slice, tuple)) or _contains_ellipsis(keys): - return self.as_array(data=out, features=features) - if isinstance(out, xr.Dataset): - return type(self)(out) - return out.transpose(*ordered_dims(out.dims), ...) - - def __contains__(self, vals): - """Check if self._ds contains `vals`. - - Parameters - ---------- - vals : str | list - Values to check. Can be a list of strings or a single string. - - Examples - -------- - bool(['u', 'v'] in self) - bool('u' in self) - """ - feature_check = isinstance(vals, (list, tuple)) and all( - isinstance(s, str) for s in vals - ) - if feature_check: - return all(s.lower() in self._ds for s in vals) - return self._ds.__contains__(vals) - def _add_dims_to_data_dict(self, vals): """Add dimensions to vals entries if needed. This is used to set values of `self._ds` which can require dimensions to be explicitly specified @@ -415,34 +444,6 @@ def assign(self, vals: Dict[str, Union[T_Array, tuple]]): self._ds = self._ds.assign(data_dict) return type(self)(self._ds) - def __setitem__(self, keys, data): - """ - Parameters - ---------- - keys : str | list | tuple - keys to set. This can be a string like 'temperature' or a list - like ['u', 'v']. `data` will be iterated over in the latter case. - data : T_Array | xr.DataArray - array object used to set variable data. If `variable` is a list - then this is expected to have a trailing dimension with length - equal to the length of the list. - """ - if _is_strings(keys): - if isinstance(keys, (list, tuple)): - data_dict = {v: data[..., i] for i, v in enumerate(keys)} - else: - data_dict = {keys.lower(): data} - _ = self.assign(data_dict) - elif isinstance(keys[0], str) and keys[0] not in self.coords: - feats, slices = self._parse_keys(keys) - var_array = self[feats].data - var_array[*slices.values()] = data - _ = self.assign({feats: var_array}) - else: - msg = f'Cannot set values for keys {keys}' - logger.error(msg) - raise KeyError(msg) - @property def features(self): """Features in this container.""" diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index f3b2887906..319458ee85 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -21,6 +21,15 @@ logger = logging.getLogger(__name__) +def _get_class_info(namespace): + sig_objs = namespace.get('_signature_objs', None) + skips = namespace.get('_skip_params', None) + _sig = _doc = None + if sig_objs: + _sig, _doc = composite_info(sig_objs, skip_params=skips) + return _sig, _doc + + class Sup3rMeta(ABCMeta, type): """Meta class to define __name__, __signature__, and __subclasscheck__ of composite and derived classes. This allows us to still resolve a signature @@ -29,15 +38,13 @@ class Sup3rMeta(ABCMeta, type): def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804 """Define __name__ and __signature__""" - name = namespace.get('__name__', name) - sig_objs = namespace.get('_signature_objs', None) - skips = namespace.get('_skip_params', None) - if sig_objs: - _sig, _doc = composite_info(sig_objs, skip_params=skips) + _sig, _doc = _get_class_info(namespace) + if _sig: namespace['__signature__'] = _sig - if '__init__' in namespace: - namespace['__init__'].__signature__ = _sig - namespace['__init__'].__doc__ = _doc + if '__init__' in namespace and _sig: + namespace['__init__'].__signature__ = _sig + if '__init__' in namespace and _doc: + namespace['__init__'].__doc__ = _doc return super().__new__(mcs, name, bases, namespace, **kwargs) def __subclasscheck__(cls, subclass): diff --git a/sup3r/preprocessing/data_handlers/exo/base.py b/sup3r/preprocessing/data_handlers/exo.py similarity index 50% rename from sup3r/preprocessing/data_handlers/exo/base.py rename to sup3r/preprocessing/data_handlers/exo.py index 2f3d4ce975..5aa4ee14fe 100644 --- a/sup3r/preprocessing/data_handlers/exo/base.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -1,7 +1,16 @@ -"""Base exogenous data wrangling classes. -""" +"""Exogenous data handler. This performs exo extraction for one or more model +steps for requested features.""" import logging +import pathlib +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +from sup3r.preprocessing.names import Dimension +from sup3r.preprocessing.rasterizers import ExoRasterizer +from sup3r.preprocessing.utilities import log_args logger = logging.getLogger(__name__) @@ -238,3 +247,205 @@ def get_chunk(self, input_data_shape, lr_slices): } exo_chunk[feature]['steps'].append(chunk_step) return exo_chunk + + +@dataclass +class ExoDataHandler: + """Class to extract exogenous features for multistep forward passes. e.g. + Multiple topography arrays at different resolutions for multiple spatial + enhancement steps. + + This takes a list of models and information about model + steps and uses that info to compute needed enhancement factors for each + step and extract exo data corresponding to those enhancement factors. The + list of steps are then updated with the exo data for each step. + + Parameters + ---------- + file_paths : str | list + A single source h5 file or netcdf file to extract raster data from. + The string can be a unix-style file path which will be passed + through glob.glob. This is typically low-res WRF output or GCM + netcdf data that is source low-resolution data intended to be + sup3r resolved. + feature : str + Exogenous feature to extract from file_paths + models : list + List of models used with the given steps list. This list of models is + used to determine the input and output resolution and enhancement + factors for each model step which is then used to determine the target + shape for rasterized exo data. If enhancement factors are provided in + the steps list the model list is not needed. + steps : list + List of dictionaries containing info on which models to use for a + given step index and what type of exo data the step requires. e.g. + [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + Each step entry can also contain enhancement factors. e.g. + [{'model': 0, 'combine_type': 'input', 's_enhance': 1, 't_enhance': 1}, + {'model': 0, 'combine_type': 'layer', 's_enhance': 3, 't_enhance': 1}] + source_file : str + Filepath to source wtk, nsrdb, or netcdf file to get hi-res data + from which will be mapped to the enhanced grid of the file_paths + input. Pixels from this file will be mapped to their nearest + low-res pixel in the file_paths input. Accordingly, the input + should be a significantly higher resolution than file_paths. + Warnings will be raised if the low-resolution pixels in file_paths + do not have unique nearest pixels from this exo source data. + input_handler_name : str + data handler class used by the exo handler. Provide a string name to + match a :class:`~sup3r.preprocessing.rasterizers.Rasterizer`. If None + the correct handler will be guessed based on file type and time series + properties. This is passed directly to the exo handler, along with + input_handler_kwargs + input_handler_kwargs : dict | None + Any kwargs for initializing the `input_handler_name` class used by the + exo handler. + cache_dir : str | None + Directory for storing cache data. Default is './exo_cache'. If None + then no data will be cached. + distance_upper_bound : float | None + Maximum distance to map high-resolution data from source_file to the + low-resolution file_paths input. None (default) will calculate this + based on the median distance between points in source_file + """ + + file_paths: Union[str, list, pathlib.Path] + feature: str + steps: List[dict] + models: Optional[list] = None + source_file: Optional[str] = None + input_handler_name: Optional[str] = None + input_handler_kwargs: Optional[dict] = None + cache_dir: str = './exo_cache' + distance_upper_bound: Optional[int] = None + + @log_args + def __post_init__(self): + """Initialize `self.data`, perform checks on enhancement factors, and + update `self.data` for each model step with rasterized exo data for the + corresponding enhancement factors.""" + self.data = {self.feature: {'steps': []}} + en_check = all('s_enhance' in v for v in self.steps) + en_check = en_check and all('t_enhance' in v for v in self.steps) + en_check = en_check or self.models is not None + msg = ( + f'{self.__class__.__name__} needs s_enhance and t_enhance ' + 'provided in each step in steps list or models' + ) + assert en_check, msg + self.s_enhancements, self.t_enhancements = self._get_all_enhancement() + msg = ( + 'Need to provide s_enhance and t_enhance for each model' + 'step. If the step is temporal only (spatial only) then ' + 's_enhance = 1 (t_enhance = 1).' + ) + assert not any(s is None for s in self.s_enhancements), msg + assert not any(t is None for t in self.t_enhancements), msg + + self.get_all_step_data() + + def get_all_step_data(self): + """Get exo data for each model step. We get the maximally enhanced + exo data and then coarsen this to get the exo data for each enhancement + step. We get coarsen factors by iterating through enhancement factors + in reverse. + """ + hr_exo = ExoRasterizer( + file_paths=self.file_paths, + source_file=self.source_file, + feature=self.feature, + s_enhance=self.s_enhancements[-1], + t_enhance=self.t_enhancements[-1], + input_handler_name=self.input_handler_name, + input_handler_kwargs=self.input_handler_kwargs, + cache_dir=self.cache_dir, + distance_upper_bound=self.distance_upper_bound, + ) + for i, (s_coarsen, t_coarsen) in enumerate( + zip(self.s_enhancements[::-1], self.t_enhancements[::-1]) + ): + coarsen_kwargs = dict( + zip(Dimension.dims_3d(), [s_coarsen, s_coarsen, t_coarsen]) + ) + step = SingleExoDataStep( + self.feature, + self.steps[i]['combine_type'], + self.steps[i]['model'], + data=hr_exo.data.coarsen(**coarsen_kwargs).mean().as_array(), + ) + self.data[self.feature]['steps'].append(step) + shapes = [ + None if step is None else step.shape + for step in self.data[self.feature]['steps'] + ] + logger.info( + 'Got exogenous_data of length {} with shapes: {}'.format( + len(self.data[self.feature]['steps']), shapes + ) + ) + + def _get_single_step_enhance(self, step): + """Get enhancement factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + stored enhance attributes of each model and the model step provided. + If enhancement factors are already provided in step they are not + overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_enhance, t_enhance added + """ + if all(key in step for key in ['s_enhance', 't_enhance']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = ( + f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})' + ) + assert len(self.models) > model_step, msg + msg = ( + 'Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)' + ) + assert combine_type.lower() in ('input', 'output', 'layer'), msg + s_enhancements = [model.s_enhance for model in self.models] + t_enhancements = [model.t_enhance for model in self.models] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.prod(s_enhancements[:model_step]) + t_enhance = np.prod(t_enhancements[:model_step]) + + else: + s_enhance = np.prod(s_enhancements[: model_step + 1]) + t_enhance = np.prod(t_enhancements[: model_step + 1]) + step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) + return step + + def _get_all_enhancement(self): + """Compute enhancement factors for all model steps for all features. + + Returns + ------- + s_enhancements: list + List of s_enhance factors for all model steps + t_enhancements: list + List of t_enhance factors for all model steps + """ + for i, step in enumerate(self.steps): + out = self._get_single_step_enhance(step) + self.steps[i] = out + s_enhancements = [step['s_enhance'] for step in self.steps] + t_enhancements = [step['t_enhance'] for step in self.steps] + return s_enhancements, t_enhancements diff --git a/sup3r/preprocessing/data_handlers/exo/__init__.py b/sup3r/preprocessing/data_handlers/exo/__init__.py deleted file mode 100644 index 20c826c9b7..0000000000 --- a/sup3r/preprocessing/data_handlers/exo/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Exo data handler module.""" -from .base import ExoData, SingleExoDataStep -from .exo import ExoDataHandler diff --git a/sup3r/preprocessing/data_handlers/exo/exo.py b/sup3r/preprocessing/data_handlers/exo/exo.py deleted file mode 100644 index f7271e4f1f..0000000000 --- a/sup3r/preprocessing/data_handlers/exo/exo.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Exogenous data handler. This performs exo extraction for one or more model -steps for requested features. - -TODO: More cleaning. This does not yet fit the new style of composition and -lazy loading. -""" - -import logging -import pathlib -from dataclasses import dataclass -from inspect import signature -from typing import List, Optional, Union - -import numpy as np - -from sup3r.preprocessing.rasterizers import ExoRasterizer -from sup3r.preprocessing.utilities import log_args - -from .base import SingleExoDataStep - -logger = logging.getLogger(__name__) - - -@dataclass -class ExoDataHandler: - """Class to extract exogenous features for multistep forward passes. e.g. - Multiple topography arrays at different resolutions for multiple spatial - enhancement steps. - - This takes a list of models and information about model - steps and uses that info to compute needed enhancement factors for each - step and extract exo data corresponding to those enhancement factors. The - list of steps are then updated with the exo data for each step. - - Parameters - ---------- - file_paths : str | list - A single source h5 file or netcdf file to extract raster data from. - The string can be a unix-style file path which will be passed - through glob.glob. This is typically low-res WRF output or GCM - netcdf data that is source low-resolution data intended to be - sup3r resolved. - feature : str - Exogenous feature to extract from file_paths - models : list - List of models used with the given steps list. This list of models is - used to determine the input and output resolution and enhancement - factors for each model step which is then used to determine the target - shape for rasterized exo data. If enhancement factors are provided in - the steps list the model list is not needed. - steps : list - List of dictionaries containing info on which models to use for a - given step index and what type of exo data the step requires. e.g. - [{'model': 0, 'combine_type': 'input'}, - {'model': 0, 'combine_type': 'layer'}] - Each step entry can also contain enhancement factors. e.g. - [{'model': 0, 'combine_type': 'input', 's_enhance': 1, 't_enhance': 1}, - {'model': 0, 'combine_type': 'layer', 's_enhance': 3, 't_enhance': 1}] - source_file : str - Filepath to source wtk, nsrdb, or netcdf file to get hi-res data - from which will be mapped to the enhanced grid of the file_paths - input. Pixels from this file will be mapped to their nearest - low-res pixel in the file_paths input. Accordingly, the input - should be a significantly higher resolution than file_paths. - Warnings will be raised if the low-resolution pixels in file_paths - do not have unique nearest pixels from this exo source data. - input_handler_name : str - data handler class used by the exo handler. Provide a string name to - match a :class:`Rasterizer`. If None the correct handler will - be guessed based on file type and time series properties. This is - passed directly to the exo handler, along with input_handler_kwargs - input_handler_kwargs : dict | None - Any kwargs for initializing the `input_handler_name` class used by the - exo handler. - cache_dir : str | None - Directory for storing cache data. Default is './exo_cache'. If None - then no data will be cached. - """ - - file_paths: Union[str, list, pathlib.Path] - feature: str - steps: List[dict] - models: Optional[list] = None - source_file: Optional[str] = None - input_handler_name: Optional[str] = None - input_handler_kwargs: Optional[dict] = None - cache_dir: str = './exo_cache' - - @log_args - def __post_init__(self): - """Initialize `self.data`, perform checks on enhancement factors, and - update `self.data` for each model step with rasterized exo data for the - corresponding enhancement factors.""" - self.data = {self.feature: {'steps': []}} - en_check = all('s_enhance' in v for v in self.steps) - en_check = en_check and all('t_enhance' in v for v in self.steps) - en_check = en_check or self.models is not None - msg = ( - f'{self.__class__.__name__} needs s_enhance and t_enhance ' - 'provided in each step in steps list or models' - ) - assert en_check, msg - self.s_enhancements, self.t_enhancements = self._get_all_enhancement() - msg = ( - 'Need to provide s_enhance and t_enhance for each model' - 'step. If the step is temporal only (spatial only) then ' - 's_enhance = 1 (t_enhance = 1).' - ) - assert not any(s is None for s in self.s_enhancements), msg - assert not any(t is None for t in self.t_enhancements), msg - - self.get_all_step_data() - - def get_all_step_data(self): - """Get exo data for each model step. - - TODO: I think this could be simplified by getting the highest res data - first and then calling the xr.Dataset.coarsen() method according to - enhancement factors for different steps. - - """ - for i, (s_enhance, t_enhance) in enumerate( - zip(self.s_enhancements, self.t_enhancements) - ): - data = self.get_single_step_data( - feature=self.feature, - s_enhance=s_enhance, - t_enhance=t_enhance, - ).as_array() - step = SingleExoDataStep( - self.feature, - self.steps[i]['combine_type'], - self.steps[i]['model'], - data, - ) - self.data[self.feature]['steps'].append(step) - shapes = [ - None if step is None else step.shape - for step in self.data[self.feature]['steps'] - ] - logger.info( - 'Got exogenous_data of length {} with shapes: {}'.format( - len(self.data[self.feature]['steps']), shapes - ) - ) - - def _get_single_step_enhance(self, step): - """Get enhancement factors for exogenous data extraction - using exo_kwargs single model step. These factors are computed using - stored enhance attributes of each model and the model step provided. - If enhancement factors are already provided in step they are not - overwritten. - - Parameters - ---------- - step : dict - Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} - - Returns - ------- - updated_step : dict - Same as input dictionary with s_enhance, t_enhance added - """ - if all(key in step for key in ['s_enhance', 't_enhance']): - return step - - model_step = step['model'] - combine_type = step.get('combine_type', None) - msg = ( - f'Model index from exo_kwargs ({model_step} exceeds number ' - f'of model steps ({len(self.models)})' - ) - assert len(self.models) > model_step, msg - msg = ( - 'Received exo_kwargs entry without valid combine_type ' - '(input/layer/output)' - ) - assert combine_type.lower() in ('input', 'output', 'layer'), msg - s_enhancements = [model.s_enhance for model in self.models] - t_enhancements = [model.t_enhance for model in self.models] - if combine_type.lower() == 'input': - if model_step == 0: - s_enhance = 1 - t_enhance = 1 - else: - s_enhance = np.prod(s_enhancements[:model_step]) - t_enhance = np.prod(t_enhancements[:model_step]) - - else: - s_enhance = np.prod(s_enhancements[: model_step + 1]) - t_enhance = np.prod(t_enhancements[: model_step + 1]) - step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) - return step - - def _get_all_enhancement(self): - """Compute enhancement factors for all model steps for all features. - - Returns - ------- - s_enhancements: list - List of s_enhance factors for all model steps - t_enhancements: list - List of t_enhance factors for all model steps - """ - for i, step in enumerate(self.steps): - out = self._get_single_step_enhance(step) - self.steps[i] = out - s_enhancements = [step['s_enhance'] for step in self.steps] - t_enhancements = [step['t_enhance'] for step in self.steps] - return s_enhancements, t_enhancements - - def get_single_step_data(self, feature, s_enhance, t_enhance): - """Get the exogenous topography data - - Parameters - ---------- - feature : str - Name of feature to get exo data for - s_enhance : int - Spatial enhancement for this exogenous data step (cumulative for - all model steps up to the current step). - t_enhance : int - Temporal enhancement for this exogenous data step (cumulative for - all model steps up to the current step). - - Returns - ------- - data : Sup3rX - Sup3rX object containing exogenous data. `data.as_array()` gives - an array of shape (lats, lons, times, 1) - """ - - kwargs = { - 's_enhance': s_enhance, - 't_enhance': t_enhance, - 'feature': feature, - } - - params = signature(ExoRasterizer).parameters.values() - kwargs.update( - { - k.name: getattr(self, k.name) - for k in params - if hasattr(self, k.name) - } - ) - return ExoRasterizer(**kwargs).data diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 75fb5ee3ff..2117da5e71 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -97,12 +97,21 @@ def check_registry(self, feature) -> Union[T_Array, str, None]: if method is not None and hasattr(method, 'inputs'): fstruct = parse_feature(feature) inputs = [fstruct.map_wildcard(i) for i in method.inputs] - if all(f in self.data for f in inputs): + missing = [f for f in inputs if f not in self.data] + logger.debug('Found compute method (%s) for %s.', method, feature) + if any(missing): logger.debug( - f'Found compute method ({method}) for {feature}. ' - 'Proceeding with derivation.' + 'Missing required features %s. ' + 'Trying to derive these first.', + missing, ) - return self._run_compute(feature, method) + for f in missing: + self.data[f] = self.derive(f) + else: + logger.debug( + 'All required features %s found. Proceeding.', inputs + ) + return self._run_compute(feature, method) return None def _run_compute(self, feature, method): diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index 88edf4d937..92999e2af4 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -106,8 +106,8 @@ def _get_dset_tuple(self, dset, dims, chunks): warn(msg) arr_dims = Dimension.dims_4d_bc() else: - arr_dims = dims - return (arr_dims, arr, self.res.h5[dset].attrs) + arr_dims = dims[:len(arr.shape)] + return (arr_dims, arr, dict(self.res.h5[dset].attrs)) def _get_data_vars(self, dims): """Define data_vars dict for xr.Dataset construction.""" @@ -156,8 +156,12 @@ def load(self) -> xr.Dataset: """Wrap data in xarray.Dataset(). Handle differences with flattened and cached h5.""" dims = self._get_dims() - data_vars = self._get_data_vars(dims) coords = self._get_coords(dims) + data_vars = { + k: v + for k, v in self._get_data_vars(dims).items() + if k not in coords + } data_vars = {k: v for k, v in data_vars.items() if k not in coords} return xr.Dataset(coords=coords, data_vars=data_vars).astype( np.float32 diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 6634c93802..442b3ab142 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -20,12 +20,12 @@ from sup3r.postprocessing.writers.base import OutputHandler from sup3r.preprocessing.accessor import Sup3rX +from sup3r.preprocessing.base import Sup3rMeta from sup3r.preprocessing.loaders import Loader from sup3r.preprocessing.names import Dimension from sup3r.utilities.utilities import generate_random_string, nn_fill_array from ..utilities import ( - composite_info, get_class_kwargs, get_input_handler_class, get_source_type, @@ -39,9 +39,8 @@ class BaseExoRasterizer(ABC): """Class to extract high-res (4km+) data rasters for new spatially-enhanced datasets (e.g. GCM files after spatial enhancement) - using nearest neighbor mapping and aggregation from NREL datasets - (e.g. WTK or NSRDB) - + using nearest neighbor mapping and aggregation from NREL datasets (e.g. WTK + or NSRDB) Parameters ---------- @@ -382,7 +381,7 @@ def get_data(self): return Sup3rX(ds) -class ExoRasterizer: +class ExoRasterizer(BaseExoRasterizer, metaclass=Sup3rMeta): """Type agnostic `ExoRasterizer` class.""" TypeSpecificClasses: ClassVar = { @@ -405,4 +404,5 @@ def __new__(cls, file_paths, source_file, feature, **kwargs): ExoClass = cls.TypeSpecificClasses[get_source_type(source_file)] return ExoClass(**kwargs) - __signature__, __doc__ = composite_info(BaseExoRasterizer) + _signature_objs = (BaseExoRasterizer,) + __doc__ = BaseExoRasterizer.__doc__ diff --git a/tests/collections/test_stats.py b/tests/collections/test_stats.py index 0cd2c2fbaa..6fc5949be7 100644 --- a/tests/collections/test_stats.py +++ b/tests/collections/test_stats.py @@ -31,19 +31,14 @@ def test_stats_dual_data(): """Check accuracy of stats calcs across multiple containers with `type(self.data) == type(Sup3rDataset)` (e.g. a dual dataset).""" - dat = DummyData((10, 10, 100), ['windspeed', 'winddirection']) + feats = ['windspeed', 'winddirection'] + dat = DummyData((10, 10, 100), feats) dat.data = Sup3rDataset( low_res=Sup3rX(dat.data[0]._ds), high_res=Sup3rX(dat.data[0]._ds) ) - og_means = { - 'windspeed': np.nanmean(dat[..., 0]), - 'winddirection': np.nanmean(dat[..., 1]), - } - og_stds = { - 'windspeed': np.nanstd(dat[..., 0]), - 'winddirection': np.nanstd(dat[..., 1]), - } + og_means = {f: np.nanmean(dat[f]) for f in feats} + og_stds = {f: np.nanstd(dat[f]) for f in feats} direct_means = { 'windspeed': dat.data.mean( @@ -81,16 +76,11 @@ def test_stats_known(): """Check accuracy of stats calcs across multiple containers with known means / stds.""" - dat = DummyData((10, 10, 100), ['windspeed', 'winddirection']) + feats = ['windspeed', 'winddirection'] + dat = DummyData((10, 10, 100), feats) - og_means = { - 'windspeed': np.nanmean(dat[..., 0]), - 'winddirection': np.nanmean(dat[..., 1]), - } - og_stds = { - 'windspeed': np.nanstd(dat[..., 0]), - 'winddirection': np.nanstd(dat[..., 1]), - } + og_means = {f: np.nanmean(dat[f]) for f in feats} + og_stds = {f: np.nanstd(dat[f]) for f in feats} with TemporaryDirectory() as td: means = os.path.join(td, 'means.json') diff --git a/tests/data_handlers/test_dh_nc_cc.py b/tests/data_handlers/test_dh_nc_cc.py index 400b2282ac..64999c4cc8 100644 --- a/tests/data_handlers/test_dh_nc_cc.py +++ b/tests/data_handlers/test_dh_nc_cc.py @@ -168,10 +168,11 @@ def test_nc_cc_temp(): nc['ta'].attrs['units'] = 'K' nc = nc.swap_dims({'level': 'height'}) nc.to_netcdf(tmp_file) + + DataHandlerNCforCC.FEATURE_REGISTRY.update({'temperature': 'ta'}) dh = DataHandlerNCforCC( - tmp_file, features=['ta_100m', 'temperature_100m'] + tmp_file, features=['temperature_100m'] ) - assert dh['ta_100m'].attrs['units'] == 'C' assert dh['temperature_100m'].attrs['units'] == 'C' diff --git a/tests/data_wrapper/test_access.py b/tests/data_wrapper/test_access.py index a6efc2066e..f470653bc7 100644 --- a/tests/data_wrapper/test_access.py +++ b/tests/data_wrapper/test_access.py @@ -132,7 +132,7 @@ def test_change_values(): rand_u = RANDOM_GENERATOR.uniform(0, 20, data['u', ...].shape) data['u'] = rand_u - assert np.array_equal(rand_u, data['u', ...].compute()) + assert np.array_equal(rand_u, np.asarray(data['u', ...])) rand_v = RANDOM_GENERATOR.uniform(0, 10, data['v', ...].shape) data['v'] = rand_v @@ -140,7 +140,7 @@ def test_change_values(): data[['u', 'v']] = da.stack([rand_u, rand_v], axis=-1) assert np.array_equal( - data[['u', 'v']].as_array().data.compute(), + np.asarray(data[['u', 'v']].as_array()), da.stack([rand_u, rand_v], axis=-1).compute(), ) data['u', slice(0, 10)] = 0 diff --git a/tests/docs/test_doc_automation.py b/tests/docs/test_doc_automation.py index 0d1014cf0d..4b51adcf94 100644 --- a/tests/docs/test_doc_automation.py +++ b/tests/docs/test_doc_automation.py @@ -37,7 +37,9 @@ def test_full_docs(obj): """Make sure each arg in obj signature has an entry in the doc string.""" sig = signature(obj) - doc = NumpyDocString(obj.__init__.__doc__) + doc = obj.__init__.__doc__ + doc = doc if doc else obj.__doc__ + doc = NumpyDocString(doc) params = {p.name for p in sig.parameters.values()} doc_params = {p.name for p in doc['Parameters']} assert not params - doc_params