Skip to content

Commit

Permalink
Update Initial Condition + Alternate Rotation Function
Browse files Browse the repository at this point in the history
  • Loading branch information
manishvenu committed Jan 7, 2025
1 parent b5b3e38 commit cd5e408
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 53 deletions.
133 changes: 83 additions & 50 deletions regional_mom6/regional_mom6.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def setup_initial_condition(
varnames,
arakawa_grid="A",
vcoord_type="height",
rotational_method=rot.RotationMethod.GIVEN_ANGLE,
):
"""
Reads the initial condition from files in ``ic_path``, interpolates to the
Expand All @@ -1127,6 +1128,7 @@ def setup_initial_condition(
Either ``'A'`` (default), ``'B'``, or ``'C'``.
vcoord_type (Optional[str]): The type of vertical coordinate used in the forcing files.
Either ``'height'`` or ``'thickness'``.
rotational_method (Optional[RotationMethod]): The method used to rotate the velocities.
"""

# Remove time dimension if present in the IC.
Expand Down Expand Up @@ -1249,28 +1251,6 @@ def setup_initial_condition(
+ "Terminating!"
)

## Construct the xq, yh and xh, yq grids
ugrid = (
self.hgrid[["x", "y"]]
.isel(nxp=slice(None, None, 2), nyp=slice(1, None, 2))
.rename({"x": "lon", "y": "lat"})
.set_coords(["lat", "lon"])
)
vgrid = (
self.hgrid[["x", "y"]]
.isel(nxp=slice(1, None, 2), nyp=slice(None, None, 2))
.rename({"x": "lon", "y": "lat"})
.set_coords(["lat", "lon"])
)

## Construct the cell centre grid for tracers (xh, yh).
tgrid = (
self.hgrid[["x", "y"]]
.isel(nxp=slice(1, None, 2), nyp=slice(1, None, 2))
.rename({"x": "lon", "y": "lat", "nxp": "nx", "nyp": "ny"})
.set_coords(["lat", "lon"])
)

# NaNs might be here from the land mask of the model that the IC has come from.
# If they're not removed then the coastlines from this other grid will be retained!
# The land mask comes from the bathymetry file, so we don't need NaNs
Expand Down Expand Up @@ -1309,39 +1289,75 @@ def setup_initial_condition(
.ffill("lat")
.bfill("lat")
)
renamed_hgrid = self.hgrid # This is not a deep copy
renamed_hgrid["lon"] = renamed_hgrid["x"]
renamed_hgrid["lat"] = renamed_hgrid["y"]
tgrid = (
rgd.get_hgrid_arakawa_c_points(self.hgrid, "t")
.rename({"tlon": "lon", "tlat": "lat", "nxp": "nx", "nyp": "ny"})
.set_coords(["lat", "lon"])
)

## Make our three horizontal regridders
regridder_u = xe.Regridder(
ic_raw_u,
ugrid,
"bilinear",

regridder_u = rgd.create_regridder(
ic_raw_u, renamed_hgrid, locstream_out=False, method="bilinear"
)
regridder_v = xe.Regridder(
ic_raw_v,
vgrid,
"bilinear",
regridder_v = rgd.create_regridder(
ic_raw_v, renamed_hgrid, locstream_out=False, method="bilinear"
)
regridder_t = rgd.create_regridder(
ic_raw_tracers, tgrid, locstream_out=False, method="bilinear"
) # Doesn't need to be rotated, so we can regrid to just tracers

regridder_t = xe.Regridder(
ic_raw_tracers,
tgrid,
"bilinear",
)
# ugrid= rgd.get_hgrid_arakawa_c_points(self.hgrid, "u").rename({"ulon": "lon", "ulat": "lat"}).set_coords(["lat", "lon"])
# vgrid = rgd.get_hgrid_arakawa_c_points(self.hgrid, "v").rename({"vlon": "lon", "vlat": "lat"}).set_coords(["lat", "lon"])

## Construct the cell centre grid for tracers (xh, yh).
print("INITIAL CONDITIONS")

## Regrid all fields horizontally.

print("Regridding Velocities... ", end="")

regridded_u = regridder_u(ic_raw_u)
regridded_v = regridder_v(ic_raw_v)
if rotational_method == rot.RotationMethod.GIVEN_ANGLE:
rotated_u, rotated_v = segment.rotate(
None,
regridded_u,
regridded_v,
radian_angle=np.radians(self.hgrid.angle_dx.values),
)
elif rotational_method == rot.RotationMethod.EXPAND_GRID:
self.hgrid["angle_dx_rm6"] = (
rot.initialize_grid_rotation_angles_using_expanded_hgrid(self.hgrid)
)
rotated_u, rotated_v = segment.rotate(
regridded_u,
regridded_v,
radian_angle=np.radians(self.hgrid.angle_dx_rm6.values),
)
elif rotational_method == rot.RotationMethod.NO_ROTATION:
rotated_u, rotated_v = regridded_u, regridded_v
# Slice the velocites to the u and v grid.
u_points = rgd.get_hgrid_arakawa_c_points(self.hgrid, "u")
v_points = rgd.get_hgrid_arakawa_c_points(self.hgrid, "v")
rotated_v = rotated_v[:, v_points.v_points_y.values, v_points.v_points_x.values]
rotated_u = rotated_u[:, u_points.u_points_y.values, u_points.u_points_x.values]
rotated_u["lon"] = u_points.ulon
rotated_u["lat"] = u_points.ulat
rotated_v["lon"] = v_points.vlon
rotated_v["lat"] = v_points.vlat

# Merge Vels
vel_out = xr.merge(
[
regridder_u(ic_raw_u)
.rename({"lon": "xq", "lat": "yh", "nyp": "ny", varnames["zl"]: "zl"})
.rename("u"),
regridder_v(ic_raw_v)
.rename({"lon": "xh", "lat": "yq", "nxp": "nx", varnames["zl"]: "zl"})
.rename("v"),
rotated_u.rename(
{"lon": "xq", "lat": "yh", "nyp": "ny", varnames["zl"]: "zl"}
).rename("u"),
rotated_v.rename(
{"lon": "xh", "lat": "yq", "nxp": "nx", varnames["zl"]: "zl"}
).rename("v"),
]
)

Expand Down Expand Up @@ -1400,14 +1416,11 @@ def setup_initial_condition(
eta_out.attrs = ic_raw_eta.attrs

## Regrid the fields vertically

if (
vcoord_type == "thickness"
): ## In this case construct the vertical profile by summing thickness
tracers_out["zl"] = tracers_out["zl"].diff("zl")
dz = tracers_out[self.z].diff(self.z)
dz.name = "dz"
dz = xr.concat([dz, dz[-1]], dim=self.z)
dz = rgd.generate_dz(tracers_out, self.z)

tracers_out = tracers_out.interp({"zl": self.vgrid.zl.values})
vel_out = vel_out.interp({"zl": self.vgrid.zl.values})
Expand Down Expand Up @@ -2934,9 +2947,30 @@ def __init__(
self.segment_name = segment_name
self.repeat_year_forcing = repeat_year_forcing

def rotate(self, u, v, radian_angle):
# Make docstring
def rotate_complex(self, u, v, radian_angle):
"""
Rotate velocities to grid orientation using complex number math (Same as rotate)
Args:
u (xarray.DataArray): The u-component of the velocity.
v (xarray.DataArray): The v-component of the velocity.
radian_angle (xarray.DataArray): The angle of the grid in RADIANS
Returns:
Tuple[xarray.DataArray, xarray.DataArray]: The rotated u and v components of the velocity.
"""

# express velocity in the complex plan
vel = u + v * 1j
# rotate velocity using grid angle theta
vel = vel * np.exp(1j * radian_angle)

# From here you can easily get the rotated u, v, or the magnitude/direction of the currents:
u = np.real(vel)
v = np.imag(vel)

return u, v

def rotate(self, u, v, radian_angle):
"""
Rotate the velocities to the grid orientation.
Args:
Expand Down Expand Up @@ -2974,7 +3008,7 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG
coords,
self.outfolder
/ f"weights/bilinear_velocity_weights_{self.orientation}.nc",
method="nearest_s2d",
method="bilinear",
)

regridded = regridder(
Expand All @@ -2990,7 +3024,6 @@ def regrid_velocity_tracers(self, rotational_method=rot.RotationMethod.GIVEN_ANG
regridded[self.v],
radian_angle=np.radians(coords.angle.values),
)

elif rotational_method == rot.RotationMethod.EXPAND_GRID:

# Recalculate entire hgrid angles
Expand Down
10 changes: 8 additions & 2 deletions regional_mom6/regridding.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def create_regridder(
output_grid: xr.Dataset,
outfile: Path = None,
method: str = "bilinear",
locstream_out: bool = True,
periodic: bool = False,
) -> xe.Regridder:
"""
Basic Regridder for any forcing variables, this just wraps the xesmf regridder for a few defaults
Expand All @@ -220,6 +222,10 @@ def create_regridder(
The path to the output file for weights I believe, by default Path(".temp")
method : str, optional
The regridding method, by default "bilinear"
locstream_out : bool, optional
Whether to output the locstream, by default True
periodic : bool, optional
Whether the grid is periodic, by default False
Returns
-------
xe.Regridder
Expand All @@ -230,8 +236,8 @@ def create_regridder(
forcing_variables,
output_grid,
method=method,
locstream_out=True,
periodic=False,
locstream_out=locstream_out,
periodic=periodic,
filename=outfile,
reuse_weights=False,
)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_regridding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ def test_smoke_untested_funcs(get_curvilinear_hgrid, generate_silly_vt_dataset):


def test_fill_missing_data(generate_silly_vt_dataset):
"""
Only testing forward fill for now
"""
ds = generate_silly_vt_dataset
ds["temp"][0, 0, 6:10, 0] = np.nan

ds = rgd.fill_missing_data(ds, "silly_depth")
ds = rgd.fill_missing_data(ds, "silly_depth", fill="f")

assert (
ds["temp"][0, 0, 6:10, 0] == (ds["temp"][0, 0, 5, 0])
Expand Down

0 comments on commit cd5e408

Please sign in to comment.