Skip to content

Commit

Permalink
Fix argmax kernel for GPU SVM (#1270) (#1281)
Browse files Browse the repository at this point in the history
(cherry picked from commit 87ae894)

Co-authored-by: Kirill <[email protected]>
  • Loading branch information
mergify[bot] and PetrovKP authored Nov 16, 2020
1 parent dac5d4a commit 04c37c9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@ DECLARE_SOURCE_DAAL(
const int groupId = get_sub_group_id();
const int localId = get_local_id(0);
const int groupCount = get_num_sub_groups();
const int subGroupSize = get_sub_group_size();

algorithmFPType x = values[localId];
int indX = localId;

algorithmFPType resMax;
int resIndex;
resMax = sub_group_reduce_max(x);
resIndex = sub_group_reduce_min(resMax == x ? indX : INT_MAX);
algorithmFPType resMax = sub_group_reduce_max(x);
int resIndex = sub_group_reduce_min(resMax == x ? indX : INT_MAX);

if (localGroupId == 0)
{
Expand All @@ -73,6 +72,19 @@ DECLARE_SOURCE_DAAL(
resMax = sub_group_reduce_max(x);
resIndex = sub_group_reduce_min(resMax == x ? indX : INT_MAX);

for (int iGroup = subGroupSize; iGroup < groupCount; iGroup += subGroupSize)
{
x = localCache[iGroup + localGroupId].value;
indX = localCache[iGroup + localGroupId].index;

const algorithmFPType innerMax = sub_group_reduce_max(x);
if (innerMax > resMax)
{
resMax = innerMax;
resIndex = sub_group_reduce_min(resMax == x ? indX : INT_MAX);
}
}

if (localGroupId == 0)
{
result->value = resMax;
Expand All @@ -83,7 +95,7 @@ DECLARE_SOURCE_DAAL(
}

__kernel void smoKernel(const __global algorithmFPType * const y, const __global algorithmFPType * const kernelWsRows,
const __global int * wsIndices, const uint nVectors, const __global algorithmFPType * grad, const algorithmFPType C,
const __global uint * wsIndices, const uint nVectors, const __global algorithmFPType * grad, const algorithmFPType C,
const algorithmFPType eps, const algorithmFPType tau, const uint maxInnerIteration, __global algorithmFPType * alpha,
__global algorithmFPType * deltaalpha, __global algorithmFPType * resinfo) {
const uint i = get_local_id(0);
Expand Down Expand Up @@ -117,7 +129,7 @@ DECLARE_SOURCE_DAAL(
__local algorithmFPType localEps;

uint iter = 0;
for (; iter < maxInnerIteration; iter++)
for (; iter < maxInnerIteration; ++iter)
{
/* m(alpha) = min(grad[i]): i belongs to I_UP (alpha) */
objFunc[i] = isUpper(alphai, yi, C) ? -gradi : MIN_FLT;
Expand Down Expand Up @@ -194,7 +206,6 @@ DECLARE_SOURCE_DAAL(
barrier(CLK_LOCAL_MEM_FENCE);

const algorithmFPType delta = min(deltaBi, deltaBj);

if (i == Bi)
{
alphai = alphai + yi * delta;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ class SaveResultModel

bias = -0.5 * (ub + lb);
}

return status;
}

Expand Down

0 comments on commit 04c37c9

Please sign in to comment.