Skip to content

Commit

Permalink
ASTE Compatibility (#108)
Browse files Browse the repository at this point in the history
* grid object for SOSE ... much easier than llc

* this works for sose

* different way to do domain default

* grid would be a nice option here...

* an old stash commit

* aste grid definition

* kill unneeded tile connection

* central_longitude available for stereo plots

* fixed: pass grid through section calcs...

* auto colorbar label a la xarray style

* can supply coords separately, no grid to section trsp at depth

* typos

* more consistent separation of coords and ds, with tests

* typos

* typos are all I type

* more consistent coords parsing, new tests fixed

* ecco domain -> global ... generalize a bit read_mds...

* first round of updated tests

* trsp tests condensed, get basin tests for global llc90, copy broadcasting dimensions

* update common fixture

* squeeze

* slightly fewer tests ... getting ridiculous

* bump version

* only grab top level mask... whoops

* update tests accordingly

* get stf in there

* another test ... and more robust optional coords stuff

* updated tests for 3D mask...

* sose isnt even LLC
  • Loading branch information
timothyas committed Oct 31, 2020
1 parent c79ccd4 commit b69826c
Show file tree
Hide file tree
Showing 20 changed files with 587 additions and 482 deletions.
76 changes: 53 additions & 23 deletions ecco_v4_py/calc_meridional_trsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
RHO_CONST = 1029
HEAT_CAPACITY = 4000

def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None):
def calc_meridional_vol_trsp(ds,lat_vals,
basin_name=None,coords=None,grid=None):
"""Compute volumetric transport across latitude band in Sverdrups
Parameters
Expand All @@ -34,6 +35,9 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None):
denote ocean basin over which to compute streamfunction
If not specified, compute global quantity
see get_basin.get_available_basin_names for options
coords : xarray Dataset
separate dataset containing the coordinate information
YC, Z, drF, dyG, dxG, optionally maskW, maskS
grid : xgcm Grid object, optional
denotes LLC90 operations for xgcm, see ecco_utils.get_llc_grid
see also the [xgcm documentation](https://xgcm.readthedocs.io/en/latest/grid_topology.html)
Expand All @@ -52,12 +56,14 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None):
dimensions: 'time' (if provided), 'lat', and 'k'
"""

x_vol = ds['UVELMASS'] * ds['drF'] * ds['dyG']
y_vol = ds['VVELMASS'] * ds['drF'] * ds['dxG']
coords = _parse_coords(ds,coords,['Z','YC','drF','dyG','dxG'])

x_vol = ds['UVELMASS'] * coords['drF'] * coords['dyG']
y_vol = ds['VVELMASS'] * coords['drF'] * coords['dxG']

# Computes salt transport in m^3/s at each depth level
ds_out = meridional_trsp_at_depth(x_vol,y_vol,
cds=ds,
coords=coords,
lat_vals=lat_vals,
basin_name=basin_name,
grid=grid)
Expand All @@ -76,7 +82,8 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None):
return ds_out


def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None):
def calc_meridional_heat_trsp(ds,lat_vals,
basin_name=None,coords=None,grid=None):
"""Compute heat transport across latitude band in Petwatts
see calc_meridional_vol_trsp for argument documentation.
The only differences are:
Expand All @@ -85,6 +92,9 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None):
----------
ds : xarray Dataset
must contain fields 'ADVx_TH','ADVy_TH','DFxE_TH','DFyE_TH'
coords : xarray Dataset, optional
in case coordinates are in a separate dataset
only needs field 'YC' and optionally maskW, maskS
Returns
-------
Expand All @@ -99,12 +109,14 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None):
dimensions: 'time' (if provided), 'lat', and 'k'
"""

coords = _parse_coords(ds,coords,['Z','YC'])

x_heat = ds['ADVx_TH'] + ds['DFxE_TH']
y_heat = ds['ADVy_TH'] + ds['DFyE_TH']

# Computes heat transport in degC * m^3/s at each depth level
ds_out = meridional_trsp_at_depth(x_heat,y_heat,
cds=ds,
coords=coords,
lat_vals=lat_vals,
basin_name=basin_name,
grid=grid)
Expand All @@ -122,7 +134,8 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None):

return ds_out

def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None):
def calc_meridional_salt_trsp(ds,lat_vals,
basin_name=None,coords=None,grid=None):
"""Compute salt transport across latitude band in psu * Sv
see calc_meridional_vol_trsp for argument documentation.
The only differences are:
Expand All @@ -131,6 +144,9 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None):
----------
ds : xarray Dataset
must contain fields 'ADVx_SLT','ADVy_SLT','DFxE_SLT','DFyE_SLT'
coords : xarray Dataset, optional
in case coordinates are in a separate dataset
only needs field 'YC' and optionally maskW, maskS
Returns
-------
Expand All @@ -145,12 +161,14 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None):
dimensions: 'time' (if provided), 'lat', and 'k'
"""

coords = _parse_coords(ds,coords,['Z','YC'])

x_salt = ds['ADVx_SLT'] + ds['DFxE_SLT']
y_salt = ds['ADVy_SLT'] + ds['DFyE_SLT']

# Computes salt transport in psu * m^3/s at each depth level
ds_out = meridional_trsp_at_depth(x_salt,y_salt,
cds=ds,
coords=coords,
lat_vals=lat_vals,
basin_name=basin_name,
grid=grid)
Expand All @@ -170,7 +188,7 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None):

# ---------------------------------------------------------------------

def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds,
def meridional_trsp_at_depth(xfld, yfld, lat_vals, coords,
basin_name=None, grid=None, less_output=True):
"""
Compute transport of vector quantity at each depth level
Expand All @@ -182,8 +200,8 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds,
3D spatial (+ time, optional) field at west and south grid cell edges
lat_vals : float or list
latitude value(s) specifying where to compute transport
cds : xarray Dataset
with all LLC90 coordinates, including: maskW, maskS, YC
coords : xarray Dataset
only needs YC, and optionally maskW, maskS (defining wet points)
basin_name : string, optional
denote ocean basin over which to compute streamfunction
If not specified, compute global quantity
Expand All @@ -203,14 +221,14 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds,
"""

if grid is None:
grid = get_llc_grid(cds)
grid = get_llc_grid(coords)

# Initialize empty DataArray with coordinates and dims
ds_out = _initialize_trsp_data_array(cds, lat_vals)
ds_out = _initialize_trsp_data_array(coords, lat_vals)

# Get basin mask
maskW = cds['maskW'] if 'maskW' in cds else xr.ones_like(xfld)
maskS = cds['maskS'] if 'maskS' in cds else xr.ones_like(yfld)
maskW = coords['maskW'] if 'maskW' in coords else xr.ones_like(xfld)
maskS = coords['maskS'] if 'maskS' in coords else xr.ones_like(yfld)
if basin_name is not None:
maskW = get_basin_mask(basin_name,maskW)
maskS = get_basin_mask(basin_name,maskS)
Expand All @@ -225,7 +243,7 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds,
print ('calculating transport for latitutde ', lat)

# Compute mask for particular latitude band
lat_maskW, lat_maskS = vector_calc.get_latitude_masks(lat, cds['YC'], grid)
lat_maskW, lat_maskS = vector_calc.get_latitude_masks(lat, coords['YC'], grid)

# Sum horizontally
lat_trsp_x = (tmp_x * lat_maskW).sum(dim=['i_g','j','tile'])
Expand All @@ -236,12 +254,12 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds,
return ds_out


def _initialize_trsp_data_array(cds, lat_vals):
def _initialize_trsp_data_array(coords, lat_vals):
"""Create an xarray DataArray with time, depth, and latitude dims
Parameters
----------
cds : xarray Dataset
coords : xarray Dataset
contains LLC coordinates 'k' and (optionally) 'time'
lat_vals : int or array of ints
latitude value(s) rounded to the nearest degree
Expand All @@ -258,19 +276,31 @@ def _initialize_trsp_data_array(cds, lat_vals):
the original depth coordinate
"""

coords = OrderedDict()
dims = ()
lat_vals = np.array(lat_vals) if isinstance(lat_vals,list) else lat_vals
lat_vals = np.array([lat_vals]) if np.isscalar(lat_vals) else lat_vals
lat_vals = xr.DataArray(lat_vals,coords={'lat':lat_vals},dims=('lat',))

xda = xr.zeros_like(lat_vals*cds['k'])
xda = xda if 'time' not in cds.dims else xda.broadcast_like(cds['time'])
xda = xr.zeros_like(coords['k']*lat_vals)
xda = xda if 'time' not in coords.dims else xda.broadcast_like(coords['time']).copy()

# Convert to dataset to add Z coordinate
xds = xda.to_dataset(name='trsp_z')
xds['Z'] = cds['Z']
xds['Z'] = coords['Z']
xds = xds.set_coords('Z')

return xds

def _parse_coords(ds,coords,coordlist):
if coords is not None:
return coords
else:
for f in set(['maskW','maskS']).intersection(ds.reset_coords().keys()):
coordlist.append(f)

if 'time' in ds.dims:
coordlist.append('time')

dsout = ds[coordlist]
if 'domain' in ds.attrs:
dsout.attrs['domain'] = ds.attrs['domain']
return dsout
Loading

0 comments on commit b69826c

Please sign in to comment.