Skip to content

Commit

Permalink
Add axpyz kernels (#810)
Browse files Browse the repository at this point in the history
This PR adds SeqVector and ParVector functions for computing: z = a * x + b * y
  • Loading branch information
victorapm authored Feb 3, 2023
1 parent d3f6b03 commit 53dfbe3
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/parcsr_mv/HYPRE_parcsr_vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ HYPRE_ParVectorScale( HYPRE_Complex value,
/*--------------------------------------------------------------------------
* HYPRE_ParVectorAxpy
*--------------------------------------------------------------------------*/

HYPRE_Int
HYPRE_ParVectorAxpy( HYPRE_Complex alpha,
HYPRE_ParVector x,
Expand Down
3 changes: 3 additions & 0 deletions src/parcsr_mv/_hypre_parcsr_mv.h
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,9 @@ hypre_ParVector *hypre_ParVectorCloneDeep_v2( hypre_ParVector *x,
HYPRE_Int hypre_ParVectorMigrate(hypre_ParVector *x, HYPRE_MemoryLocation memory_location);
HYPRE_Int hypre_ParVectorScale ( HYPRE_Complex alpha, hypre_ParVector *y );
HYPRE_Int hypre_ParVectorAxpy ( HYPRE_Complex alpha, hypre_ParVector *x, hypre_ParVector *y );
HYPRE_Int hypre_ParVectorAxpyz ( HYPRE_Complex alpha, hypre_ParVector *x,
HYPRE_Complex beta, hypre_ParVector *y,
hypre_ParVector *z );
HYPRE_Int hypre_ParVectorMassAxpy ( HYPRE_Complex *alpha, hypre_ParVector **x, hypre_ParVector *y,
HYPRE_Int k, HYPRE_Int unroll);
HYPRE_Real hypre_ParVectorInnerProd ( hypre_ParVector *x, hypre_ParVector *y );
Expand Down
19 changes: 19 additions & 0 deletions src/parcsr_mv/par_vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ hypre_ParVectorCopy( hypre_ParVector *x,
{
hypre_Vector *x_local = hypre_ParVectorLocalVector(x);
hypre_Vector *y_local = hypre_ParVectorLocalVector(y);

return hypre_SeqVectorCopy(x_local, y_local);
}

Expand Down Expand Up @@ -442,6 +443,24 @@ hypre_ParVectorAxpy( HYPRE_Complex alpha,
return hypre_SeqVectorAxpy(alpha, x_local, y_local);
}

/*--------------------------------------------------------------------------
* hypre_ParVectorAxpyz
*--------------------------------------------------------------------------*/

HYPRE_Int
hypre_ParVectorAxpyz( HYPRE_Complex alpha,
hypre_ParVector *x,
HYPRE_Complex beta,
hypre_ParVector *y,
hypre_ParVector *z )
{
hypre_Vector *x_local = hypre_ParVectorLocalVector(x);
hypre_Vector *y_local = hypre_ParVectorLocalVector(y);
hypre_Vector *z_local = hypre_ParVectorLocalVector(z);

return hypre_SeqVectorAxpyz(alpha, x_local, beta, y_local, z_local);
}

/*--------------------------------------------------------------------------
* hypre_ParVectorInnerProd
*--------------------------------------------------------------------------*/
Expand Down
3 changes: 3 additions & 0 deletions src/parcsr_mv/protos.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,9 @@ hypre_ParVector *hypre_ParVectorCloneDeep_v2( hypre_ParVector *x,
HYPRE_Int hypre_ParVectorMigrate(hypre_ParVector *x, HYPRE_MemoryLocation memory_location);
HYPRE_Int hypre_ParVectorScale ( HYPRE_Complex alpha, hypre_ParVector *y );
HYPRE_Int hypre_ParVectorAxpy ( HYPRE_Complex alpha, hypre_ParVector *x, hypre_ParVector *y );
HYPRE_Int hypre_ParVectorAxpyz ( HYPRE_Complex alpha, hypre_ParVector *x,
HYPRE_Complex beta, hypre_ParVector *y,
hypre_ParVector *z );
HYPRE_Int hypre_ParVectorMassAxpy ( HYPRE_Complex *alpha, hypre_ParVector **x, hypre_ParVector *y,
HYPRE_Int k, HYPRE_Int unroll);
HYPRE_Real hypre_ParVectorInnerProd ( hypre_ParVector *x, hypre_ParVector *y );
Expand Down
6 changes: 6 additions & 0 deletions src/seq_mv/protos.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpy ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyHost ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyz ( HYPRE_Complex alpha, hypre_Vector *x,
HYPRE_Complex beta, hypre_Vector *y,
hypre_Vector *z );
HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x,
HYPRE_Complex beta, hypre_Vector *y,
hypre_Vector *z );
HYPRE_Real hypre_SeqVectorInnerProd ( hypre_Vector *x, hypre_Vector *y );
HYPRE_Real hypre_SeqVectorInnerProdHost ( hypre_Vector *x, hypre_Vector *y );
HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y );
Expand Down
6 changes: 6 additions & 0 deletions src/seq_mv/seq_mv.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,12 @@ HYPRE_Int hypre_SeqVectorScaleDevice( HYPRE_Complex alpha, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpy ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyHost ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyDevice ( HYPRE_Complex alpha, hypre_Vector *x, hypre_Vector *y );
HYPRE_Int hypre_SeqVectorAxpyz ( HYPRE_Complex alpha, hypre_Vector *x,
HYPRE_Complex beta, hypre_Vector *y,
hypre_Vector *z );
HYPRE_Int hypre_SeqVectorAxpyzDevice ( HYPRE_Complex alpha, hypre_Vector *x,
HYPRE_Complex beta, hypre_Vector *y,
hypre_Vector *z );
HYPRE_Real hypre_SeqVectorInnerProd ( hypre_Vector *x, hypre_Vector *y );
HYPRE_Real hypre_SeqVectorInnerProdHost ( hypre_Vector *x, hypre_Vector *y );
HYPRE_Real hypre_SeqVectorInnerProdDevice ( hypre_Vector *x, hypre_Vector *y );
Expand Down
60 changes: 60 additions & 0 deletions src/seq_mv/vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,66 @@ hypre_SeqVectorAxpy( HYPRE_Complex alpha,
return hypre_error_flag;
}

/*--------------------------------------------------------------------------
* hypre_SeqVectorAxpyzHost
*--------------------------------------------------------------------------*/

HYPRE_Int
hypre_SeqVectorAxpyzHost( HYPRE_Complex alpha,
hypre_Vector *x,
HYPRE_Complex beta,
hypre_Vector *y,
hypre_Vector *z )
{
HYPRE_Complex *x_data = hypre_VectorData(x);
HYPRE_Complex *y_data = hypre_VectorData(y);
HYPRE_Complex *z_data = hypre_VectorData(z);

HYPRE_Int num_vectors = hypre_VectorNumVectors(x);
HYPRE_Int size = hypre_VectorSize(x);
HYPRE_Int total_size = size * num_vectors;
HYPRE_Int i;

#if defined(HYPRE_USING_OPENMP)
#pragma omp parallel for private(i) HYPRE_SMP_SCHEDULE
#endif
for (i = 0; i < total_size; i++)
{
z_data[i] = alpha * x_data[i] + beta * y_data[i];
}

return hypre_error_flag;
}

/*--------------------------------------------------------------------------
* hypre_SeqVectorAxpyz
*
* Computes z = a*x + b*y
*--------------------------------------------------------------------------*/

HYPRE_Int
hypre_SeqVectorAxpyz( HYPRE_Complex alpha,
hypre_Vector *x,
HYPRE_Complex beta,
hypre_Vector *y,
hypre_Vector *z )
{
#if defined(HYPRE_USING_GPU)
HYPRE_ExecutionPolicy exec = hypre_GetExecPolicy2( hypre_VectorMemoryLocation(x),
hypre_VectorMemoryLocation(y));
if (exec == HYPRE_EXEC_DEVICE)
{
hypre_SeqVectorAxpyzDevice(alpha, x, beta, y, z);
}
else
#endif
{
hypre_SeqVectorAxpyzHost(alpha, x, beta, y, z);
}

return hypre_error_flag;
}

/*--------------------------------------------------------------------------
* hypre_SeqVectorElmdivpyHost
*
Expand Down
37 changes: 37 additions & 0 deletions src/seq_mv/vector_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,43 @@ hypre_SeqVectorAxpyDevice( HYPRE_Complex alpha,
return hypre_error_flag;
}

/*--------------------------------------------------------------------------
* hypre_SeqVectorAxpyzDevice
*--------------------------------------------------------------------------*/

HYPRE_Int
hypre_SeqVectorAxpyzDevice( HYPRE_Complex alpha,
hypre_Vector *x,
HYPRE_Complex beta,
hypre_Vector *y,
hypre_Vector *z )
{
HYPRE_Complex *x_data = hypre_VectorData(x);
HYPRE_Complex *y_data = hypre_VectorData(y);
HYPRE_Complex *z_data = hypre_VectorData(z);

HYPRE_Int num_vectors = hypre_VectorNumVectors(x);
HYPRE_Int size = hypre_VectorSize(x);
HYPRE_Int total_size = size * num_vectors;

#if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP) || defined(HYPRE_USING_SYCL)
hypreDevice_ComplexAxpyzn(total_size, x_data, y_data, z_data, alpha, beta);

#elif defined(HYPRE_USING_DEVICE_OPENMP)
HYPRE_Int i;

#pragma omp target teams distribute parallel for private(i) is_device_ptr(z_data, y_data, x_data)
for (i = 0; i < total_size; i++)
{
z_data[i] = alpha * x_data[i] + beta * z_data[i];
}
#endif

hypre_SyncComputeStream(hypre_handle());

return hypre_error_flag;
}

/*--------------------------------------------------------------------------
* hypre_SeqVectorElmdivpyDevice
*--------------------------------------------------------------------------*/
Expand Down
3 changes: 2 additions & 1 deletion src/utilities/_hypre_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ typedef void (*GPUMfreeFunc)(void *);
#endif

#endif

/******************************************************************************
* Copyright (c) 1998 Lawrence Livermore National Security, LLC and other
* HYPRE Project Developers. See the top-level COPYRIGHT file for details.
Expand Down Expand Up @@ -1868,6 +1867,8 @@ HYPRE_Int hypreDevice_IntAxpyn(HYPRE_Int *d_x, size_t n, HYPRE_Int *d_y, HYPRE_I
HYPRE_Int a);
HYPRE_Int hypreDevice_BigIntAxpyn(HYPRE_BigInt *d_x, size_t n, HYPRE_BigInt *d_y,
HYPRE_BigInt *d_z, HYPRE_BigInt a);
HYPRE_Int hypreDevice_ComplexAxpyzn(HYPRE_Int n, HYPRE_Complex *d_x, HYPRE_Complex *d_y,
HYPRE_Complex *d_z, HYPRE_Complex a, HYPRE_Complex b);
HYPRE_Int* hypreDevice_CsrRowPtrsToIndices(HYPRE_Int nrows, HYPRE_Int nnz, HYPRE_Int *d_row_ptr);
HYPRE_Int hypreDevice_CsrRowPtrsToIndices_v2(HYPRE_Int nrows, HYPRE_Int nnz, HYPRE_Int *d_row_ptr,
HYPRE_Int *d_row_ind);
Expand Down
48 changes: 30 additions & 18 deletions src/utilities/device_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -1413,37 +1413,35 @@ hypreDevice_GenScatterAdd( HYPRE_Real *x,
}

/*--------------------------------------------------------------------
* hypreGPUKernel_axpyn
* hypreGPUKernel_Axpyzn
*--------------------------------------------------------------------*/

template<typename T>
__global__ void
hypreGPUKernel_axpyn( hypre_DeviceItem &item,
T *x,
size_t n,
T *y,
T *z,
T a )
hypreGPUKernel_Axpyzn( hypre_DeviceItem &item,
HYPRE_Int n,
T *x,
T *y,
T *z,
T a,
T b )
{
HYPRE_Int i = hypre_gpu_get_grid_thread_id<1, 1>(item);

if (i < n)
{
z[i] = a * x[i] + y[i];
z[i] = a * x[i] + b * y[i];
}
}

/*--------------------------------------------------------------------
* hypreDevice_Axpyn
* hypreDevice_Axpyzn
*--------------------------------------------------------------------*/

template<typename T>
HYPRE_Int
hypreDevice_Axpyn(T *d_x, size_t n, T *d_y, T *d_z, T a)
hypreDevice_Axpyzn(HYPRE_Int n, T *d_x, T *d_y, T *d_z, T a, T b)
{
#if 0
HYPRE_THRUST_CALL( transform, d_x, d_x + n, d_y, d_z, a * _1 + _2 );
#else
if (n <= 0)
{
return hypre_error_flag;
Expand All @@ -1452,8 +1450,7 @@ hypreDevice_Axpyn(T *d_x, size_t n, T *d_y, T *d_z, T a)
dim3 bDim = hypre_GetDefaultDeviceBlockDimension();
dim3 gDim = hypre_GetDefaultDeviceGridDimension(n, "thread", bDim);

HYPRE_GPU_LAUNCH( hypreGPUKernel_axpyn, gDim, bDim, d_x, n, d_y, d_z, a );
#endif
HYPRE_GPU_LAUNCH( hypreGPUKernel_Axpyzn, gDim, bDim, n, d_x, d_y, d_z, a, b );

return hypre_error_flag;
}
Expand All @@ -1469,7 +1466,7 @@ hypreDevice_ComplexAxpyn( HYPRE_Complex *d_x,
HYPRE_Complex *d_z,
HYPRE_Complex a )
{
return hypreDevice_Axpyn(d_x, n, d_y, d_z, a);
return hypreDevice_Axpyzn((HYPRE_Int) n, d_x, d_y, d_z, a, 1.0);
}

/*--------------------------------------------------------------------
Expand All @@ -1483,7 +1480,7 @@ hypreDevice_IntAxpyn( HYPRE_Int *d_x,
HYPRE_Int *d_z,
HYPRE_Int a )
{
return hypreDevice_Axpyn(d_x, n, d_y, d_z, a);
return hypreDevice_Axpyzn((HYPRE_Int) n, d_x, d_y, d_z, a, 1);
}

/*--------------------------------------------------------------------
Expand All @@ -1497,7 +1494,22 @@ hypreDevice_BigIntAxpyn( HYPRE_BigInt *d_x,
HYPRE_BigInt *d_z,
HYPRE_BigInt a )
{
return hypreDevice_Axpyn(d_x, n, d_y, d_z, a);
return hypreDevice_Axpyzn((HYPRE_Int) n, d_x, d_y, d_z, a, 1);
}

/*--------------------------------------------------------------------
* hypreDevice_ComplexAxpyzn
*--------------------------------------------------------------------*/

HYPRE_Int
hypreDevice_ComplexAxpyzn( HYPRE_Int n,
HYPRE_Complex *d_x,
HYPRE_Complex *d_y,
HYPRE_Complex *d_z,
HYPRE_Complex a,
HYPRE_Complex b )
{
return hypreDevice_Axpyzn(n, d_x, d_y, d_z, a, b);
}

#if defined(HYPRE_USING_CURAND)
Expand Down
1 change: 0 additions & 1 deletion src/utilities/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,3 @@ typedef void (*GPUMfreeFunc)(void *);
#endif

#endif

2 changes: 2 additions & 0 deletions src/utilities/protos.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ HYPRE_Int hypreDevice_IntAxpyn(HYPRE_Int *d_x, size_t n, HYPRE_Int *d_y, HYPRE_I
HYPRE_Int a);
HYPRE_Int hypreDevice_BigIntAxpyn(HYPRE_BigInt *d_x, size_t n, HYPRE_BigInt *d_y,
HYPRE_BigInt *d_z, HYPRE_BigInt a);
HYPRE_Int hypreDevice_ComplexAxpyzn(HYPRE_Int n, HYPRE_Complex *d_x, HYPRE_Complex *d_y,
HYPRE_Complex *d_z, HYPRE_Complex a, HYPRE_Complex b);
HYPRE_Int* hypreDevice_CsrRowPtrsToIndices(HYPRE_Int nrows, HYPRE_Int nnz, HYPRE_Int *d_row_ptr);
HYPRE_Int hypreDevice_CsrRowPtrsToIndices_v2(HYPRE_Int nrows, HYPRE_Int nnz, HYPRE_Int *d_row_ptr,
HYPRE_Int *d_row_ind);
Expand Down

0 comments on commit 53dfbe3

Please sign in to comment.