Skip to content

Commit

Permalink
era downloader test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Sep 16, 2024
1 parent 2aa353b commit 7c8fe53
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 82 deletions.
11 changes: 10 additions & 1 deletion sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dask
import dask.array as da
import numpy as np
from warnings import warn

from sup3r.preprocessing.base import Container
from sup3r.preprocessing.names import Dimension
Expand Down Expand Up @@ -179,9 +180,17 @@ def get_chunksizes(cls, dset, data, chunks):
data_var = data_var.unify_chunks()
chunksizes = tuple(d[0] for d in data_var.chunksizes.values())
chunksizes = chunksizes if chunksizes else None
if chunksizes is not None:
chunkmem = np.prod(chunksizes) * data_var.dtype.itemsize / 1e9
if chunkmem > 4:
msg = (
'Chunks cannot be larger than 4GB. Given chunksizes %s '
'result in %sGB. Will use chunksizes = None')
logger.warning(msg, chunksizes, chunkmem)
warn(msg % (chunksizes, chunkmem))
chunksizes = None
return data_var, chunksizes

# pylint : disable=unused-argument
@classmethod
def write_h5(
cls,
Expand Down
193 changes: 117 additions & 76 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
from calendar import monthrange
from warnings import warn

import dask
import dask.array as da
import numpy as np
from rex import init_logger

from sup3r.preprocessing import Loader
from sup3r.preprocessing import Cacher, Loader
from sup3r.preprocessing.loaders.utilities import (
standardize_names,
standardize_values,
Expand All @@ -31,6 +30,8 @@
)
from sup3r.preprocessing.utilities import log_args

IGNORE_VARS = ('number', 'expver')

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -572,19 +573,23 @@ def run_month(
cls.make_monthly_file(year, month, monthly_file_pattern, variables)

@classmethod
def run_year(
def run(
cls,
year,
area,
levels,
monthly_file_pattern,
yearly_file=None,
yearly_file_pattern=None,
months=None,
overwrite=False,
max_workers=None,
variables=None,
product_type='reanalysis',
chunks='auto',
combine_all_files=False,
res_kwargs=None,
):
"""Run routine for all months in the requested year.
"""Run routine for all requested months in the requested year.
Parameters
----------
Expand All @@ -595,7 +600,7 @@ def run_year(
[max_lat, min_lon, min_lat, max_lon]
levels : list
List of pressure levels to download.
monthly_file_pattern : str
file_pattern : str
Pattern for combined monthly output file. Must include year and
month format keys. e.g. 'era5_{year}_{month}_combined.nc'
yearly_file : str
Expand All @@ -611,83 +616,108 @@ def run_year(
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'
combine_all_files : bool
Whether to combine separate yearly variable files into a single
yearly file with all variables included
"""
for var in variables:
cls.run_for_var(
year=year,
area=area,
levels=levels,
months=months,
monthly_file_pattern=monthly_file_pattern,
yearly_file_pattern=yearly_file_pattern,
overwrite=overwrite,
variable=var,
product_type=product_type,
max_workers=max_workers,
chunks=chunks,
res_kwargs=res_kwargs,
)

if (
yearly_file is not None
and os.path.exists(yearly_file)
and not overwrite
cls.all_vars_exist(
year=year,
file_pattern=yearly_file_pattern,
variables=variables,
)
and combine_all_files
):
logger.info('%s already exists and overwrite=False.', yearly_file)
msg = (
'monthly_file_pattern must have {year}, {month}, and {var} '
'format keys'
)
assert all(
key in monthly_file_pattern
for key in ('{year}', '{month}', '{var}')
), msg

tasks = []
for month in range(1, 13):
for var in variables:
task = dask.delayed(cls.run_month)(
year=year,
month=month,
area=area,
levels=levels,
monthly_file_pattern=monthly_file_pattern,
overwrite=overwrite,
variables=[var],
product_type=product_type,
)
tasks.append(task)

if max_workers == 1:
dask.compute(*tasks, scheduler='single-threaded')
else:
dask.compute(*tasks, scheduler='threads', num_workers=max_workers)

if yearly_file is not None:
cls.make_yearly_file(year, monthly_file_pattern, yearly_file)
cls.make_yearly_file(
year,
yearly_file_pattern,
variables,
chunks=chunks,
res_kwargs=res_kwargs,
)

@classmethod
def make_monthly_file(cls, year, month, file_pattern, variables):
"""Combine monthly variable files into a single monthly file.
def make_yearly_var_file(
cls,
year,
monthly_file_pattern,
yearly_file_pattern,
variable,
chunks='auto',
res_kwargs=None,
):
"""Combine monthly variable files into a single yearly variable file.
Parameters
----------
year : int
Year used to download data
month : int
Month used to download data
file_pattern : str
monthly_file_pattern : str
File pattern for monthly variable files. Must have year, month, and
var format keys. e.g. './era_{year}_{month}_{var}_combined.nc'
variables : list
List of variables downloaded.
yearly_file_pattern : str
File pattern for yearly variable files. Must have year and var
format keys. e.g. './era_{year}_{var}_combined.nc'
variable : string
Variable name for the files to be combined.
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
res_kwargs : None | dict
Keyword arguments for base resource handler, like
``xr.open_mfdataset.`` This is passed to a ``Loader`` object and
then used in the base loader contained by that obkect.
"""
msg = (
f'Not all variable files with file_patten {file_pattern} for '
f'year {year} and month {month} exist.'
)
assert cls.all_vars_exist(year, month, file_pattern, variables), msg

files = [
file_pattern.format(year=year, month=str(month).zfill(2), var=var)
for var in variables
monthly_file_pattern.format(
year=year, month=str(month).zfill(2), var=variable
)
for month in range(1, 13)
]

outfile = file_pattern.replace('_{var}', '').format(
year=year, month=str(month).zfill(2)
outfile = yearly_file_pattern.format(year=year, var=variable)
cls._combine_files(
files, outfile, chunks=chunks, res_kwargs=res_kwargs
)
cls._combine_files(files, outfile)

@classmethod
def _combine_files(cls, files, outfile, kwargs=None):
def _combine_files(cls, files, outfile, chunks='auto', res_kwargs=None):
if not os.path.exists(outfile):
logger.info(f'Combining {files} into {outfile}.')
try:
cls._write_dsets(files, out_file=outfile, kwargs=kwargs)
res_kwargs = res_kwargs or {}
loader = Loader(files, res_kwargs=res_kwargs)
tmp_file = cls.get_tmp_file(outfile)
for ignore_var in IGNORE_VARS:
if ignore_var in loader.coords:
loader.data = loader.data.drop_vars(ignore_var)
Cacher.write_netcdf(
data=loader.data,
out_file=tmp_file,
max_workers=1,
chunks=chunks,
)
os.replace(tmp_file, outfile)
logger.info('Moved %s to %s.', tmp_file, outfile)
except Exception as e:
msg = f'Error combining {files}.'
logger.error(msg)
Expand All @@ -696,33 +726,44 @@ def _combine_files(cls, files, outfile, kwargs=None):
logger.info(f'{outfile} already exists.')

@classmethod
def make_yearly_file(cls, year, file_pattern, yearly_file):
"""Combine monthly files into a single file.
def make_yearly_file(
cls, year, file_pattern, variables, chunks='auto', res_kwargs=None
):
"""Combine yearly variable files into a single file.
Parameters
----------
year : int
Year of monthly data to make into a yearly file.
Year for the data to make into a yearly file.
file_pattern : str
File pattern for monthly files. Must have year and month format
keys. e.g. './era_uv_{year}_{month}_combined.nc'
yearly_file : str
Name of yearly file made from monthly files.
File pattern for output files. Must have year and var
format keys. e.g. './era_{year}_{var}_combined.nc'
variables : list
List of variables corresponding to the yearly variable files to
combine.
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
res_kwargs : None | dict
Keyword arguments for base resource handler, like
``xr.open_mfdataset.`` This is passed to a ``Loader`` object and
then used in the base loader contained by that obkect.
"""
msg = (
f'Not all monthly files with file_patten {file_pattern} for '
f'year {year} exist.'
)
assert cls.all_months_exist(year, file_pattern), msg

files = [
file_pattern.replace('_{var}', '').format(
year=year, month=str(month).zfill(2)
)
for month in range(1, 13)
]
kwargs = {'combine': 'nested', 'concat_dim': 'time'}
cls._combine_files(files, yearly_file, kwargs)
files = [file_pattern.format(year=year, var=var) for var in variables]
yearly_file = (
file_pattern.replace('_{var}_', '')
.replace('_{var}', '')
.format(year=year)
)
cls._combine_files(
files, yearly_file, res_kwargs=res_kwargs, chunks=chunks
)

@classmethod
def run_qa(cls, file, res_kwargs=None, log_file=None):
Expand Down
31 changes: 31 additions & 0 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,42 @@
from packaging import version
from scipy import ndimage as nd

from sup3r.preprocessing.utilities import get_class_kwargs

logger = logging.getLogger(__name__)

RANDOM_GENERATOR = np.random.default_rng(seed=42)


def merge_datasets(files, **kwargs):
"""Merge xr.Datasets after some standardization. This useful when
xr.open_mfdatasets fails due to different time index formats or coordinate
names, for example."""
dsets = [xr.open_mfdataset(f, **kwargs) for f in files]
time_indices = []
for i, dset in enumerate(dsets):
if 'time' in dset and dset.time.size > 1:
ti = pd.DatetimeIndex(dset.time)
dset['time'] = ti
dsets[i] = dset
time_indices.append(ti.to_series())
if 'latitude' in dset.dims:
dset = dset.swap_dims({'latitude': 'south_north'})
dsets[i] = dset
if 'longitude' in dset.dims:
dset = dset.swap_dims({'longitude': 'west_east'})
dsets[i] = dset
out = xr.merge(dsets, **get_class_kwargs(xr.merge, kwargs))
msg = (
'Merged time index does not have the same number of time steps '
'(%s) as the sum of the individual time index steps (%s).'
)
merged_size = out.time.size
summed_size = pd.concat(time_indices).drop_duplicates().size
assert merged_size == summed_size, msg % (merged_size, summed_size)
return out


def xr_open_mfdataset(files, **kwargs):
"""Wrapper for xr.open_mfdataset with default opening options."""
default_kwargs = {'engine': 'netcdf4'}
Expand Down
17 changes: 12 additions & 5 deletions tests/utilities/test_era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_era_dl(tmpdir_factory):
month=month,
area=area,
levels=levels,
monthly_file_pattern=file_pattern,
file_pattern=file_pattern,
variables=variables,
)
for v in variables:
Expand All @@ -86,18 +86,25 @@ def test_era_dl_year(tmpdir_factory):
file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{month}_{var}.nc'
)
yearly_file = os.path.join(tmpdir_factory.mktemp('tmp'), 'era5_final.nc')
EraDownloaderTester.run_year(
yearly_file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{var}_final.nc'
)
EraDownloaderTester.run(
year=2000,
area=[50, -130, 23, -65],
levels=[1000, 900, 800],
variables=variables,
monthly_file_pattern=file_pattern,
yearly_file=yearly_file,
yearly_file_pattern=yearly_file_pattern,
max_workers=1,
combine_all_files=True,
res_kwargs={'compat': 'override', 'engine': 'netcdf4'},
)

tmp = xr_open_mfdataset(yearly_file)
combined_file = yearly_file_pattern.replace('_{var}_', '').format(
year=2000
)
tmp = xr_open_mfdataset(combined_file)
for v in variables:
standard_name = FEATURE_NAMES.get(v, v)
assert standard_name in tmp

0 comments on commit 7c8fe53

Please sign in to comment.