Skip to content

Commit

Permalink
added recursive derivation. simplified exo handling: get highest res …
Browse files Browse the repository at this point in the history
…and then use xr.Dataset().coarsen for each model step. moved that simplified code from exo dir contents back into exo.py
  • Loading branch information
bnb32 committed Aug 5, 2024
1 parent c045b81 commit a27110e
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 409 deletions.
227 changes: 114 additions & 113 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,104 @@ def __init__(self, ds: Union[xr.Dataset, Self]):
self._features = None
self.time_slice = None

def __getattr__(self, attr):
"""Get attribute and cast to type(self) if a xr.Dataset is returned
first."""
out = getattr(self._ds, attr)
return type(self)(out) if isinstance(out, xr.Dataset) else out

def __mul__(self, other):
"""Multiply Sup3rX object by other. Used to compute weighted means and
stdevs."""
try:
return type(self)(other * self._ds)
except Exception as e:
raise NotImplementedError(
f'Multiplication not supported for type {type(other)}.'
) from e

def __rmul__(self, other):
return self.__mul__(other)

def __pow__(self, other):
"""Raise Sup3rX object to an integer power. Used to compute weighted
standard deviations."""
try:
return type(self)(self._ds**other)
except Exception as e:
raise NotImplementedError(
f'Exponentiation not supported for type {type(other)}.'
) from e

def __setitem__(self, keys, data):
"""
Parameters
----------
keys : str | list | tuple
keys to set. This can be a string like 'temperature' or a list
like ['u', 'v']. `data` will be iterated over in the latter case.
data : T_Array | xr.DataArray
array object used to set variable data. If `variable` is a list
then this is expected to have a trailing dimension with length
equal to the length of the list.
"""
if _is_strings(keys):
if isinstance(keys, (list, tuple)):
data_dict = {v: data[..., i] for i, v in enumerate(keys)}
else:
data_dict = {keys.lower(): data}
_ = self.assign(data_dict)
elif isinstance(keys[0], str) and keys[0] not in self.coords:
feats, slices = self._parse_keys(keys)
var_array = self[feats].data
var_array[tuple(slices.values())] = data
_ = self.assign({feats: var_array})
else:
msg = f'Cannot set values for keys {keys}'
logger.error(msg)
raise KeyError(msg)

def __getitem__(self, keys) -> Union[T_Array, Self]:
"""Method for accessing variables or attributes. keys can optionally
include a feature name or list of feature names as the first entry of a
keys tuple. When keys take the form of numpy style indexing we return a
dask or numpy array, depending on whether contained data has been
loaded into memory, otherwise we return xarray or Sup3rX objects"""
features, slices = self._parse_keys(keys)
out = self._ds[features]
slices = {k: v for k, v in slices.items() if k in out.dims}
if self._needs_fancy_indexing(slices.values()):
out = self.as_array(data=out, features=features)
return out.vindex[tuple(slices.values())]

out = out.isel(**slices)
# numpy style indexing requested so we return an array (dask or np)
if isinstance(keys, (slice, tuple)) or _contains_ellipsis(keys):
return self.as_array(data=out, features=features)
if isinstance(out, xr.Dataset):
return type(self)(out)
return out.transpose(*ordered_dims(out.dims), ...)

def __contains__(self, vals):
"""Check if self._ds contains `vals`.
Parameters
----------
vals : str | list
Values to check. Can be a list of strings or a single string.
Examples
--------
bool(['u', 'v'] in self)
bool('u' in self)
"""
feature_check = isinstance(vals, (list, tuple)) and all(
isinstance(s, str) for s in vals
)
if feature_check:
return all(s.lower() in self._ds for s in vals)
return self._ds.__contains__(vals)

def compute(self, **kwargs):
"""Load `._ds` into memory. This updates the internal `xr.Dataset` if
it has not been loaded already."""
Expand Down Expand Up @@ -128,52 +226,20 @@ def update_ds(self, new_dset, attrs=None):
"""
coords = dict(self._ds.coords)
data_vars = dict(self._ds.data_vars)
coords.update(
{
k: dims_array_tuple(v)
for k, v in new_dset.items()
if k in coords
}
)
data_vars.update(
{
k: dims_array_tuple(v)
for k, v in new_dset.items()
if k not in coords
}
)
new_coords = {
k: dims_array_tuple(v) for k, v in new_dset.items() if k in coords
}
coords.update(new_coords)
new_data = {
k: dims_array_tuple(v)
for k, v in new_dset.items()
if k not in coords
}
data_vars.update(new_data)

self._ds = xr.Dataset(coords=coords, data_vars=data_vars, attrs=attrs)
return type(self)(self._ds)

def __getattr__(self, attr):
"""Get attribute and cast to type(self) if a xr.Dataset is returned
first."""
out = getattr(self._ds, attr)
return type(self)(out) if isinstance(out, xr.Dataset) else out

def __mul__(self, other):
"""Multiply Sup3rX object by other. Used to compute weighted means and
stdevs."""
try:
return type(self)(other * self._ds)
except Exception as e:
raise NotImplementedError(
f'Multiplication not supported for type {type(other)}.'
) from e

def __rmul__(self, other):
return self.__mul__(other)

def __pow__(self, other):
"""Raise Sup3rX object to an integer power. Used to compute weighted
standard deviations."""
try:
return type(self)(self._ds**other)
except Exception as e:
raise NotImplementedError(
f'Exponentiation not supported for type {type(other)}.'
) from e

@property
def name(self):
"""Name of dataset. Used to label datasets when grouped in
Expand All @@ -198,9 +264,13 @@ def name(self, value):
self._ds.attrs['name'] = value

def isel(self, *args, **kwargs):
"""Override xr.Dataset.sel to cast back to Sup3rX object."""
"""Override xr.Dataset.isel to cast back to Sup3rX object."""
return type(self)(self._ds.isel(*args, **kwargs))

def coarsen(self, *args, **kwargs):
"""Override xr.Dataset.coarsen to cast back to Sup3rX object."""
return type(self)(self._ds.coarsen(*args, **kwargs))

@property
def dims(self):
"""Return dims with our own enforced ordering."""
Expand Down Expand Up @@ -309,47 +379,6 @@ def _parse_keys(self, keys):
dim_keys = _parse_ellipsis(dim_keys, dim_num=len(self._ds.dims))
return features, dict(zip(ordered_dims(self._ds.dims), dim_keys))

def __getitem__(self, keys) -> Union[T_Array, Self]:
"""Method for accessing variables or attributes. keys can optionally
include a feature name or list of feature names as the first entry of a
keys tuple. When keys take the form of numpy style indexing we return a
dask or numpy array, depending on whether contained data has been
loaded into memory, otherwise we return xarray or Sup3rX objects"""
features, slices = self._parse_keys(keys)
out = self._ds[features]
slices = {k: v for k, v in slices.items() if k in out.dims}
if self._needs_fancy_indexing(slices.values()):
out = self.as_array(data=out, features=features)
return out.vindex[*slices.values()]

out = out.isel(**slices)
# numpy style indexing requested so we return an array (dask or np)
if isinstance(keys, (slice, tuple)) or _contains_ellipsis(keys):
return self.as_array(data=out, features=features)
if isinstance(out, xr.Dataset):
return type(self)(out)
return out.transpose(*ordered_dims(out.dims), ...)

def __contains__(self, vals):
"""Check if self._ds contains `vals`.
Parameters
----------
vals : str | list
Values to check. Can be a list of strings or a single string.
Examples
--------
bool(['u', 'v'] in self)
bool('u' in self)
"""
feature_check = isinstance(vals, (list, tuple)) and all(
isinstance(s, str) for s in vals
)
if feature_check:
return all(s.lower() in self._ds for s in vals)
return self._ds.__contains__(vals)

def _add_dims_to_data_dict(self, vals):
"""Add dimensions to vals entries if needed. This is used to set values
of `self._ds` which can require dimensions to be explicitly specified
Expand Down Expand Up @@ -415,34 +444,6 @@ def assign(self, vals: Dict[str, Union[T_Array, tuple]]):
self._ds = self._ds.assign(data_dict)
return type(self)(self._ds)

def __setitem__(self, keys, data):
"""
Parameters
----------
keys : str | list | tuple
keys to set. This can be a string like 'temperature' or a list
like ['u', 'v']. `data` will be iterated over in the latter case.
data : T_Array | xr.DataArray
array object used to set variable data. If `variable` is a list
then this is expected to have a trailing dimension with length
equal to the length of the list.
"""
if _is_strings(keys):
if isinstance(keys, (list, tuple)):
data_dict = {v: data[..., i] for i, v in enumerate(keys)}
else:
data_dict = {keys.lower(): data}
_ = self.assign(data_dict)
elif isinstance(keys[0], str) and keys[0] not in self.coords:
feats, slices = self._parse_keys(keys)
var_array = self[feats].data
var_array[*slices.values()] = data
_ = self.assign({feats: var_array})
else:
msg = f'Cannot set values for keys {keys}'
logger.error(msg)
raise KeyError(msg)

@property
def features(self):
"""Features in this container."""
Expand Down
23 changes: 15 additions & 8 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
logger = logging.getLogger(__name__)


def _get_class_info(namespace):
sig_objs = namespace.get('_signature_objs', None)
skips = namespace.get('_skip_params', None)
_sig = _doc = None
if sig_objs:
_sig, _doc = composite_info(sig_objs, skip_params=skips)
return _sig, _doc


class Sup3rMeta(ABCMeta, type):
"""Meta class to define __name__, __signature__, and __subclasscheck__ of
composite and derived classes. This allows us to still resolve a signature
Expand All @@ -29,15 +38,13 @@ class Sup3rMeta(ABCMeta, type):

def __new__(mcs, name, bases, namespace, **kwargs): # noqa: N804
"""Define __name__ and __signature__"""
name = namespace.get('__name__', name)
sig_objs = namespace.get('_signature_objs', None)
skips = namespace.get('_skip_params', None)
if sig_objs:
_sig, _doc = composite_info(sig_objs, skip_params=skips)
_sig, _doc = _get_class_info(namespace)
if _sig:
namespace['__signature__'] = _sig
if '__init__' in namespace:
namespace['__init__'].__signature__ = _sig
namespace['__init__'].__doc__ = _doc
if '__init__' in namespace and _sig:
namespace['__init__'].__signature__ = _sig
if '__init__' in namespace and _doc:
namespace['__init__'].__doc__ = _doc
return super().__new__(mcs, name, bases, namespace, **kwargs)

def __subclasscheck__(cls, subclass):
Expand Down
Loading

0 comments on commit a27110e

Please sign in to comment.