Skip to content

Commit

Permalink
[2020-u3-rls] Fix class voting for kNN (#962)
Browse files Browse the repository at this point in the history
* Fix class voting for kNN

* Fix kNN kdtree examples
  • Loading branch information
Alexsandruss authored Sep 18, 2020
1 parent 6580da4 commit d148c71
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 79 deletions.
37 changes: 20 additions & 17 deletions algorithms/kernel/k_nearest_neighbors/bf_knn_impl.i
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public:
TlsMem<int, cpu> tlsIdx(outBlockSize);
TlsMem<FPType, cpu> tlsKDistances(inBlockSize * k);
TlsMem<int, cpu> tlsKIndexes(inBlockSize * k);
TlsMem<int, cpu> tlsVoting(nClasses);
TlsMem<FPType, cpu> tlsVoting(nClasses);

SafeStatus safeStat;

Expand Down Expand Up @@ -150,7 +150,7 @@ protected:
FPType * trainLabel, const NumericTable * trainTable, const NumericTable * testTable,
NumericTable * testLabelTable, NumericTable * indicesTable, NumericTable * distancesTable,
TlsMem<FPType, cpu> & tlsDistances, TlsMem<int, cpu> & tlsIdx, TlsMem<FPType, cpu> & tlsKDistances,
TlsMem<int, cpu> & tlsKIndexes, TlsMem<int, cpu> & tlsVoting, size_t nOuterBlocks)
TlsMem<int, cpu> & tlsKIndexes, TlsMem<FPType, cpu> & tlsVoting, size_t nOuterBlocks)
{
const size_t inBlockSize = trainBlockSize;
const size_t inRows = nTrain;
Expand Down Expand Up @@ -265,7 +265,7 @@ protected:
DAAL_CHECK_BLOCK_STATUS(testLabelRows);
int * testLabel = testLabelRows.get();

int * voting = tlsVoting.local();
FPType * voting = tlsVoting.local();
DAAL_CHECK_MALLOC(voting);

if (voteWeights == VoteWeights::voteUniform)
Expand Down Expand Up @@ -351,7 +351,7 @@ protected:
}

services::Status uniformWeightedVoting(const size_t nClasses, const size_t k, const size_t n, const size_t nTrain, int * indices,
const FPType * trainLabel, int * testLabel, int * classWeights)
const FPType * trainLabel, int * testLabel, FPType * classWeights)
{
for (size_t i = 0; i < n; ++i)
{
Expand Down Expand Up @@ -380,36 +380,39 @@ protected:
}

services::Status distanceWeightedVoting(const size_t nClasses, const size_t k, const size_t n, const size_t nTrain, FPType * distances,
int * indices, const FPType * trainLabel, int * testLabel, int * classWeights)
int * indices, const FPType * trainLabel, int * testLabel, FPType * classWeights)
{
const FPType epsilon = daal::services::internal::EpsilonVal<FPType>::get();
bool isContainZero = false;
for (size_t i = 0; i < k * n; ++i)
{
if (distances[i] < epsilon)
{
isContainZero = true;
break;
}
}

for (size_t i = 0; i < n; ++i)
{
bool isContainZero = false;
for (size_t j = 0; j < k * n; ++j)
{
if (distances[j] < epsilon)
{
isContainZero = true;
break;
}
}
for (size_t j = 0; j < nClasses; ++j)
{
classWeights[j] = 0;
}
for (size_t j = 0; j < k; ++j)
if (isContainZero)
{
if (isContainZero)
for (size_t j = 0; j < k; ++j)
{
if (distances[i] < epsilon)
{
const int label = static_cast<int>(trainLabel[indices[i * k + j]]);
classWeights[label] += 1;
}
}
else
}
else
{
for (size_t j = 0; j < k; ++j)
{
const int label = static_cast<int>(trainLabel[indices[i * k + j]]);
classWeights[label] += 1 / distances[i * k + j];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu> : publi
services::Status predict(algorithmFpType * predictedClass, const Heap<GlobalNeighbors<algorithmFpType, cpu>, cpu> & heap,
const NumericTable * labels, size_t k, VoteWeights voteWeights, const NumericTable * modelIndices,
data_management::BlockDescriptor<algorithmFpType> & indices,
data_management::BlockDescriptor<algorithmFpType> & distances, size_t index);
data_management::BlockDescriptor<algorithmFpType> & distances, size_t index, const size_t nClasses);
};

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,24 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
typedef daal::internal::Math<algorithmFpType, cpu> Math;

size_t k;
size_t nClasses;
VoteWeights voteWeights = voteUniform;
DAAL_UINT64 resultsToEvaluate = classifier::computeClassLabels;

{
auto par1 = dynamic_cast<const kdtree_knn_classification::interface1::Parameter *>(par);
if (par1) k = par1->k;
if (par1)
{
k = par1->k;
nClasses = par1->nClasses;
}

auto par2 = dynamic_cast<const kdtree_knn_classification::interface2::Parameter *>(par);
if (par2)
{
k = par2->k;
resultsToEvaluate = par2->resultsToEvaluate;
nClasses = par2->nClasses;
}

const auto par3 = dynamic_cast<const kdtree_knn_classification::interface3::Parameter *>(par);
Expand All @@ -302,6 +308,7 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
k = par3->k;
voteWeights = par3->voteWeights;
resultsToEvaluate = par3->resultsToEvaluate;
nClasses = par3->nClasses;
}

if (par1 == NULL && par2 == NULL && par3 == NULL) return Status(ErrorNullParameterNotSupported);
Expand Down Expand Up @@ -408,7 +415,7 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
{
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data,
isHomogenSOA, soa_arrays);
s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i);
s = predict(&(dy[i * yColumnCount]), local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s)
}

Expand All @@ -421,7 +428,7 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
{
findNearestNeighbors(&dx[i * xColumnCount], local->heap, local->stack, k, radius, kdTreeTable, rootTreeNodeIndex, data,
isHomogenSOA, soa_arrays);
s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i);
s = predict(nullptr, local->heap, labels, k, voteWeights, modelIndices, indicesBD, distancesBD, i, nClasses);
DAAL_CHECK_STATUS_THR(s)
}
}
Expand Down Expand Up @@ -599,7 +606,7 @@ template <typename algorithmFpType, CpuType cpu>
services::Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::predict(
algorithmFpType * predictedClass, const Heap<GlobalNeighbors<algorithmFpType, cpu>, cpu> & heap, const NumericTable * labels, size_t k,
VoteWeights voteWeights, const NumericTable * modelIndices, data_management::BlockDescriptor<algorithmFpType> & indices,
data_management::BlockDescriptor<algorithmFpType> & distances, size_t index)
data_management::BlockDescriptor<algorithmFpType> & distances, size_t index, const size_t nClasses)
{
typedef daal::internal::Math<algorithmFpType, cpu> Math;

Expand Down Expand Up @@ -661,39 +668,29 @@ services::Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, c
DAAL_ASSERT(predictedClass);

data_management::BlockDescriptor<algorithmFpType> labelBD;
algorithmFpType * classes = static_cast<algorithmFpType *>(daal::services::internal::service_malloc<algorithmFpType, cpu>(heapSize));
algorithmFpType * classes = static_cast<algorithmFpType *>(daal::services::internal::service_malloc<algorithmFpType, cpu>(heapSize));
algorithmFpType * classWeights = static_cast<algorithmFpType *>(daal::services::internal::service_malloc<algorithmFpType, cpu>(nClasses));
DAAL_CHECK_MALLOC(classWeights);
DAAL_CHECK_MALLOC(classes);

for (size_t i = 0; i < nClasses; ++i)
{
classWeights[i] = 0;
}

for (size_t i = 0; i < heapSize; ++i)
{
const_cast<NumericTable *>(labels)->getBlockOfColumnValues(0, heap[i].index, 1, readOnly, labelBD);
classes[i] = *(labelBD.getBlockPtr());
const_cast<NumericTable *>(labels)->releaseBlockOfColumnValues(labelBD);
}
daal::algorithms::internal::qSort<algorithmFpType, cpu>(heapSize, classes);
algorithmFpType currentClass = classes[0];
algorithmFpType winnerClass = currentClass;

if (voteWeights == voteUniform)
{
size_t currentWeight = 1;
size_t winnerWeight = currentWeight;
for (size_t i = 1; i < heapSize; ++i)
for (size_t i = 0; i < heapSize; ++i)
{
if (classes[i] == currentClass)
{
if ((++currentWeight) > winnerWeight)
{
winnerWeight = currentWeight;
winnerClass = currentClass;
}
}
else
{
currentWeight = 1;
currentClass = classes[i];
}
classWeights[(size_t)(classes[i])] += 1;
}
*predictedClass = winnerClass;
}
else
{
Expand All @@ -714,55 +711,37 @@ services::Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, c

if (isContainZero)
{
size_t currentWeight = (heap[0].distance <= epsilon);
size_t winnerWeight = currentWeight;
for (size_t i = 1; i < heapSize; ++i)
for (size_t i = 0; i < heapSize; ++i)
{
if (classes[i] == currentClass)
{
currentWeight += (heap[i].distance <= epsilon);
}
else
if (heap[i].distance <= epsilon)
{
currentWeight = (heap[i].distance <= epsilon);
currentClass = classes[i];
}

if (currentWeight > winnerWeight)
{
winnerWeight = currentWeight;
winnerClass = currentClass;
classWeights[(size_t)(classes[i])] += 1;
}
}
*predictedClass = winnerClass;
}
else
{
algorithmFpType currentWeight = Math::sSqrt(1.0 / heap[0].distance);
algorithmFpType winnerWeight = currentWeight;
for (size_t i = 1; i < heapSize; ++i)
for (size_t i = 0; i < heapSize; ++i)
{
if (classes[i] == currentClass)
{
currentWeight += Math::sSqrt(1.0 / heap[i].distance);
}
else
{
currentWeight = Math::sSqrt(1.0 / heap[i].distance);
currentClass = classes[i];
}

if (currentWeight > winnerWeight)
{
winnerWeight = currentWeight;
winnerClass = currentClass;
}
classWeights[(size_t)(classes[i])] += Math::sSqrt(1 / heap[i].distance);
}
*predictedClass = winnerClass;
}
}

algorithmFpType maxWeightClass = 0;
algorithmFpType maxWeight = 0;
for (size_t i = 0; i < nClasses; ++i)
{
if (classWeights[i] > maxWeight)
{
maxWeight = classWeights[i];
maxWeightClass = i;
}
}
*predictedClass = maxWeightClass;

service_free<algorithmFpType, cpu>(classes);
service_free<algorithmFpType, cpu>(classWeights);
classes = nullptr;
}

Expand Down
1 change: 1 addition & 0 deletions examples/cpp/source/k_nearest_neighbors/kdtree_knn_dense_batch.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ void testModel()
/* Pass the testing data set and trained model to the algorithm */
algorithm.input.set(classifier::prediction::data, testData);
algorithm.input.set(classifier::prediction::model, trainingResult->get(classifier::training::model));
algorithm.parameter.nClasses = nClasses;

/* Compute prediction results */
algorithm.compute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private static void testModel() {

kNearestNeighborsPredict.input.set(NumericTableInputId.data, testData);
kNearestNeighborsPredict.input.set(ModelInputId.model, model);
kNearestNeighborsPredict.parameter.setNClasses(nClasses);

/* Compute prediction results */
PredictionResult predictionResult = kNearestNeighborsPredict.compute();
Expand Down

0 comments on commit d148c71

Please sign in to comment.