diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala index 0fc287e11bf..c7662b61f31 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala @@ -24,6 +24,8 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression import com.nvidia.spark.rapids.jni.HyperLogLogPlusPlusHostUDF +import com.nvidia.spark.rapids.jni.HyperLogLogPlusPlusHostUDF.AggregationType + import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} @@ -45,18 +47,15 @@ case class CudfHLLPP(override val dataType: DataType, cudf.Scalar.structFromColumnViews(cvs: _*) } } else { - input.reduce( - ReductionAggregation.hostUDF( - HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( - HyperLogLogPlusPlusHostUDF.AggregationType.Reduction, precision)), - DType.STRUCT) + withResource(new HyperLogLogPlusPlusHostUDF(AggregationType.Reduction, precision)) { hll => + input.reduce(ReductionAggregation.hostUDF(hll), DType.STRUCT) + } } } override lazy val groupByAggregate: GroupByAggregation = - GroupByAggregation.hostUDF( - HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( - HyperLogLogPlusPlusHostUDF.AggregationType.GroupBy, precision) - ) + withResource(new HyperLogLogPlusPlusHostUDF(AggregationType.GroupBy, precision)) { hll => + GroupByAggregation.hostUDF(hll) + } override val name: String = "CudfHyperLogLogPlusPlus" } @@ -64,16 +63,17 @@ case class CudfMergeHLLPP(override val dataType: DataType, precision: Int) extends CudfAggregate { override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = - (input: cudf.ColumnVector) => input.reduce( - ReductionAggregation.hostUDF( - HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( - HyperLogLogPlusPlusHostUDF.AggregationType.Reduction_MERGE, precision)), - DType.STRUCT) + (input: cudf.ColumnVector) => withResource( + new HyperLogLogPlusPlusHostUDF(AggregationType.ReductionMerge, precision)) { hll => + input.reduce(ReductionAggregation.hostUDF(hll), DType.STRUCT) + } + override lazy val groupByAggregate: GroupByAggregation = - GroupByAggregation.hostUDF( - HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( - HyperLogLogPlusPlusHostUDF.AggregationType.GroupByMerge, precision) - ) + withResource( + new HyperLogLogPlusPlusHostUDF(AggregationType.GroupByMerge, precision)) { hll => + GroupByAggregation.hostUDF(hll) + } + override val name: String = "CudfMergeHyperLogLogPlusPlus" }