Skip to content

Commit

Permalink
Create gw_ml initialisation routine to read in net and normalisation …
Browse files Browse the repository at this point in the history
…variables from file.
  • Loading branch information
jatkinson1000 committed Aug 19, 2024
1 parent 9fd9e8f commit ab846d3
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/physics/cam/gw_drag.F90
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module gw_drag
use gw_common, only: GWBand
use gw_convect, only: BeresSourceDesc
use gw_front, only: CMSourceDesc
use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml

! Typical module header
implicit none
Expand Down Expand Up @@ -990,7 +991,7 @@ subroutine gw_init()
! Set up neccessary attributes if using ML scheme for convective drag
if ((gw_convect_dp_ml == 'on') .or. (gw_convect_dp_ml == 'bothon')) then
! Load the convective drag net from TorchScript file
call torch_model_load(gw_convect_dp_nn, gw_convect_dp_ml_net)
call gw_drag_convect_dp_ml_init(gw_convect_dp_ml_net, gw_convect_dp_ml_norms)
endif

if (use_gw_convect_sh) then
Expand Down
256 changes: 250 additions & 6 deletions src/physics/cam/gw_ml.F90
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,58 @@ module gw_ml

use gw_utils, only: r8
use ppgrid, only: pver
use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8
use cam_abortutils, only: endrun

use ftorch

implicit none
private
save

public :: gw_drag_convect_dp_ml
public :: gw_drag_convect_dp_ml, gw_drag_convect_dp_ml_init

! Neural Net as read in by FTorch
type(torch_model) :: convect_net

! Means for normalisation
real(r8) :: utgw_mean(pver), vtgw_mean(pver)
real(r8) :: u_mean(pver), v_mean(pver)
real(r8) :: t_mean(pver)
real(r8) :: dse_mean(pver)
real(r8) :: nm_mean(pver)
real(r8) :: netdt_mean(pver)
real(r8) :: zm_mean(pver)
real(r8) :: rhoi_mean(pver+1)
real(r8) :: ps_mean
real(r8) :: lat_mean
real(r8) :: lon_mean
! Standard deviations for normalisation
real(r8) :: utgw_std(pver), vtgw_std(pver)
real(r8) :: u_std(pver), v_std(pver)
real(r8) :: t_std(pver)
real(r8) :: dse_std(pver)
real(r8) :: nm_std(pver)
real(r8) :: netdt_std(pver)
real(r8) :: zm_std(pver)
real(r8) :: rhoi_std(pver+1)
real(r8) :: ps_std
real(r8) :: lat_std
real(r8) :: lon_std

contains

!==========================================================================

subroutine gw_drag_convect_dp_ml(convect_net, &
ncol, dt, &
subroutine gw_drag_convect_dp_ml(ncol, dt, &
u, v, t, dse, nm, netdt, zm, rhoi, ps, lat, lon, &
utgw, vtgw)

! Take data from CAM, normalise and concatenate before passing it to the Torch neural
! net to calculate u and v tendencies.

! Neural Net as read in by FTorch
type(torch_model) :: convect_net



! Column dimension.
integer, intent(in) :: ncol
Expand All @@ -46,7 +75,7 @@ subroutine gw_drag_convect_dp_ml(convect_net, &
! Midpoint and interface Brunt-Vaisalla frequencies.
real(r8), intent(in) :: nm(ncol,pver)
! Heating rate due to convection.
real(r8), intent(in) :: netdt(:,:)
real(r8), intent(in) :: netdt(ncol,pver)
! Midpoint geopotential altitudes.
real(r8), intent(in) :: zm(ncol,pver)
! Interface densities.
Expand Down Expand Up @@ -110,4 +139,219 @@ subroutine gw_drag_convect_dp_ml(convect_net, &

end subroutine gw_drag_convect_dp_ml


subroutine gw_drag_convect_dp_ml_init(neural_net_path, norms_path)

character(len=132), intent(in) :: neural_net_path ! Filepath to PyTorch Torchscript net
character(len=132), intent(in) :: norms_path ! Filepath to NetCDF normalisation weights

! Load the convective drag net from TorchScript file
call torch_model_load(convect_net, neural_net_path)
! read in normalisation weights
call read_norms(norms_path)

end subroutine gw_drag_convect_dp_ml_init


subroutine read_norms(norms_path)

use netcdf
use error_messages, only: handle_ncerr

character(len=132), intent(in) :: norms_path ! Filepath to NetCDF normalisation weights

integer :: ncid, varid, retva, ierr
character(len=*), parameter :: sub = 'gw_ml/F90 read_norms: '

! Load weights from file in master process then broadcast
if (masterproc) then
! Open the NetCDF file
call handle_ncerr( nf90_open(trim(norms_path), NF90_NOWRITE, ncid), &
"Error opening NetCDF norms file in gw_ml.F90")

! We do not need to read in dimensions here as we assume inputs match the grid.

! Read in variables (means and deviations).
call handle_ncerr( nf90_inq_varid(ncid, 'U_mean', varid), &
"Error getting U_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, u_mean), &
"Error getting U_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'U_std', varid), &
"Error getting U_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, u_std), &
"Error getting U_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'V_mean', varid), &
"Error getting V_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, v_mean), &
"Error getting V_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'V_std', varid), &
"Error getting V_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, v_std), &
"Error getting V_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'T_mean', varid), &
"Error getting T_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, t_mean), &
"Error getting t_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'T_std', varid), &
"Error getting T_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, t_std), &
"Error getting T_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'DSE_mean', varid), &
"Error getting DSE_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, dse_mean), &
"Error getting U_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'DSE_std', varid), &
"Error getting DSE_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, dse_std), &
"Error getting DSE_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'NMBV_mean', varid), &
"Error getting NMBV_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, nm_mean), &
"Error getting NMBV_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'NMBV_std', varid), &
"Error getting NMBV_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, nm_std), &
"Error getting NMBV_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'NETDT_mean', varid), &
"Error getting NETDT_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, netdt_mean), &
"Error getting NETDT_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'NETDT_std', varid), &
"Error getting NETDT_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, netdt_std), &
"Error getting NETDT_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'Z3_mean', varid), &
"Error getting Z3_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, zm_mean), &
"Error getting Z3_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'Z3_std', varid), &
"Error getting Z3_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, zm_std), &
"Error getting Z3_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'RHOI_mean', varid), &
"Error getting RHOI_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, rhoi_mean), &
"Error getting RHOI_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'RHOI_std', varid), &
"Error getting RHOI_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, rhoi_std), &
"Error getting RHOI_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'PS_mean', varid), &
"Error getting PS_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, ps_mean), &
"Error getting PS_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'PS_std', varid), &
"Error getting PS_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, ps_std), &
"Error getting PS_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'lat_mean', varid), &
"Error getting lat_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, lat_mean), &
"Error getting lat_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'lat_std', varid), &
"Error getting lat_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, lat_std), &
"Error getting lat_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'lon_mean', varid), &
"Error getting lon_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, lon_mean), &
"Error getting lon_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'lon_std', varid), &
"Error getting lon_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, lon_std), &
"Error getting lon_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'UTGWSPEC_mean', varid), &
"Error getting UTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, utgw_mean), &
"Error getting UTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'UTGWSPEC_std', varid), &
"Error getting UTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, utgw_std), &
"Error getting UTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90")

call handle_ncerr( nf90_inq_varid(ncid, 'VTGWSPEC_mean', varid), &
"Error getting VTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, vtgw_mean), &
"Error getting VTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_inq_varid(ncid, 'VTGWSPEC_std', varid), &
"Error getting VTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90")
call handle_ncerr( nf90_get_var(ncid, varid, vtgw_std), &
"Error getting VTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90")

endif

! Broadcast normalisation variables to other processes
call mpi_bcast(utgw_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: utgw_mean from gw_ml.F90")
call mpi_bcast(utgw_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: utgw_std from gw_ml.F90")

call mpi_bcast(vtgw_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: vtgw_mean from gw_ml.F90")
call mpi_bcast(vtgw_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: vtgw_std from gw_ml.F90")

call mpi_bcast(u_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: u_mean from gw_ml.F90")
call mpi_bcast(u_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: u_std from gw_ml.F90")

call mpi_bcast(v_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: v_mean from gw_ml.F90")
call mpi_bcast(v_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: v_std from gw_ml.F90")

call mpi_bcast(t_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: t_mean from gw_ml.F90")
call mpi_bcast(t_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: t_std from gw_ml.F90")

call mpi_bcast(dse_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: dse_mean from gw_ml.F90")
call mpi_bcast(dse_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: dse_std from gw_ml.F90")

call mpi_bcast(nm_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: nm_mean from gw_ml.F90")
call mpi_bcast(nm_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: nm_std from gw_ml.F90")

call mpi_bcast(zm_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: zm_mean from gw_ml.F90")
call mpi_bcast(zm_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: zm_std from gw_ml.F90")

call mpi_bcast(rhoi_mean, pver+1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: rhoi_mean from gw_ml.F90")
call mpi_bcast(rhoi_std, pver+1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: rhoi_std from gw_ml.F90")

call mpi_bcast(ps_mean, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: ps_mean from gw_ml.F90")
call mpi_bcast(ps_std, pver, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: ps_std from gw_ml.F90")

call mpi_bcast(lat_mean, 1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lat_mean from gw_ml.F90")
call mpi_bcast(lat_std, 1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lat_std from gw_ml.F90")

call mpi_bcast(lon_mean, 1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lon_mean from gw_ml.F90")
call mpi_bcast(lon_std, 1, mpi_real8, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lon_std from gw_ml.F90")

end subroutine read_norms

end module gw_ml

0 comments on commit ab846d3

Please sign in to comment.