Skip to content

Commit

Permalink
Clean up testing
Browse files Browse the repository at this point in the history
  • Loading branch information
manishvenu committed Dec 21, 2024
1 parent 12ad67e commit 9b7da85
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 124 deletions.
142 changes: 142 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,62 @@ def generate_silly_vt_dataset():
return eastern_boundary


@pytest.fixture()
def generate_silly_ic_dataset():
def _generate_silly_ic_dataset(
longitude_extent,
latitude_extent,
resolution,
number_vertical_layers,
depth,
temp_dataarray_initial_condition,
):
nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)
silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {
"silly_lat": silly_lat,
"silly_lon": silly_lon,
"silly_depth": silly_depth,
}
# initial condition includes, temp, salt, eta, u, v
initial_cond = xr.Dataset(
{
"eta": xr.DataArray(
np.random.random((ny, nx)),
dims=["silly_lat", "silly_lon"],
coords={
"silly_lat": silly_lat,
"silly_lon": silly_lon,
},
),
"temp": temp_dataarray_initial_condition,
"salt": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"u": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"v": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
}
)
return initial_cond

return _generate_silly_ic_dataset


@pytest.fixture()
def dummy_bathymetry_data():
latitude_extent = [16.0, 27]
Expand All @@ -130,3 +186,89 @@ def dummy_bathymetry_data():
)
bathymetry.name = "silly_depth"
return bathymetry


def temperature_dataarrays(
longitude_extent, latitude_extent, resolution, number_vertical_layers, depth
):

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {
"silly_lat": silly_lat,
"silly_lon": silly_lon,
"silly_depth": silly_depth,
}

toolpath_dir = "toolpath"
hgrid_type = "even_spacing"

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked = (
generate_temperature_arrays(nx, ny, number_vertical_layers)
)

temp_C = xr.DataArray(temp_in_C, dims=dims, coords=coords)
temp_K = xr.DataArray(temp_in_K, dims=dims, coords=coords)
temp_C_masked = xr.DataArray(temp_in_C_masked, dims=dims, coords=coords)
temp_K_masked = xr.DataArray(temp_in_K_masked, dims=dims, coords=coords)

maximum_temperature_in_C = np.max(temp_in_C)
return [temp_C, temp_C_masked, temp_K, temp_K_masked]


def number_of_gridpoints(longitude_extent, latitude_extent, resolution):
nx = int((longitude_extent[-1] - longitude_extent[0]) / resolution)
ny = int((latitude_extent[-1] - latitude_extent[0]) / resolution)

return nx, ny


def generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
):
nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

horizontal_buffer = 5

silly_lat = np.linspace(
latitude_extent[0] - horizontal_buffer,
latitude_extent[1] + horizontal_buffer,
ny,
)
silly_lon = np.linspace(
longitude_extent[0] - horizontal_buffer,
longitude_extent[1] + horizontal_buffer,
nx,
)
silly_depth = np.linspace(0, depth, number_vertical_layers)

return silly_lat, silly_lon, silly_depth


def generate_temperature_arrays(nx, ny, number_vertical_layers):

# temperatures close to 0 ᵒC
temp_in_C = np.random.randn(ny, nx, number_vertical_layers)

temp_in_C_masked = np.copy(temp_in_C)
if int(ny / 4 + 4) < ny - 1 and int(nx / 3 + 4) < nx + 1:
temp_in_C_masked[
int(ny / 3) : int(ny / 3 + 5), int(nx) : int(nx / 4 + 4), :
] = float("nan")
else:
raise ValueError("use bigger domain")

temp_in_K = np.copy(temp_in_C) + 273.15
temp_in_K_masked = np.copy(temp_in_C_masked) + 273.15

# ensure we didn't mask the minimum temperature
if np.nanmin(temp_in_C_masked) == np.min(temp_in_C):
return temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked
else:
return generate_temperature_arrays(nx, ny, number_vertical_layers)
142 changes: 18 additions & 124 deletions tests/test_expt_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import xarray as xr
import xesmf as xe
import dask
from .conftest import (
generate_temperature_arrays,
generate_silly_coords,
number_of_gridpoints,
temperature_dataarrays,
)

## Note:
## When creating test dataarrays we use 'silly' names for coordinates to
Expand Down Expand Up @@ -96,58 +102,6 @@ def test_setup_bathymetry(
bathymetry_file.unlink()


def number_of_gridpoints(longitude_extent, latitude_extent, resolution):
nx = int((longitude_extent[-1] - longitude_extent[0]) / resolution)
ny = int((latitude_extent[-1] - latitude_extent[0]) / resolution)

return nx, ny


def generate_temperature_arrays(nx, ny, number_vertical_layers):

# temperatures close to 0 ᵒC
temp_in_C = np.random.randn(ny, nx, number_vertical_layers)

temp_in_C_masked = np.copy(temp_in_C)
if int(ny / 4 + 4) < ny - 1 and int(nx / 3 + 4) < nx + 1:
temp_in_C_masked[
int(ny / 3) : int(ny / 3 + 5), int(nx) : int(nx / 4 + 4), :
] = float("nan")
else:
raise ValueError("use bigger domain")

temp_in_K = np.copy(temp_in_C) + 273.15
temp_in_K_masked = np.copy(temp_in_C_masked) + 273.15

# ensure we didn't mask the minimum temperature
if np.nanmin(temp_in_C_masked) == np.min(temp_in_C):
return temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked
else:
return generate_temperature_arrays(nx, ny, number_vertical_layers)


def generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
):
nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

horizontal_buffer = 5

silly_lat = np.linspace(
latitude_extent[0] - horizontal_buffer,
latitude_extent[1] + horizontal_buffer,
ny,
)
silly_lon = np.linspace(
longitude_extent[0] - horizontal_buffer,
longitude_extent[1] + horizontal_buffer,
nx,
)
silly_depth = np.linspace(0, depth, number_vertical_layers)

return silly_lat, silly_lon, silly_depth


longitude_extent = [-5, 3]
latitude_extent = (0, 10)
date_range = ["2003-01-01 00:00:00", "2003-01-01 00:00:00"]
Expand All @@ -156,35 +110,12 @@ def generate_silly_coords(
layer_thickness_ratio = 1
depth = 1000

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {"silly_lat": silly_lat, "silly_lon": silly_lon, "silly_depth": silly_depth}


toolpath_dir = "toolpath"
hgrid_type = "even_spacing"

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

temp_in_C, temp_in_C_masked, temp_in_K, temp_in_K_masked = generate_temperature_arrays(
nx, ny, number_vertical_layers
)

temp_C = xr.DataArray(temp_in_C, dims=dims, coords=coords)
temp_K = xr.DataArray(temp_in_K, dims=dims, coords=coords)
temp_C_masked = xr.DataArray(temp_in_C_masked, dims=dims, coords=coords)
temp_K_masked = xr.DataArray(temp_in_K_masked, dims=dims, coords=coords)

maximum_temperature_in_C = np.max(temp_in_C)


@pytest.mark.parametrize(
"temp_dataarray_initial_condition",
[temp_C, temp_C_masked, temp_K, temp_K_masked],
temperature_dataarrays(
longitude_extent, latitude_extent, resolution, number_vertical_layers, depth
),
)
@pytest.mark.parametrize(
(
Expand Down Expand Up @@ -224,20 +155,9 @@ def test_ocean_forcing(
hgrid_type,
temp_dataarray_initial_condition,
tmp_path,
generate_silly_ic_dataset,
):
dask.config.set(scheduler="single-threaded")

silly_lat, silly_lon, silly_depth = generate_silly_coords(
longitude_extent, latitude_extent, resolution, depth, number_vertical_layers
)

dims = ["silly_lat", "silly_lon", "silly_depth"]

coords = {
"silly_lat": silly_lat,
"silly_lon": silly_lon,
"silly_depth": silly_depth,
}
mom_run_dir = tmp_path / "rundir"
mom_input_dir = tmp_path / "inputdir"
expt = experiment(
Expand All @@ -254,42 +174,16 @@ def test_ocean_forcing(
hgrid_type=hgrid_type,
)

## Generate some initial condition to test on

nx, ny = number_of_gridpoints(longitude_extent, latitude_extent, resolution)

# initial condition includes, temp, salt, eta, u, v
initial_cond = xr.Dataset(
{
"eta": xr.DataArray(
np.random.random((ny, nx)),
dims=["silly_lat", "silly_lon"],
coords={
"silly_lat": silly_lat,
"silly_lon": silly_lon,
},
),
"temp": temp_dataarray_initial_condition,
"salt": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"u": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
"v": xr.DataArray(
np.random.random((ny, nx, number_vertical_layers)),
dims=dims,
coords=coords,
),
}
initial_cond = generate_silly_ic_dataset(
longitude_extent,
latitude_extent,
resolution,
number_vertical_layers,
depth,
temp_dataarray_initial_condition,
)

# Generate boundary forcing

initial_cond.to_netcdf(tmp_path / "ic_unprocessed")
initial_cond.close()
varnames = {
Expand All @@ -311,7 +205,7 @@ def test_ocean_forcing(

# ensure that temperature is in degrees C
assert np.nanmin(expt.ic_tracers["temp"]) < 100.0

maximum_temperature_in_C = np.max(temp_dataarray_initial_condition)
# max(temp) can be less maximum_temperature_in_C due to re-gridding
assert np.nanmax(expt.ic_tracers["temp"]) <= maximum_temperature_in_C

Expand Down

0 comments on commit 9b7da85

Please sign in to comment.