Skip to content

Commit

Permalink
Update APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Jan 16, 2025
1 parent a24d9b8 commit f305e0c
Showing 1 changed file with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression
import com.nvidia.spark.rapids.jni.HyperLogLogPlusPlusHostUDF
import com.nvidia.spark.rapids.jni.HyperLogLogPlusPlusHostUDF.AggregationType

import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
Expand All @@ -45,35 +47,33 @@ case class CudfHLLPP(override val dataType: DataType,
cudf.Scalar.structFromColumnViews(cvs: _*)
}
} else {
input.reduce(
ReductionAggregation.hostUDF(
HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF(
HyperLogLogPlusPlusHostUDF.AggregationType.Reduction, precision)),
DType.STRUCT)
withResource(new HyperLogLogPlusPlusHostUDF(AggregationType.Reduction, precision)) { hll =>
input.reduce(ReductionAggregation.hostUDF(hll), DType.STRUCT)
}
}
}
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.hostUDF(
HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF(
HyperLogLogPlusPlusHostUDF.AggregationType.GroupBy, precision)
)
withResource(new HyperLogLogPlusPlusHostUDF(AggregationType.GroupBy, precision)) { hll =>
GroupByAggregation.hostUDF(hll)
}
override val name: String = "CudfHyperLogLogPlusPlus"
}

case class CudfMergeHLLPP(override val dataType: DataType,
precision: Int)
extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(input: cudf.ColumnVector) => input.reduce(
ReductionAggregation.hostUDF(
HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF(
HyperLogLogPlusPlusHostUDF.AggregationType.Reduction_MERGE, precision)),
DType.STRUCT)
(input: cudf.ColumnVector) => withResource(
new HyperLogLogPlusPlusHostUDF(AggregationType.ReductionMerge, precision)) { hll =>
input.reduce(ReductionAggregation.hostUDF(hll), DType.STRUCT)
}

override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.hostUDF(
HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF(
HyperLogLogPlusPlusHostUDF.AggregationType.GroupByMerge, precision)
)
withResource(
new HyperLogLogPlusPlusHostUDF(AggregationType.GroupByMerge, precision)) { hll =>
GroupByAggregation.hostUDF(hll)
}

override val name: String = "CudfMergeHyperLogLogPlusPlus"
}

Expand Down

0 comments on commit f305e0c

Please sign in to comment.