diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 2c50595a..c46593eb 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -29,7 +29,7 @@ #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] #: Type alias for supported spatial axis keys. -SpatialAxis = Literal["X", "Y"] +SpatialAxis = Literal["X", "Y", "Z"] SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis) #: Type alias for a tuple of floats/ints for the regional selection bounds. RegionAxisBounds = Tuple[float, float] @@ -73,10 +73,12 @@ def average( keep_weights: bool = False, lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + lev_bounds: Optional[RegionAxisBounds] = None, ) -> xr.Dataset: """ - Calculates the spatial average for a rectilinear grid over an optionally - specified regional domain. + Calculates the weighted spatial and/or vertical average for a + rectilinear grid over an optionally specified regional and/or vertical + domain. Operations include: @@ -101,7 +103,7 @@ def average( average. axis : List[SpatialAxis] List of axis dimensions to average over, by default ["X", "Y"]. - Valid axis keys include "X" and "Y". + Valid axis keys include "X", "Y", and "Z". weights : {"generate", xr.DataArray}, optional If "generate", then weights are generated. Otherwise, pass a DataArray containing the regional weights used for weighted @@ -122,6 +124,10 @@ def average( ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. + lev_bounds : Optional[RegionAxisBounds], optional + A tuple of floats/ints for the regional lower and upper level + boundaries. This arg is used when calculating axis weights, but is + ignored if ``weights`` are supplied. The default is None. Returns ------- @@ -143,11 +149,15 @@ def average( >>> >>> ds.lon.attrs["axis"] >>> X + >>> + >>> ds.level.attrs["axis"] + >>> Z Set the 'axis' attribute for the required coordinates if it isn't: >>> ds.lat.attrs["axis"] = "Y" >>> ds.lon.attrs["axis"] = "X" + >>> ds.level.attrs["axis"] = "Z" Call spatial averaging method: @@ -167,6 +177,10 @@ def average( >>> ts_zonal = ds.spatial.average("tas", axis=["X"])["tas"] + Get the vertical average (between 100 and 1000 hPa): + + >>> ta_column = ds.spatial.average("ta", axis=["Z"], lev_bounds=(100, 1000))["ta"] + Using custom weights for averaging: >>> # The shape of the weights must align with the data var. @@ -178,6 +192,12 @@ def average( >>> >>> ts_global = ds.spatial.average("tas", axis=["X", "Y"], >>> weights=weights)["tas"] + + Notes: + ------ + Weights are generally computed as the difference between the bounds. If + sub-selecting a region, the units must match the axis units (e.g., + Pa/hPa or m/km). """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) @@ -188,7 +208,11 @@ def average( self._validate_region_bounds("Y", lat_bounds) if lon_bounds is not None: self._validate_region_bounds("X", lon_bounds) - self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var) + if lev_bounds is not None: + self._validate_region_bounds("Z", lev_bounds) + self._weights = self.get_weights( + axis, lat_bounds, lon_bounds, lev_bounds, data_var + ) elif isinstance(weights, xr.DataArray): self._weights = weights @@ -205,6 +229,7 @@ def get_weights( axis: List[SpatialAxis], lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + lev_bounds: Optional[RegionAxisBounds] = None, data_var: Optional[str] = None, ) -> xr.DataArray: """ @@ -216,9 +241,9 @@ def get_weights( weights are then combined to form a DataArray of weights that can be used to perform a weighted (spatial) average. - If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells - outside this selected regional domain are given zero weight. Grid cells - that are partially in this domain are given partial weight. + If ``lat_bounds``, ``lon_bounds``, or ``lev_bounds`` are supplied, then + grid cells outside this selected regional domain are given zero weight. + Grid cells that are partially in this domain are given partial weight. Parameters ---------- @@ -230,6 +255,9 @@ def get_weights( lon_bounds : Optional[RegionAxisBounds] Tuple of longitude boundaries for regional selection, by default None. + lev_bounds : Optional[RegionAxisBounds] + Tuple of level boundaries for vertical selection, by default + None. data_var: Optional[str] The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" @@ -246,9 +274,7 @@ def get_weights( Notes ----- This method was developed for rectilinear grids only. ``get_weights()`` - recognizes and operate on latitude and longitude, but could be extended - to work with other standard geophysical dimensions (e.g., time, depth, - and pressure). + recognizes and operate on latitude, longitude, and vertical levels. """ Bounds = TypedDict( "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} @@ -267,6 +293,12 @@ def get_weights( if lat_bounds is not None else None, }, + "Z": { + "weights_method": self._get_vertical_weights, + "region": np.array(lev_bounds, dtype="float") + if lev_bounds is not None + else None, + }, } axis_weights: AxisWeights = {} @@ -476,6 +508,32 @@ def _get_latitude_weights( weights = self._calculate_weights(d_bounds) return weights + def _get_vertical_weights( + self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + ) -> xr.DataArray: + """Gets weights for the vertical axis. + + This method scales the domain to a region (if selected) and returns weights + proportional to the difference between each pair of level bounds. + + Parameters + ---------- + domain_bounds : xr.DataArray + The array of bounds for the vertical domain. + region_bounds : Optional[np.ndarray] + The array of bounds for vertical selection. + + Returns + ------- + xr.DataArray + The vertical axis weights. + """ + if region_bounds is not None: + domain_bounds = self._scale_domain_to_region(domain_bounds, region_bounds) + + weights = self._calculate_weights(domain_bounds) + return weights + def _calculate_weights(self, domain_bounds: xr.DataArray): """Calculate weights for the domain.