Skip to content

Commit

Permalink
Merge pull request #27 from Nanoseb/omp_transeq
Browse files Browse the repository at this point in the history
Implement transeq in omp backend
  • Loading branch information
semi-h authored Feb 19, 2024
2 parents 26ccadd + 33345c0 commit 00d0986
Show file tree
Hide file tree
Showing 8 changed files with 652 additions and 91 deletions.
134 changes: 126 additions & 8 deletions src/omp/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module m_omp_backend
use m_base_backend, only: base_backend_t
use m_common, only: dp, globs_t
use m_tdsops, only: dirps_t, tdsops_t
use m_omp_exec_dist, only: exec_dist_tds_compact, exec_dist_transeq_compact
use m_omp_sendrecv, only: sendrecv_fields

use m_omp_common, only: SZ

Expand Down Expand Up @@ -30,6 +32,7 @@ module m_omp_backend
procedure :: vecadd => vecadd_omp
procedure :: set_fields => set_fields_omp
procedure :: get_fields => get_fields_omp
procedure :: transeq_omp_dist
end type omp_backend_t

interface omp_backend_t
Expand Down Expand Up @@ -118,7 +121,7 @@ subroutine transeq_x_omp(self, du, dv, dw, u, v, w, dirps)
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps

!call self%transeq_omp_dist(du, dv, dw, u, v, w, dirps)
call self%transeq_omp_dist(du, dv, dw, u, v, w, dirps)

end subroutine transeq_x_omp

Expand All @@ -131,7 +134,7 @@ subroutine transeq_y_omp(self, du, dv, dw, u, v, w, dirps)
type(dirps_t), intent(in) :: dirps

! u, v, w is reordered so that we pass v, u, w
!call self%transeq_omp_dist(dv, du, dw, v, u, w, dirps)
call self%transeq_omp_dist(dv, du, dw, v, u, w, dirps)

end subroutine transeq_y_omp

Expand All @@ -144,10 +147,86 @@ subroutine transeq_z_omp(self, du, dv, dw, u, v, w, dirps)
type(dirps_t), intent(in) :: dirps

! u, v, w is reordered so that we pass w, u, v
!call self%transeq_omp_dist(dw, du, dv, w, u, v, dirps)
call self%transeq_omp_dist(dw, du, dv, w, u, v, dirps)

end subroutine transeq_z_omp

subroutine transeq_omp_dist(self, du, dv, dw, u, v, w, dirps)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: du, dv, dw
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps
integer :: n_halo

call transeq_halo_exchange(self, u, v, w, dirps)

call transeq_dist_component(self, du, u, u, &
dirps%der1st, dirps%der1st_sym, dirps%der2nd, dirps)
call transeq_dist_component(self, dv, v, u, &
dirps%der1st_sym, dirps%der1st, dirps%der2nd_sym, dirps)
call transeq_dist_component(self, dw, w, u, &
dirps%der1st_sym, dirps%der1st, dirps%der2nd_sym, dirps)

end subroutine transeq_omp_dist


subroutine transeq_halo_exchange(self, u, v, w, dirps)
class(omp_backend_t) :: self
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps
integer :: n_halo

! TODO: don't hardcode n_halo
n_halo = 4

call copy_into_buffers(self%u_send_s, self%u_send_e, u%data, dirps%n, dirps%n_blocks)
call copy_into_buffers(self%v_send_s, self%v_send_e, v%data, dirps%n, dirps%n_blocks)
call copy_into_buffers(self%w_send_s, self%w_send_e, w%data, dirps%n, dirps%n_blocks)

call sendrecv_fields(self%u_recv_s, self%u_recv_e, self%u_send_s, self%u_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)
call sendrecv_fields(self%v_recv_s, self%v_recv_e, self%v_send_s, self%v_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)
call sendrecv_fields(self%w_recv_s, self%w_recv_e, self%w_send_s, self%w_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)

end subroutine transeq_halo_exchange

!> Computes RHS_x^v following:
! rhs_x^v = -0.5*(u*dv/dx + duv/dx) + nu*d2v/dx2
subroutine transeq_dist_component(self, rhs, v, u, tdsops_du, tdsops_dud, tdsops_d2u, dirps)

class(omp_backend_t) :: self
class(field_t), intent(inout) :: rhs
class(field_t), intent(in) :: u, v
class(tdsops_t), intent(in) :: tdsops_du
class(tdsops_t), intent(in) :: tdsops_dud
class(tdsops_t), intent(in) :: tdsops_d2u
type(dirps_t), intent(in) :: dirps
class(field_t), pointer :: du, d2u, dud

du => self%allocator%get_block()
dud => self%allocator%get_block()
d2u => self%allocator%get_block()

call exec_dist_transeq_compact(&
rhs%data, du%data, dud%data, d2u%data, &
self%du_send_s, self%du_send_e, self%du_recv_s, self%du_recv_e, &
self%dud_send_s, self%dud_send_e, self%dud_recv_s, self%dud_recv_e, &
self%d2u_send_s, self%d2u_send_e, self%d2u_recv_s, self%d2u_recv_e, &
u%data, self%u_recv_s, self%u_recv_e, &
v%data, self%v_recv_s, self%v_recv_e, &
tdsops_du, tdsops_dud, tdsops_d2u, self%nu, &
dirps%nproc, dirps%pprev, dirps%pnext, dirps%n_blocks)

call self%allocator%release_block(du)
call self%allocator%release_block(dud)
call self%allocator%release_block(d2u)

end subroutine transeq_dist_component

subroutine tds_solve_omp(self, du, u, dirps, tdsops)
implicit none

Expand All @@ -157,10 +236,36 @@ subroutine tds_solve_omp(self, du, u, dirps, tdsops)
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops

!call self%tds_solve_dist(self, du, u, dirps, tdsops)
call tds_solve_dist(self, du, u, dirps, tdsops)

end subroutine tds_solve_omp

subroutine tds_solve_dist(self, du, u, dirps, tdsops)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops
integer :: n_halo

! TODO: don't hardcode n_halo
n_halo = 4
call copy_into_buffers(self%u_send_s, self%u_send_e, u%data, dirps%n, dirps%n_blocks)

! halo exchange
call sendrecv_fields(self%u_recv_s, self%u_recv_e, self%u_send_s, self%u_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)


call exec_dist_tds_compact( &
du%data, u%data, self%u_recv_s, self%u_recv_e, self%du_send_s, self%du_send_e, &
self%du_recv_s, self%du_recv_e, &
tdsops, dirps%nproc, dirps%pprev, dirps%pnext, dirps%n_blocks)

end subroutine tds_solve_dist

subroutine reorder_omp(self, u_, u, direction)
implicit none

Expand Down Expand Up @@ -200,15 +305,28 @@ subroutine vecadd_omp(self, a, x, b, y)

end subroutine vecadd_omp

subroutine copy_into_buffers(u_send_s, u_send_e, u, n)
subroutine copy_into_buffers(u_send_s, u_send_e, u, n, n_blocks)
implicit none

real(dp), dimension(:, :, :), intent(out) :: u_send_s, u_send_e
real(dp), dimension(:, :, :), intent(in) :: u
integer, intent(in) :: n

u_send_s(:, :, :) = u(:, 1:4, :)
u_send_e(:, :, :) = u(:, n - 3:n, :)
integer, intent(in) :: n_blocks
integer :: i, j, k
integer :: n_halo = 4

!$omp parallel do
do k=1, n_blocks
do j=1, n_halo
!$omp simd
do i=1, SZ
u_send_s(i, j, k) = u(i, j, k)
u_send_e(i, j, k) = u(i, n - n_halo + j, k)
end do
!$omp end simd
end do
end do
!$omp end parallel do

end subroutine copy_into_buffers

Expand Down
129 changes: 129 additions & 0 deletions src/omp/exec_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,134 @@ subroutine exec_dist_tds_compact( &

end subroutine exec_dist_tds_compact


subroutine exec_dist_transeq_compact(&
rhs, du, dud, d2u, &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
u, u_recv_s, u_recv_e, &
v, v_recv_s, v_recv_e, &
tdsops_du, tdsops_dud, tdsops_d2u, nu, nproc, pprev, pnext, n_block)

implicit none

! du = d(u)
real(dp), dimension(:, :, :), intent(out) :: rhs, du, dud, d2u

! The ones below are intent(out) just so that we can write data in them,
! not because we actually need the data they store later where this
! subroutine is called. We absolutely don't care about the data they pass back
real(dp), dimension(:, :, :), intent(out) :: &
du_send_s, du_send_e, du_recv_s, du_recv_e
real(dp), dimension(:, :, :), intent(out) :: &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e
real(dp), dimension(:, :, :), intent(out) :: &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e

real(dp), dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
real(dp), dimension(:, :, :), intent(in) :: v, v_recv_s, v_recv_e

type(tdsops_t), intent(in) :: tdsops_du, tdsops_dud, tdsops_d2u

real(dp), dimension(:, :), allocatable :: ud, ud_recv_s, ud_recv_e
real(dp) :: nu
integer, intent(in) :: nproc, pprev, pnext
integer, intent(in) :: n_block

integer :: n_data, n_halo
integer :: k, i, j, n

! TODO: don't hardcode n_halo
n_halo = 4
n = tdsops_d2u%n
n_data = SZ*n_block

allocate(ud(SZ, n))
allocate(ud_recv_e(SZ, n_halo))
allocate(ud_recv_s(SZ, n_halo))

!$omp parallel do private(ud, ud_recv_e, ud_recv_s)
do k = 1, n_block
call der_univ_dist( &
du(:, :, k), du_send_s(:, :, k), du_send_e(:, :, k), u(:, :, k), &
u_recv_s(:, :, k), u_recv_e(:, :, k), &
tdsops_du%coeffs_s, tdsops_du%coeffs_e, tdsops_du%coeffs, tdsops_du%n, &
tdsops_du%dist_fw, tdsops_du%dist_bw, tdsops_du%dist_af &
)

call der_univ_dist( &
d2u(:, :, k), d2u_send_s(:, :, k), d2u_send_e(:, :, k), u(:, :, k), &
u_recv_s(:, :, k), u_recv_e(:, :, k), &
tdsops_d2u%coeffs_s, tdsops_d2u%coeffs_e, tdsops_d2u%coeffs, tdsops_d2u%n, &
tdsops_d2u%dist_fw, tdsops_d2u%dist_bw, tdsops_d2u%dist_af &
)

! Handle dud by locally generating u*v
do j = 1, n
!$omp simd
do i = 1, SZ
ud(i, j) = u(i, j, k) * v(i, j, k)
end do
!$omp end simd
end do

do j = 1, n_halo
!$omp simd
do i = 1, SZ
ud_recv_s(i, j) = u_recv_s(i, j, k) * v_recv_s(i, j, k)
ud_recv_e(i, j) = u_recv_e(i, j, k) * v_recv_e(i, j, k)
end do
!$omp end simd
end do

call der_univ_dist( &
dud(:, :, k), dud_send_s(:, :, k), dud_send_e(:, :, k), ud(:, :), &
ud_recv_s(:, :), ud_recv_e(:, :), &
tdsops_dud%coeffs_s, tdsops_dud%coeffs_e, tdsops_dud%coeffs, tdsops_dud%n, &
tdsops_dud%dist_fw, tdsops_dud%dist_bw, tdsops_dud%dist_af &
)

end do
!$omp end parallel do

! halo exchange for 2x2 systems
call sendrecv_fields(du_recv_s, du_recv_e, du_send_s, du_send_e, &
n_data, nproc, pprev, pnext)
call sendrecv_fields(dud_recv_s, dud_recv_e, dud_send_s, dud_send_e, &
n_data, nproc, pprev, pnext)
call sendrecv_fields(d2u_recv_s, d2u_recv_e, d2u_send_s, d2u_send_e, &
n_data, nproc, pprev, pnext)

!$omp parallel do
do k = 1, n_block
call der_univ_subs(du(:, :, k), &
du_recv_s(:, :, k), du_recv_e(:, :, k), &
tdsops_du%n, tdsops_du%dist_sa, tdsops_du%dist_sc)

call der_univ_subs(dud(:, :, k), &
dud_recv_s(:, :, k), dud_recv_e(:, :, k), &
tdsops_dud%n, tdsops_dud%dist_sa, tdsops_dud%dist_sc)

call der_univ_subs(d2u(:, :, k), &
d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), &
tdsops_d2u%n, tdsops_d2u%dist_sa, tdsops_d2u%dist_sc)

do j = 1, n
!$omp simd
do i = 1, SZ
rhs(i, j, k) = -0.5_dp*(v(i, j, k)*du(i, j, k) + dud(i, j, k)) + nu*d2u(i, j, k)
end do
!$omp end simd
end do

end do
!$omp end parallel do


end subroutine exec_dist_transeq_compact



end module m_omp_exec_dist

Loading

0 comments on commit 00d0986

Please sign in to comment.