Skip to content

Commit

Permalink
feat(ecmwf): Add full ensemble model (#224)
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Dudfield <[email protected]>
  • Loading branch information
devsjc and peterdudfield authored Jan 15, 2025
1 parent fdb9afb commit 2bb79a9
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 15 deletions.
23 changes: 14 additions & 9 deletions src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import datetime as dt
import json
import logging
import math
from importlib.metadata import PackageNotFoundError, version

import dask.array
Expand Down Expand Up @@ -87,6 +88,8 @@ class NWPDimensionCoordinateMap:
"""The variables in the forecast data."""
ensemble_stat: list[str] | None = None
"""The relevant ensemble statistics of the forecast data."""
ensemble_member: list[int] | None = None
"""The ensemble member numbers making up the ensemble forecast."""
latitude: list[float] | None = None
"""The latitude coordinates of the forecast grid in degrees.
Expand Down Expand Up @@ -222,6 +225,8 @@ def from_pandas(
],
ensemble_stat=pd_indexes["ensemble_stat"].to_list() \
if "ensemble_stat" in pd_indexes else None,
ensemble_member=pd_indexes["ensemble_member"].to_list() \
if "ensemble_member" in pd_indexes else None,
latitude=pd_indexes["latitude"].to_list() \
if "latitude" in pd_indexes else None,
longitude=pd_indexes["longitude"].to_list() \
Expand Down Expand Up @@ -333,7 +338,7 @@ def determine_region(
return Failure(
KeyError(
"Cannot find slices in non-matching coordinate mappings: "
"both objects must have identical dimensions (rank and labels)."
"both objects must have identical dimensions (rank and labels). "
f"Got: {inner.dims} (inner) and {self.dims} (outer).",
),
)
Expand Down Expand Up @@ -414,17 +419,17 @@ def chunking(self, chunk_count_overrides: dict[str, int]) -> dict[str, int]:
chunk_count_overrides: A dictionary mapping dimension labels to the
number of chunks to split the dimension into.
"""
out_dict: dict[str, int] = {
"init_time": 1,
"step": 1,
"variable": 1,
} | {
dim: len(getattr(self, dim)) // chunk_count_overrides.get(dim, 2)
if len(getattr(self, dim)) > 8 else 1
default_dict: dict[str, int] = {
dim: 1
if len(getattr(self, dim)) <= 8 or dim in ["init_time", "step", "variable"]
else math.ceil(len(getattr(self, dim)))
for dim in self.dims
if dim not in ["init_time", "step", "variable"]
}

out_dict = {}
for key in default_dict:
out_dict[key] = chunk_count_overrides.get(key, default_dict[key])

return out_dict


Expand Down
27 changes: 27 additions & 0 deletions src/nwp_consumer/internal/entities/modelmetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,33 @@ class Models:
)
"""Summary statistics from ECMWF's Ensemble Forecast System."""

ECMWF_ENS_0P1DEGREE: ModelMetadata = ModelMetadata(
name="ENS",
resolution="0.1 degrees",
expected_coordinates=NWPDimensionCoordinateMap(
init_time=[],
step=list(range(0, 49, 1)),
variable=[
Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL,
Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL,
Parameter.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL,
Parameter.WIND_U_COMPONENT_10m,
Parameter.WIND_V_COMPONENT_10m,
Parameter.SNOW_DEPTH_GL,
Parameter.CLOUD_COVER_HIGH,
Parameter.CLOUD_COVER_MEDIUM,
Parameter.CLOUD_COVER_LOW,
Parameter.CLOUD_COVER_TOTAL,
Parameter.TEMPERATURE_SL,
Parameter.TOTAL_PRECIPITATION_RATE_GL,
],
ensemble_member=list(range(1, 51)),
latitude=[v/10 for v in range(900, -900, -1)],
longitude=[v/10 for v in range(-1800, 1800, 1)],
),
)
"""Full ensemble data from ECMWF's Ensemble Forecast System."""

NCEP_GFS_1DEGREE: ModelMetadata = ModelMetadata(
name="NCEP-GFS",
resolution="1 degree",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,15 @@ class _MARSRequest:
- "fc" for forecast
- "em" for ensemble mean
- "es" for ensemble standard deviation
- "pf" for perturbed forecast (full ensemble)
"""
grid: str = "0.1/0.1"
"""The grid resolution."""
number: list[int] | None = None
"""The ensemble member numbers to request.
Only relevant in full ensemble data requests.
"""

def as_ensemble_mean_request(self) -> "_MARSRequest":
"""Create a new request for the ensemble mean."""
Expand Down Expand Up @@ -144,10 +150,13 @@ class = {self.classfication},
date = {self.init_time:%Y%m%d},
expver = {self.expver},
levtype = {self.levtype},
stream = {self.stream},
param = {param},
step = {step},
stream = {self.stream},
time = {self.init_time:%H},
"""
marsReq += f"number = {'/'.join(map(str, self.number))}," if self.number is not None else ""
marsReq += f"""
type = {self.field_type},
area = {self.nwse},
grid = {self.grid},
Expand Down Expand Up @@ -215,6 +224,10 @@ def repository() -> entities.RawRepositoryMetadata:
"hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india"),
"ens-stat-india": entities.Models.ECMWF_ENS_STAT_0P1DEGREE.with_region("india"),
"ens-stat-uk": entities.Models.ECMWF_ENS_STAT_0P1DEGREE.with_region("uk"),
"ens-uk": entities.Models.ECMWF_ENS_0P1DEGREE.with_region("uk")\
.with_chunk_count_overrides(
{"latitude": 2, "longitude": 2, "variable": 1, "ensemble_member": 5},
),
},
)

Expand Down Expand Up @@ -251,15 +264,22 @@ def authenticate(cls) -> ResultE["ECMWFMARSRawRepository"]:
def fetch_init_data(
self, it: dt.datetime,
) -> Iterator[Callable[..., ResultE[list[xr.DataArray]]]]:
base_req: _MARSRequest = _MARSRequest(
req: _MARSRequest = _MARSRequest(
params=self.model().expected_coordinates.variable,
init_time=it,
steps=self.model().expected_coordinates.step,
nwse=",".join([str(ord) for ord in self.model().expected_coordinates.nwse()]),
nwse="/".join([str(ord) for ord in self.model().expected_coordinates.nwse()]),
number=self.model().expected_coordinates.ensemble_member,
)

for req in [base_req.as_ensemble_mean_request(), base_req.as_ensemble_std_request()]:
yield delayed(self._download_and_convert)(req)
# Yield the download and convert function with the appropriate request type
if self.model().expected_coordinates.ensemble_stat is not None:
for stat_req in [req.as_ensemble_mean_request(), req.as_ensemble_std_request()]:
yield delayed(self._download_and_convert)(stat_req)
elif self.model().expected_coordinates.ensemble_member is not None:
yield delayed(self._download_and_convert)(req.as_full_ensemble_request())
else:
yield delayed(self._download_and_convert)(req.as_operational_request())

def _download_and_convert(self, mr: _MARSRequest) -> ResultE[list[xr.DataArray]]:
"""Download and convert data from the ECMWF MARS server.
Expand Down Expand Up @@ -336,6 +356,8 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
.expand_dims("init_time")
.to_dataarray(name=ECMWFMARSRawRepository.model().name)
)
if "number" in da.coords:
da = da.rename({"number": "ensemble_member"})
da = (
da.drop_vars(
names=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def repository() -> entities.RawRepositoryMetadata:
available_models={
"default": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-uk": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("uk"),
"hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india"),
"hres-ifs-india": entities.Models.ECMWF_HRES_IFS_0P1DEGREE.with_region("india")\
.with_chunk_count_overrides({"variable": 1}),
},
)

Expand Down
1 change: 1 addition & 0 deletions src/nwp_consumer/internal/services/consumer_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _create_suitable_store(
year=multiple_its.year,
month=multiple_its.month,
)
its = [it.replace(tzinfo=dt.UTC) for it in its]

# Create a store for the data with the relevant init time coordinates
return entities.TensorStore.initialize_empty_store(
Expand Down

0 comments on commit 2bb79a9

Please sign in to comment.