Skip to content

Commit

Permalink
Merge pull request #37 from semi-h/basics
Browse files Browse the repository at this point in the history
Reduce the memory usage
  • Loading branch information
semi-h authored Feb 16, 2024
2 parents 38c8b98 + a51b274 commit 26ccadd
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 30 deletions.
9 changes: 5 additions & 4 deletions src/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ module m_base_backend
procedure(transeq_ders), deferred :: transeq_z
procedure(tds_solve), deferred :: tds_solve
procedure(reorder), deferred :: reorder
procedure(sum_yzintox), deferred :: sum_yzintox
procedure(sum_intox), deferred :: sum_yintox
procedure(sum_intox), deferred :: sum_zintox
procedure(vecadd), deferred :: vecadd
procedure(get_fields), deferred :: get_fields
procedure(set_fields), deferred :: set_fields
Expand Down Expand Up @@ -95,7 +96,7 @@ end subroutine reorder
end interface

abstract interface
subroutine sum_yzintox(self, u, u_y, u_z)
subroutine sum_intox(self, u, u_)
!! sum9into3 subroutine combines all the directional velocity
!! derivatives into the corresponding x directional fields.
import :: base_backend_t
Expand All @@ -104,8 +105,8 @@ subroutine sum_yzintox(self, u, u_y, u_z)

class(base_backend_t) :: self
class(field_t), intent(inout) :: u
class(field_t), intent(in) :: u_y, u_z
end subroutine sum_yzintox
class(field_t), intent(in) :: u_
end subroutine sum_intox
end interface

abstract interface
Expand Down
27 changes: 21 additions & 6 deletions src/cuda/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ module m_cuda_backend
procedure :: transeq_z => transeq_z_cuda
procedure :: tds_solve => tds_solve_cuda
procedure :: reorder => reorder_cuda
procedure :: sum_yzintox => sum_yzintox_cuda
procedure :: sum_yintox => sum_yintox_cuda
procedure :: sum_zintox => sum_zintox_cuda
procedure :: vecadd => vecadd_cuda
procedure :: set_fields => set_fields_cuda
procedure :: get_fields => get_fields_cuda
Expand Down Expand Up @@ -462,29 +463,43 @@ subroutine reorder_cuda(self, u_o, u_i, direction)

end subroutine reorder_cuda

subroutine sum_yzintox_cuda(self, u, u_y, u_z)
subroutine sum_yintox_cuda(self, u, u_y)
implicit none

class(cuda_backend_t) :: self
class(field_t), intent(inout) :: u
class(field_t), intent(in) :: u_y, u_z
class(field_t), intent(in) :: u_y

real(dp), device, pointer, dimension(:, :, :) :: u_d, u_y_d, u_z_d
real(dp), device, pointer, dimension(:, :, :) :: u_d, u_y_d
type(dim3) :: blocks, threads

select type(u); type is (cuda_field_t); u_d => u%data_d; end select
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select

blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
threads = dim3(SZ, SZ, 1)
call sum_yintox<<<blocks, threads>>>(u_d, u_y_d, self%nz_loc)

end subroutine sum_yintox_cuda

subroutine sum_zintox_cuda(self, u, u_z)
implicit none

class(cuda_backend_t) :: self
class(field_t), intent(inout) :: u
class(field_t), intent(in) :: u_z

real(dp), device, pointer, dimension(:, :, :) :: u_d, u_z_d
type(dim3) :: blocks, threads

select type(u); type is (cuda_field_t); u_d => u%data_d; end select
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select

blocks = dim3(self%nx_loc, self%ny_loc/SZ, 1)
threads = dim3(SZ, 1, 1)
call sum_zintox<<<blocks, threads>>>(u_d, u_z_d, self%nz_loc)

end subroutine sum_yzintox_cuda
end subroutine sum_zintox_cuda

subroutine vecadd_cuda(self, a, x, b, y)
implicit none
Expand Down
18 changes: 14 additions & 4 deletions src/omp/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ module m_omp_backend
procedure :: transeq_z => transeq_z_omp
procedure :: tds_solve => tds_solve_omp
procedure :: reorder => reorder_omp
procedure :: sum_yzintox => sum_yzintox_omp
procedure :: sum_yintox => sum_yintox_omp
procedure :: sum_zintox => sum_zintox_omp
procedure :: vecadd => vecadd_omp
procedure :: set_fields => set_fields_omp
procedure :: get_fields => get_fields_omp
Expand Down Expand Up @@ -170,14 +171,23 @@ subroutine reorder_omp(self, u_, u, direction)

end subroutine reorder_omp

subroutine sum_yzintox_omp(self, u, u_y, u_z)
subroutine sum_yintox_omp(self, u, u_)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: u
class(field_t), intent(in) :: u_y, u_z
class(field_t), intent(in) :: u_

end subroutine sum_yzintox_omp
end subroutine sum_yintox_omp

subroutine sum_zintox_omp(self, u, u_)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: u
class(field_t), intent(in) :: u_

end subroutine sum_zintox_omp

subroutine vecadd_omp(self, a, x, b, y)
implicit none
Expand Down
40 changes: 24 additions & 16 deletions src/solver.f90
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ module m_solver

real(dp) :: dt, nu

class(field_t), pointer :: u, v, w, du, dv, dw
class(field_t), pointer :: u, v, w

class(base_backend_t), pointer :: backend
class(dirps_t), pointer :: xdirps, ydirps, zdirps
Expand Down Expand Up @@ -98,11 +98,6 @@ function init(backend, time_integrator, xdirps, ydirps, zdirps, globs) &
deallocate(u_init, v_init, w_init)
print*, 'initial conditions are set'

! Allocate fields for storing the RHS
solver%du => solver%backend%allocator%get_block()
solver%dv => solver%backend%allocator%get_block()
solver%dw => solver%backend%allocator%get_block()

nx = globs%nx_loc; ny = globs%ny_loc; nz = globs%nz_loc
dx = globs%dx; dy = globs%dy; dz = globs%dz

Expand Down Expand Up @@ -200,6 +195,14 @@ subroutine transeq(self, du, dv, dw, u, v, w)
call self%backend%allocator%release_block(v_y)
call self%backend%allocator%release_block(w_y)

call self%backend%sum_yintox(du, du_y)
call self%backend%sum_yintox(dv, dv_y)
call self%backend%sum_yintox(dw, dw_y)

call self%backend%allocator%release_block(du_y)
call self%backend%allocator%release_block(dv_y)
call self%backend%allocator%release_block(dw_y)

! just like in y direction, get some fields for the z derivatives.
u_z => self%backend%allocator%get_block()
v_z => self%backend%allocator%get_block()
Expand All @@ -222,14 +225,11 @@ subroutine transeq(self, du, dv, dw, u, v, w)
call self%backend%allocator%release_block(w_z)

! gather all the contributions into the x result array
call self%backend%sum_yzintox(du, du_y, du_z)
call self%backend%sum_yzintox(dv, dv_y, dv_z)
call self%backend%sum_yzintox(dw, dw_y, dw_z)
call self%backend%sum_zintox(du, du_z)
call self%backend%sum_zintox(dv, dv_z)
call self%backend%sum_zintox(dw, dw_z)

! release all the unnecessary blocks.
call self%backend%allocator%release_block(du_y)
call self%backend%allocator%release_block(dv_y)
call self%backend%allocator%release_block(dw_y)
call self%backend%allocator%release_block(du_z)
call self%backend%allocator%release_block(dv_z)
call self%backend%allocator%release_block(dw_z)
Expand Down Expand Up @@ -417,18 +417,26 @@ subroutine run(self, n_iter, u_out, v_out, w_out)
integer, intent(in) :: n_iter
real(dp), dimension(:, :, :), intent(inout) :: u_out, v_out, w_out

class(field_t), pointer :: div_u, pressure, dpdx, dpdy, dpdz
class(field_t), pointer :: du, dv, dw, div_u, pressure, dpdx, dpdy, dpdz

integer :: i

print*, 'start run'

do i = 1, n_iter
call self%transeq(self%du, self%dv, self%dw, self%u, self%v, self%w)
du => self%backend%allocator%get_block()
dv => self%backend%allocator%get_block()
dw => self%backend%allocator%get_block()

call self%transeq(du, dv, dw, self%u, self%v, self%w)

! time integration
call self%time_integrator%step(self%u, self%v, self%w, &
self%du, self%dv, self%dw, self%dt)
du, dv, dw, self%dt)

call self%backend%allocator%release_block(du)
call self%backend%allocator%release_block(dv)
call self%backend%allocator%release_block(dw)

! pressure
div_u => self%backend%allocator%get_block()
Expand Down Expand Up @@ -462,7 +470,7 @@ subroutine run(self, n_iter, u_out, v_out, w_out)
print*, 'run end'

call self%backend%get_fields( &
u_out, v_out, w_out, self%du, self%dv, self%dw &
u_out, v_out, w_out, self%u, self%v, self%w &
)

end subroutine run
Expand Down

0 comments on commit 26ccadd

Please sign in to comment.