diff --git a/.github/workflows/maven.yml b/.github/workflows/maven.yml index 583a1c4dac3..7aeb0e45c38 100644 --- a/.github/workflows/maven.yml +++ b/.github/workflows/maven.yml @@ -133,8 +133,12 @@ jobs: run: | SPARK_BINARY_VERSION=${{ matrix.spark }} SPARK_MAJOR_VERSION=${SPARK_BINARY_VERSION%%.*} + SPARK_MODULE_NAME=$SPARK_MAJOR_VERSION + if [[ $SPARK_MAJOR_VERSION == "3" ]]; then + SPARK_MODULE_NAME="3-4" + fi PROFILES="-Pgoogle-mirror,spark-${{ matrix.spark }}" - TEST_MODULES="client-spark/common,client-spark/spark-${SPARK_MAJOR_VERSION},client-spark/spark-${SPARK_MAJOR_VERSION}-shaded,tests/spark-it" + TEST_MODULES="client-spark/common,client-spark/spark-${SPARK_MODULE_NAME},client-spark/spark-${SPARK_MAJOR_VERSION}-columnar-common,client-spark/spark-${SPARK_MAJOR_VERSION}-shaded,tests/spark-it" build/mvn $PROFILES -pl $TEST_MODULES -am clean install -DskipTests build/mvn $PROFILES -pl $TEST_MODULES -Dspark.shuffle.sort.io.plugin.class=${{ matrix.shuffle-plugin-class }} test - name: Upload test log diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index f00fe063e13..9b959a4b895 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -19,14 +19,12 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; -import org.apache.spark.scheduler.DAGScheduler; public class SparkCommonUtils { public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException { + int DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4; int maxStageAttempts = - conf.getInt( - "spark.stage.maxConsecutiveAttempts", - DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS()); + conf.getInt("spark.stage.maxConsecutiveAttempts", DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS); // In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in Spark 3, it has been // changed to TASK_MAX_FAILURES. The default value for both is consistently set to 4. int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4); diff --git a/client-spark/spark-3/pom.xml b/client-spark/spark-3-4/pom.xml similarity index 96% rename from client-spark/spark-3/pom.xml rename to client-spark/spark-3-4/pom.xml index 4acacbed0b7..8ddfd6f2267 100644 --- a/client-spark/spark-3/pom.xml +++ b/client-spark/spark-3-4/pom.xml @@ -24,9 +24,9 @@ ../../pom.xml - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} jar - Celeborn Client for Spark 3 + Celeborn Client for Spark 3 and 4 diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java similarity index 100% rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java similarity index 100% rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java similarity index 100% rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java similarity index 100% rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java similarity index 100% rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/SparkVersionUtil.scala similarity index 100% rename from client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/SparkVersionUtil.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala similarity index 100% rename from client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala similarity index 100% rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala similarity index 100% rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala similarity index 100% rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java similarity index 100% rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java similarity index 100% rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java similarity index 100% rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java similarity index 100% rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java similarity index 100% rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java diff --git a/client-spark/spark-3/src/test/resources/log4j.properties b/client-spark/spark-3-4/src/test/resources/log4j.properties similarity index 100% rename from client-spark/spark-3/src/test/resources/log4j.properties rename to client-spark/spark-3-4/src/test/resources/log4j.properties diff --git a/client-spark/spark-3/src/test/resources/log4j2-test.xml b/client-spark/spark-3-4/src/test/resources/log4j2-test.xml similarity index 100% rename from client-spark/spark-3/src/test/resources/log4j2-test.xml rename to client-spark/spark-3-4/src/test/resources/log4j2-test.xml diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala b/client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala similarity index 100% rename from client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala rename to client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala diff --git a/client-spark/spark-3-columnar-common/pom.xml b/client-spark/spark-3-columnar-common/pom.xml index 6b4b3c280a1..39a36633e4b 100644 --- a/client-spark/spark-3-columnar-common/pom.xml +++ b/client-spark/spark-3-columnar-common/pom.xml @@ -31,7 +31,7 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} diff --git a/client-spark/spark-3-columnar-shuffle/pom.xml b/client-spark/spark-3-columnar-shuffle/pom.xml index 5f082443738..0a628cc3d33 100644 --- a/client-spark/spark-3-columnar-shuffle/pom.xml +++ b/client-spark/spark-3-columnar-shuffle/pom.xml @@ -49,7 +49,7 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml index d3d59cb87a8..9e2c921ab12 100644 --- a/client-spark/spark-3-shaded/pom.xml +++ b/client-spark/spark-3-shaded/pom.xml @@ -31,7 +31,7 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} diff --git a/client-spark/spark-4-columnar-shuffle/pom.xml b/client-spark/spark-4-columnar-shuffle/pom.xml new file mode 100644 index 00000000000..bd806f2fb9b --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/pom.xml @@ -0,0 +1,68 @@ + + + + 4.0.0 + + org.apache.celeborn + celeborn-parent_${scala.binary.version} + ${project.version} + ../../pom.xml + + + celeborn-spark-4-columnar-shuffle_${scala.binary.version} + jar + Celeborn Client for Spark 4 Columnar Shuffle + + + + org.apache.celeborn + celeborn-spark-3-columnar-common_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.celeborn + celeborn-client_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.celeborn + celeborn-client-spark-3_${scala.binary.version} + ${project.version} + test-jar + test + + + org.mockito + mockito-core + test + + + org.mockito + mockito-inline + test + + + diff --git a/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java b/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java new file mode 100644 index 00000000000..b09b1306c90 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.io.IOException; + +import scala.Product2; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.Private; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.UnsafeRowSerializer; +import org.apache.spark.sql.execution.columnar.CelebornBatchBuilder; +import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchBuilder; +import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild; +import org.apache.spark.sql.execution.metric.SQLMetric; +import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.common.CelebornConf; + +@Private +public class ColumnarHashBasedShuffleWriter extends HashBasedShuffleWriter { + + private static final Logger logger = + LoggerFactory.getLogger(ColumnarHashBasedShuffleWriter.class); + + private final int stageId; + private final int shuffleId; + private final CelebornBatchBuilder[] celebornBatchBuilders; + private final StructType schema; + private final Serializer depSerializer; + private final boolean isColumnarShuffle; + private final int columnarShuffleBatchSize; + private final boolean columnarShuffleCodeGenEnabled; + private final boolean columnarShuffleDictionaryEnabled; + private final double columnarShuffleDictionaryMaxFactor; + + public ColumnarHashBasedShuffleWriter( + int shuffleId, + CelebornShuffleHandle handle, + TaskContext taskContext, + CelebornConf conf, + ShuffleClient client, + ShuffleWriteMetricsReporter metrics, + SendBufferPool sendBufferPool) + throws IOException { + super(shuffleId, handle, taskContext, conf, client, metrics, sendBufferPool); + columnarShuffleBatchSize = conf.columnarShuffleBatchSize(); + columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled(); + columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled(); + columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor(); + ShuffleDependency shuffleDependency = handle.dependency(); + this.stageId = taskContext.stageId(); + this.shuffleId = shuffleDependency.shuffleId(); + this.schema = CustomShuffleDependencyUtils.getSchema(shuffleDependency); + this.depSerializer = handle.dependency().serializer(); + this.celebornBatchBuilders = + new CelebornBatchBuilder[handle.dependency().partitioner().numPartitions()]; + this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema); + } + + @Override + protected void fastWrite0(scala.collection.Iterator iterator) + throws IOException, InterruptedException { + if (isColumnarShuffle) { + logger.info("Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId); + fastColumnarWrite0(iterator); + } else { + super.fastWrite0(iterator); + } + } + + private void fastColumnarWrite0(scala.collection.Iterator iterator) throws IOException { + final scala.collection.Iterator> records = iterator; + + SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) depSerializer); + while (records.hasNext()) { + final Product2 record = records.next(); + final int partitionId = record._1(); + final UnsafeRow row = record._2(); + + if (celebornBatchBuilders[partitionId] == null) { + CelebornBatchBuilder columnBuilders; + if (columnarShuffleCodeGenEnabled && !columnarShuffleDictionaryEnabled) { + columnBuilders = + new CelebornColumnarBatchCodeGenBuild().create(schema, columnarShuffleBatchSize); + } else { + columnBuilders = + new CelebornColumnarBatchBuilder( + schema, + columnarShuffleBatchSize, + columnarShuffleDictionaryMaxFactor, + columnarShuffleDictionaryEnabled); + } + columnBuilders.newBuilders(); + celebornBatchBuilders[partitionId] = columnBuilders; + } + + celebornBatchBuilders[partitionId].writeRow(row); + if (celebornBatchBuilders[partitionId].getRowCnt() >= columnarShuffleBatchSize) { + byte[] arr = celebornBatchBuilders[partitionId].buildColumnBytes(); + pushGiantRecord(partitionId, arr, arr.length); + if (dataSize != null) { + dataSize.add(arr.length); + } + celebornBatchBuilders[partitionId].newBuilders(); + } + tmpRecordsWritten++; + } + } + + @Override + protected void closeWrite() throws IOException { + if (canUseFastWrite() && isColumnarShuffle) { + closeColumnarWrite(); + } else { + super.closeWrite(); + } + } + + private void closeColumnarWrite() throws IOException { + SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) depSerializer); + for (int i = 0; i < celebornBatchBuilders.length; i++) { + final CelebornBatchBuilder builders = celebornBatchBuilders[i]; + if (builders != null && builders.getRowCnt() > 0) { + byte[] buffers = builders.buildColumnBytes(); + if (dataSize != null) { + dataSize.add(buffers.length); + } + mergeData(i, buffers, 0, buffers.length); + // free buffer + celebornBatchBuilders[i] = null; + } + } + } + + @VisibleForTesting + public boolean isColumnarShuffle() { + return isColumnarShuffle; + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala new file mode 100644 index 00000000000..fd888fb9dc1 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.{ShuffleDependency, TaskContext} +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} + +import org.apache.celeborn.common.CelebornConf + +class CelebornColumnarShuffleReader[K, C]( + handle: CelebornShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + startMapIndex: Int = 0, + endMapIndex: Int = Int.MaxValue, + context: TaskContext, + conf: CelebornConf, + metrics: ShuffleReadMetricsReporter, + shuffleIdTracker: ExecutorShuffleIdTracker) + extends CelebornShuffleReader[K, C]( + handle, + startPartition, + endPartition, + startMapIndex, + endMapIndex, + context, + conf, + metrics, + shuffleIdTracker) { + + override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { + val schema = CustomShuffleDependencyUtils.getSchema(dep) + if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) { + logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.") + val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer]) + new CelebornColumnarBatchSerializer( + schema, + conf.columnarShuffleOffHeapEnabled, + dataSize).newInstance() + } else { + super.newSerializerInstance(dep) + } + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala new file mode 100644 index 00000000000..ce003919440 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types._ + +trait CelebornColumnAccessor { + initialize() + + protected def initialize(): Unit + + def hasNext: Boolean + + def extractTo(row: InternalRow, ordinal: Int): Unit + + def extractToColumnVector(columnVector: WritableColumnVector, ordinal: Int): Unit + + protected def underlyingBuffer: ByteBuffer +} + +abstract class CelebornBasicColumnAccessor[JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: CelebornColumnType[JvmType]) + extends CelebornColumnAccessor { + + protected def initialize(): Unit = {} + + override def hasNext: Boolean = buffer.hasRemaining + + override def extractTo(row: InternalRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) + } + + override def extractToColumnVector(columnVector: WritableColumnVector, ordinal: Int): Unit = { + val length = buffer.getInt() + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + columnVector.putByteArray(ordinal, bytes) + } + + def extractSingle(row: InternalRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } + + protected def underlyingBuffer: ByteBuffer = buffer +} + +abstract class CelebornNativeColumnAccessor[T <: PhysicalDataType]( + override protected val buffer: ByteBuffer, + override protected val columnType: NativeCelebornColumnType[T]) + extends CelebornBasicColumnAccessor(buffer, columnType) + with CelebornNullableColumnAccessor + with CelebornCompressibleColumnAccessor[T] + +class CelebornBooleanColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_BOOLEAN) + +class CelebornByteColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_BYTE) + +class CelebornShortColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_SHORT) + +class CelebornIntColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_INT) + +class CelebornLongColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_LONG) + +class CelebornFloatColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_FLOAT) + +class CelebornDoubleColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_DOUBLE) + +class CelebornStringColumnAccessor(buffer: ByteBuffer) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_STRING) + +class CelebornCompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends CelebornNativeColumnAccessor(buffer, CELEBORN_COMPACT_DECIMAL(dataType)) + +class CelebornDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends CelebornBasicColumnAccessor[Decimal](buffer, CELEBORN_LARGE_DECIMAL(dataType)) + with CelebornNullableColumnAccessor + +private[sql] object CelebornColumnAccessor { + + def apply(dataType: DataType, buffer: ByteBuffer): CelebornColumnAccessor = { + val buf = buffer.order(ByteOrder.nativeOrder) + + dataType match { + case BooleanType => new CelebornBooleanColumnAccessor(buf) + case ByteType => new CelebornByteColumnAccessor(buf) + case ShortType => new CelebornShortColumnAccessor(buf) + case IntegerType => new CelebornIntColumnAccessor(buf) + case LongType => new CelebornLongColumnAccessor(buf) + case FloatType => new CelebornFloatColumnAccessor(buf) + case DoubleType => new CelebornDoubleColumnAccessor(buf) + case StringType => new CelebornStringColumnAccessor(buf) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CelebornCompactDecimalColumnAccessor(buf, dt) + case dt: DecimalType => new CelebornDecimalColumnAccessor(buf, dt) + case other => throw new Exception(s"not support type: $other") + } + } + + def decompress( + columnAccessor: CelebornColumnAccessor, + columnVector: WritableColumnVector, + numRows: Int): Unit = { + columnAccessor match { + case nativeAccessor: CelebornNativeColumnAccessor[_] => + nativeAccessor.decompress(columnVector, numRows) + case _: CelebornDecimalColumnAccessor => + (0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _)) + case _ => + throw new RuntimeException("Not support non-primitive type now") + } + } + + def decompress( + array: Array[Byte], + columnVector: WritableColumnVector, + dataType: DataType, + numRows: Int): Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = CelebornColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala new file mode 100644 index 00000000000..7d25449981d --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.columnar.CelebornColumnBuilder.ensureFreeSpace +import org.apache.spark.sql.types._ + +trait CelebornColumnBuilder { + + /** + * Initializes with an approximate lower bound on the expected number of elements in this column. + */ + def initialize( + rowCnt: Int, + columnName: String = "", + encodingEnabled: Boolean = false): Unit + + /** + * Appends `row(ordinal)` to the column builder. + */ + def appendFrom(row: InternalRow, ordinal: Int): Unit + + /** + * Column statistics information + */ + def columnStats: CelebornColumnStats + + /** + * Returns the final columnar byte buffer. + */ + def build(): ByteBuffer + + def getTotalSize: Long = 0 +} + +class CelebornBasicColumnBuilder[JvmType]( + val columnStats: CelebornColumnStats, + val columnType: CelebornColumnType[JvmType]) + extends CelebornColumnBuilder { + + protected var columnName: String = _ + + protected var buffer: ByteBuffer = _ + + override def initialize( + rowCnt: Int, + columnName: String = "", + encodingEnabled: Boolean = false): Unit = { + + this.columnName = columnName + + buffer = ByteBuffer.allocate(rowCnt * columnType.defaultSize) + buffer.order(ByteOrder.nativeOrder()) + } + + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) + } + + override def build(): ByteBuffer = { + if (buffer.capacity() > buffer.position() * 1.1) { + // trim the buffer + buffer = ByteBuffer + .allocate(buffer.position()) + .order(ByteOrder.nativeOrder()) + .put(buffer.array(), 0, buffer.position()) + } + buffer.flip().asInstanceOf[ByteBuffer] + } +} + +abstract class CelebornComplexColumnBuilder[JvmType]( + columnStats: CelebornColumnStats, + columnType: CelebornColumnType[JvmType]) + extends CelebornBasicColumnBuilder[JvmType](columnStats, columnType) + with CelebornNullableColumnBuilder + +abstract class CelebornNativeColumnBuilder[T <: PhysicalDataType]( + override val columnStats: CelebornColumnStats, + override val columnType: NativeCelebornColumnType[T]) + extends CelebornBasicColumnBuilder[T#InternalType](columnStats, columnType) + with CelebornNullableColumnBuilder + with AllCelebornCompressionSchemes + with CelebornCompressibleColumnBuilder[T] + +class CelebornBooleanColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornBooleanColumnStats, CELEBORN_BOOLEAN) + +class CelebornByteColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornByteColumnStats, CELEBORN_BYTE) + +class CelebornShortColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornShortColumnStats, CELEBORN_SHORT) + +class CelebornIntColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornIntColumnStats, CELEBORN_INT) + +class CelebornLongColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornLongColumnStats, CELEBORN_LONG) + +class CelebornFloatColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornFloatColumnStats, CELEBORN_FLOAT) + +class CelebornDoubleColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornDoubleColumnStats, CELEBORN_DOUBLE) + +class CelebornStringColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornStringColumnStats, CELEBORN_STRING) + +class CelebornCompactMiniDecimalColumnBuilder(dataType: DecimalType) + extends CelebornNativeColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_COMPACT_MINI_DECIMAL(dataType)) + +class CelebornCompactDecimalColumnBuilder(dataType: DecimalType) + extends CelebornNativeColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_COMPACT_DECIMAL(dataType)) + +class CelebornDecimalColumnBuilder(dataType: DecimalType) + extends CelebornComplexColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_LARGE_DECIMAL(dataType)) + +class CelebornBooleanCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornBooleanColumnStats, CELEBORN_BOOLEAN) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_BOOLEAN.actualSize(row, ordinal)) + CELEBORN_BOOLEAN.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornByteCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornByteColumnStats, CELEBORN_BYTE) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_BYTE.actualSize(row, ordinal)) + CELEBORN_BYTE.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornShortCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornShortColumnStats, CELEBORN_SHORT) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_SHORT.actualSize(row, ordinal)) + CELEBORN_SHORT.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornIntCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornIntColumnStats, CELEBORN_INT) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_INT.actualSize(row, ordinal)) + CELEBORN_INT.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornLongCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornLongColumnStats, CELEBORN_LONG) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_LONG.actualSize(row, ordinal)) + CELEBORN_LONG.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornFloatCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornFloatColumnStats, CELEBORN_FLOAT) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_FLOAT.actualSize(row, ordinal)) + CELEBORN_FLOAT.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornDoubleCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornDoubleColumnStats, CELEBORN_DOUBLE) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_DOUBLE.actualSize(row, ordinal)) + CELEBORN_DOUBLE.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornStringCodeGenColumnBuilder + extends CelebornNativeColumnBuilder(new CelebornStringColumnStats, CELEBORN_STRING) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_STRING.actualSize(row, ordinal)) + CELEBORN_STRING.append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornCompactDecimalCodeGenColumnBuilder(dataType: DecimalType) + extends CelebornNativeColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_COMPACT_DECIMAL(dataType)) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_COMPACT_DECIMAL(dataType).actualSize(row, ordinal)) + CELEBORN_COMPACT_DECIMAL(dataType).append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornCompactMiniDecimalCodeGenColumnBuilder(dataType: DecimalType) + extends CelebornNativeColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_COMPACT_MINI_DECIMAL(dataType)) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = + ensureFreeSpace(buffer, CELEBORN_COMPACT_MINI_DECIMAL(dataType).actualSize(row, ordinal)) + CELEBORN_COMPACT_MINI_DECIMAL(dataType).append(row, ordinal, buffer) + } + pos += 1 + } +} + +class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType) + extends CelebornComplexColumnBuilder( + new CelebornDecimalColumnStats(dataType), + CELEBORN_LARGE_DECIMAL(dataType)) { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + buffer = ensureFreeSpace(buffer, CELEBORN_LARGE_DECIMAL(dataType).actualSize(row, ordinal)) + CELEBORN_LARGE_DECIMAL(dataType).append(row, ordinal, buffer) + } + pos += 1 + } +} + +object CelebornColumnBuilder { + + def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = { + if (orig.remaining >= size) { + orig + } else { + // grow in steps of initial size + val capacity = orig.capacity() + val newSize = capacity + size.max(capacity) + val pos = orig.position() + + ByteBuffer + .allocate(newSize) + .order(ByteOrder.nativeOrder()) + .put(orig.array(), 0, pos) + } + } + + def apply( + dataType: DataType, + rowCnt: Int, + columnName: String, + encodingEnabled: Boolean, + encoder: Encoder[_ <: PhysicalDataType]): CelebornColumnBuilder = { + val builder: CelebornColumnBuilder = dataType match { + case ByteType => new CelebornByteColumnBuilder + case BooleanType => new CelebornBooleanColumnBuilder + case ShortType => new CelebornShortColumnBuilder + case IntegerType => + val builder = new CelebornIntColumnBuilder + builder.init(encoder.asInstanceOf[Encoder[PhysicalIntegerType.type]]) + builder + case LongType => + val builder = new CelebornLongColumnBuilder + builder.init(encoder.asInstanceOf[Encoder[PhysicalLongType.type]]) + builder + case FloatType => new CelebornFloatColumnBuilder + case DoubleType => new CelebornDoubleColumnBuilder + case StringType => + val builder = new CelebornStringColumnBuilder + builder.init(encoder.asInstanceOf[Encoder[PhysicalStringType]]) + builder + case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => + new CelebornCompactMiniDecimalColumnBuilder(dt) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CelebornCompactDecimalColumnBuilder(dt) + case dt: DecimalType => new CelebornDecimalColumnBuilder(dt) + case other => + throw new Exception(s"Unsupported type: $other") + } + + builder.initialize(rowCnt, columnName, encodingEnabled) + builder + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala new file mode 100644 index 00000000000..b0b9f61db9a --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Used to collect statistical information when building in-memory columns. + * + * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` + * brings significant performance penalty. + */ +sealed private[columnar] trait CelebornColumnStats extends Serializable { + protected var count = 0 + protected var nullCount = 0 + private[columnar] var sizeInBytes = 0L + + /** + * Gathers statistics information from `row(ordinal)`. + */ + def gatherStats(row: InternalRow, ordinal: Int): Unit + + /** + * Gathers statistics information on `null`. + */ + def gatherNullStats(): Unit = { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 + count += 1 + } + + /** + * Column statistics represented as an array, currently including closed lower bound, closed + * upper bound and null count. + */ + def collectedStatistics: Array[Any] +} + +final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnStats { + protected var upper = false + protected var lower = true + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getBoolean(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Boolean): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_BOOLEAN.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornByteColumnStats extends CelebornColumnStats { + protected var upper = Byte.MinValue + protected var lower = Byte.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getByte(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Byte): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_BYTE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornShortColumnStats extends CelebornColumnStats { + protected var upper = Short.MinValue + protected var lower = Short.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getShort(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Short): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_SHORT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getInt(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Int): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_INT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornLongColumnStats extends CelebornColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getLong(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Long): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_LONG.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornFloatColumnStats extends CelebornColumnStats { + protected var upper = Float.MinValue + protected var lower = Float.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getFloat(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Float): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_FLOAT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getDouble(ordinal) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Double): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += CELEBORN_DOUBLE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats { + protected var upper: UTF8String = _ + protected var lower: UTF8String = _ + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getUTF8String(ordinal) + val size = CELEBORN_STRING.actualSize(row, ordinal) + gatherValueStats(value, size) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: UTF8String, size: Int): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + sizeInBytes += size + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int) + extends CelebornColumnStats { + def this(dt: DecimalType) = this(dt.precision, dt.scale) + + protected var upper: Decimal = _ + protected var lower: Decimal = _ + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + val value = row.getDecimal(ordinal, precision, scale) + gatherValueStats(value) + } else { + gatherNullStats() + } + } + + def gatherValueStats(value: Decimal): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + if (precision <= Decimal.MAX_INT_DIGITS) { + sizeInBytes += 4 + } else if (precision <= Decimal.MAX_LONG_DIGITS) { + sizeInBytes += 8 + } else { + sizeInBytes += (4 + value.toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1) + } + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala new file mode 100644 index 00000000000..706078684c5 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala @@ -0,0 +1,663 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.UTF8String + +/** + * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. + * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * WARNING: This only works with HeapByteBuffer + */ +private[columnar] object ByteBufferHelper { + def getShort(buffer: ByteBuffer): Short = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getInt(buffer: ByteBuffer): Int = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getLong(buffer: ByteBuffer): Long = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getFloat(buffer: ByteBuffer): Float = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getDouble(buffer: ByteBuffer): Double = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def putShort(buffer: ByteBuffer, value: Short): Unit = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putInt(buffer: ByteBuffer, value: Int): Unit = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putLong(buffer: ByteBuffer, value: Long): Unit = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = { + val srcPos = src.position() + val dstPos = dst.position() + src.position(srcPos + len) + dst.position(dstPos + len) + Platform.copyMemory( + src.array(), + Platform.BYTE_ARRAY_OFFSET + srcPos, + dst.array(), + Platform.BYTE_ARRAY_OFFSET + dstPos, + len) + } +} + +/** + * An abstract class that represents type of a column. Used to append/extract Java objects into/from + * the underlying [[ByteBuffer]] of a column. + * + * @tparam JvmType Underlying Java type to represent the elements. + */ +sealed abstract private[columnar] class CelebornColumnType[JvmType] { + + // The catalyst physical data type of this column. + def dataType: PhysicalDataType + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int + + /** + * Extracts a value out of the buffer at the buffer's current position. + */ + def extract(buffer: ByteBuffer): JvmType + + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + + /** + * Appends the given value v of type T into the given ByteBuffer. + */ + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } + + /** + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. + */ + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize + + /** + * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs + * whenever possible. + */ + def getField(row: InternalRow, ordinal: Int): JvmType + + /** + * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing + * costs whenever possible. + */ + def setField(row: InternalRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + /** + * Creates a duplicated copy of the value. + */ + def clone(v: JvmType): JvmType = v + + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} + +abstract private[columnar] class NativeCelebornColumnType[T <: PhysicalDataType]( + val dataType: T, + val defaultSize: Int) + extends CelebornColumnType[T#InternalType] {} + +private[columnar] object CELEBORN_INT extends NativeCelebornColumnType(PhysicalIntegerType, 4) { + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + + override def extract(buffer: ByteBuffer): Int = { + ByteBufferHelper.getInt(buffer) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) + } + + override def setField(row: InternalRow, ordinal: Int, value: Int): Unit = { + row.setInt(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_LONG extends NativeCelebornColumnType(PhysicalLongType, 8) { + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + + override def extract(buffer: ByteBuffer): Long = { + ByteBufferHelper.getLong(buffer) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } + + override def setField(row: InternalRow, ordinal: Int, value: Long): Unit = { + row.setLong(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_FLOAT extends NativeCelebornColumnType(PhysicalFloatType, 4) { + override def append(v: Float, buffer: ByteBuffer): Unit = { + buffer.putFloat(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + + override def extract(buffer: ByteBuffer): Float = { + ByteBufferHelper.getFloat(buffer) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) + } + + override def setField(row: InternalRow, ordinal: Int, value: Float): Unit = { + row.setFloat(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_DOUBLE extends NativeCelebornColumnType(PhysicalDoubleType, 8) { + override def append(v: Double, buffer: ByteBuffer): Unit = { + buffer.putDouble(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + + override def extract(buffer: ByteBuffer): Double = { + ByteBufferHelper.getDouble(buffer) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) + } + + override def setField(row: InternalRow, ordinal: Int, value: Double): Unit = { + row.setDouble(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_BOOLEAN extends NativeCelebornColumnType(PhysicalBooleanType, 1) { + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) + } + + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: InternalRow, ordinal: Int, value: Boolean): Unit = { + row.setBoolean(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_BYTE extends NativeCelebornColumnType(PhysicalByteType, 1) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { + buffer.put(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + + override def extract(buffer: ByteBuffer): Byte = { + buffer.get() + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: InternalRow, ordinal: Int, value: Byte): Unit = { + row.setByte(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } +} + +private[columnar] object CELEBORN_SHORT extends NativeCelebornColumnType(PhysicalShortType, 2) { + override def append(v: Short, buffer: ByteBuffer): Unit = { + buffer.putShort(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + + override def extract(buffer: ByteBuffer): Short = { + buffer.getShort() + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: InternalRow, ordinal: Int, value: Short): Unit = { + row.setShort(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } +} + +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[columnar] trait DirectCopyCelebornColumnType[JvmType] extends CelebornColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + row match { + case r: MutableUnsafeRow => + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + r.writer.write( + ordinal, + buffer.array(), + buffer.arrayOffset() + cursor, + numBytes) + case _ => + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + row match { + case r: UnsafeRow => + r.writeFieldTo(ordinal, buffer) + case _ => + super.append(row, ordinal, buffer) + } + } +} + +private[columnar] object CELEBORN_STRING + extends NativeCelebornColumnType( + PhysicalStringType(SqlApiConf.get.defaultStringType.collationId), + 8) + with DirectCopyCelebornColumnType[UTF8String] { + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getUTF8String(ordinal).numBytes() + 4 + } + + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + buffer.putInt(v.numBytes()) + v.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UTF8String = { + val length = buffer.getInt() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) + } + + override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { + row match { + case r: MutableUnsafeRow => + r.writer.write(ordinal, value) + case _ => + row.update(ordinal, value.clone()) + } + } + + override def getField(row: InternalRow, ordinal: Int): UTF8String = { + row.getUTF8String(ordinal) + } + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + override def clone(v: UTF8String): UTF8String = v.clone() + +} + +private[columnar] case class CELEBORN_COMPACT_DECIMAL(precision: Int, scale: Int) + extends NativeCelebornColumnType(PhysicalDecimalType(precision, scale), 8) { + + override def extract(buffer: ByteBuffer): Decimal = { + Decimal(ByteBufferHelper.getLong(buffer), precision, scale) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + override def append(v: Decimal, buffer: ByteBuffer): Unit = { + buffer.putLong(v.toUnscaledLong) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } +} + +private[columnar] case class CELEBORN_COMPACT_MINI_DECIMAL(precision: Int, scale: Int) + extends NativeCelebornColumnType(PhysicalDecimalType(precision, scale), 4) { + + override def extract(buffer: ByteBuffer): Decimal = { + Decimal(ByteBufferHelper.getInt(buffer), precision, scale) + } + + override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + override def append(v: Decimal, buffer: ByteBuffer): Unit = { + buffer.putInt(v.toInt) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putInt(row.getInt(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def copyField( + from: InternalRow, + fromOrdinal: Int, + to: InternalRow, + toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } +} + +private[columnar] object CELEBORN_COMPACT_DECIMAL { + def apply(dt: DecimalType): CELEBORN_COMPACT_DECIMAL = { + CELEBORN_COMPACT_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] object CELEBORN_COMPACT_MINI_DECIMAL { + def apply(dt: DecimalType): CELEBORN_COMPACT_MINI_DECIMAL = { + CELEBORN_COMPACT_MINI_DECIMAL(dt.precision, dt.scale) + } +} + +sealed abstract private[columnar] class ByteArrayCelebornColumnType[JvmType](val defaultSize: Int) + extends CelebornColumnType[JvmType] with DirectCopyCelebornColumnType[JvmType] { + + def serialize(value: JvmType): Array[Byte] + def deserialize(bytes: Array[Byte]): JvmType + + override def append(v: JvmType, buffer: ByteBuffer): Unit = { + val bytes = serialize(v) + buffer.putInt(bytes.length).put(bytes, 0, bytes.length) + } + + override def extract(buffer: ByteBuffer): JvmType = { + val length = buffer.getInt() + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + deserialize(bytes) + } +} + +private[columnar] case class CELEBORN_LARGE_DECIMAL(precision: Int, scale: Int) + extends ByteArrayCelebornColumnType[Decimal](12) { + + override val dataType: PhysicalDecimalType = PhysicalDecimalType(precision, scale) + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + + override def serialize(value: Decimal): Array[Byte] = { + value.toJavaBigDecimal.unscaledValue().toByteArray + } + + override def deserialize(bytes: Array[Byte]): Decimal = { + val javaDecimal = new BigDecimal(new BigInteger(bytes), scale) + Decimal.apply(javaDecimal, precision, scale) + } +} + +private[columnar] object CELEBORN_LARGE_DECIMAL { + def apply(dt: DecimalType): CELEBORN_LARGE_DECIMAL = { + CELEBORN_LARGE_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] object CelebornColumnType { + def apply(dataType: DataType): CelebornColumnType[_] = { + dataType match { + case BooleanType => CELEBORN_BOOLEAN + case ByteType => CELEBORN_BYTE + case ShortType => CELEBORN_SHORT + case IntegerType => CELEBORN_INT + case LongType => CELEBORN_LONG + case FloatType => CELEBORN_FLOAT + case DoubleType => CELEBORN_DOUBLE + case StringType => CELEBORN_STRING + case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => + CELEBORN_COMPACT_MINI_DECIMAL(dt) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + CELEBORN_COMPACT_DECIMAL(dt) + case dt: DecimalType => CELEBORN_LARGE_DECIMAL(dt) + case other => throw new Exception(s"Unsupported type: ${other.catalogString}") + } + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala new file mode 100644 index 00000000000..4685e997736 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.ByteArrayOutputStream + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.types._ + +class CelebornColumnarBatchBuilder( + schema: StructType, + batchSize: Int = 0, + maxDictFactor: Double, + encodingEnabled: Boolean = false) extends CelebornBatchBuilder { + var rowCnt = 0 + + private val typeConversion + : PartialFunction[DataType, NativeCelebornColumnType[_ <: PhysicalDataType]] = { + case IntegerType => CELEBORN_INT + case LongType => CELEBORN_LONG + case StringType => CELEBORN_STRING + case BooleanType => CELEBORN_BOOLEAN + case ShortType => CELEBORN_SHORT + case ByteType => CELEBORN_BYTE + case FloatType => CELEBORN_FLOAT + case DoubleType => CELEBORN_DOUBLE + case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => + CELEBORN_COMPACT_MINI_DECIMAL(dt) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => CELEBORN_COMPACT_DECIMAL(dt) + case _ => null + } + + private val encodersArr: Array[Encoder[_ <: PhysicalDataType]] = schema.map { attribute => + val nativeColumnType = typeConversion(attribute.dataType) + if (nativeColumnType == null) { + null + } else { + if (encodingEnabled && CelebornDictionaryEncoding.supports(nativeColumnType)) { + CelebornDictionaryEncoding.MAX_DICT_SIZE = + Math.min(Short.MaxValue, batchSize * maxDictFactor).toShort + CelebornDictionaryEncoding.encoder(nativeColumnType) + } else { + CelebornPassThrough.encoder(nativeColumnType) + } + } + }.toArray + + var columnBuilders: Array[CelebornColumnBuilder] = _ + + def newBuilders(): Unit = { + rowCnt = 0 + var i = -1 + columnBuilders = schema.map { attribute => + i += 1 + encodersArr(i) match { + case encoder: CelebornDictionaryEncoding.CelebornEncoder[_] if !encoder.overflow => + encoder.cleanBatch + case _ => + } + CelebornColumnBuilder( + attribute.dataType, + batchSize, + attribute.name, + encodingEnabled, + encodersArr(i)) + }.toArray + } + + def buildColumnBytes(): Array[Byte] = { + val giantBuffer = new ByteArrayOutputStream + val rowCntBytes = int2ByteArray(rowCnt) + giantBuffer.write(rowCntBytes) + val builderLen = columnBuilders.length + var i = 0 + while (i < builderLen) { + val builder = columnBuilders(i) + val buffers = builder.build() + val bytes = JavaUtils.bufferToArray(buffers) + val columnBuilderBytes = int2ByteArray(bytes.length) + giantBuffer.write(columnBuilderBytes) + giantBuffer.write(bytes) + i += 1 + } + giantBuffer.toByteArray + } + + def writeRow(row: InternalRow): Unit = { + var i = 0 + while (i < row.numFields) { + columnBuilders(i).appendFrom(row, i) + i += 1 + } + rowCnt += 1 + } + + def getRowCnt: Int = rowCnt +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala new file mode 100644 index 00000000000..e510e645259 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +import scala.collection.mutable + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection.newCodeGenContext +import org.apache.spark.sql.types._ +import org.slf4j.LoggerFactory + +class CelebornColumnarBatchCodeGenBuild { + + private val logger = LoggerFactory.getLogger(classOf[CelebornColumnarBatchCodeGenBuild]) + + def create(schema: StructType, batchSize: Int): CelebornBatchBuilder = { + val ctx = newCodeGenContext() + val codes = genCode(schema, batchSize) + val codeBody = + s""" + | + |public java.lang.Object generate(Object[] references) { + | return new SpecificCelebornColumnarBatchBuilder(references); + |} + | + |class SpecificCelebornColumnarBatchBuilder extends ${classOf[CelebornBatchBuilder].getName} { + | + | private Object[] references; + | int rowCnt = 0; + | ${codes._1} + | + | public SpecificCelebornColumnarBatchBuilder(Object[] references) { + | this.references = references; + | } + | + | public void newBuilders() throws Exception { + | rowCnt = 0; + | ${codes._2} + | } + | + | public byte[] buildColumnBytes() throws Exception { + | ${classOf[ByteArrayOutputStream].getName} giantBuffer = new ${classOf[ + ByteArrayOutputStream].getName}(); + | byte[] rowCntBytes = int2ByteArray(rowCnt); + | giantBuffer.write(rowCntBytes); + | ${codes._3} + | return giantBuffer.toByteArray(); + | } + | + | public void writeRow(InternalRow row) throws Exception { + | ${codes._4} + | rowCnt += 1; + | } + | + | public int getRowCnt() { + | return rowCnt; + | } + |} + """.stripMargin + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + + logger.debug(s"\n${CodeFormatter.format(code)}") + + val (clazz, _) = CodeGenerator.compile(code) + + clazz.generate(ctx.references.toArray).asInstanceOf[CelebornBatchBuilder] + } + + /** + * @return (code to define columnBuilder, code to instantiate columnBuilder, + * code to build column bytes, code to write row to columnBuilder) + */ + def genCode(schema: StructType, batchSize: Int): ( + mutable.StringBuilder, + mutable.StringBuilder, + mutable.StringBuilder, + mutable.StringBuilder) = { + val initCode = new mutable.StringBuilder() + val buildCode = new mutable.StringBuilder() + val writeCode = new mutable.StringBuilder() + val writeRowCode = new mutable.StringBuilder() + for (index <- schema.indices) { + schema.fields(index).dataType match { + case ByteType => + initCode.append( + s""" + | ${classOf[CelebornByteCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornByteCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case BooleanType => + initCode.append( + s""" + | ${classOf[CelebornBooleanCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornBooleanCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case ShortType => + initCode.append( + s""" + | ${classOf[CelebornShortCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornShortCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case IntegerType => + initCode.append( + s""" + | ${classOf[CelebornIntCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornIntCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case LongType => + initCode.append( + s""" + | ${classOf[CelebornLongCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornLongCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case FloatType => + initCode.append( + s""" + | ${classOf[CelebornFloatCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornFloatCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case DoubleType => + initCode.append( + s""" + | ${classOf[CelebornDoubleCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornDoubleCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case StringType => + initCode.append( + s""" + | ${classOf[CelebornStringCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornStringCodeGenColumnBuilder].getName}(); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS => + initCode.append( + s""" + | ${classOf[CelebornCompactMiniDecimalCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = + | new ${classOf[CelebornCompactMiniDecimalCodeGenColumnBuilder].getName}( + | new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale})); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + initCode.append( + s""" + | ${classOf[CelebornCompactDecimalCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = + | new ${classOf[CelebornCompactDecimalCodeGenColumnBuilder].getName} + | (new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale})); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + case dt: DecimalType => + initCode.append( + s""" + | ${classOf[CelebornDecimalCodeGenColumnBuilder].getName} b$index; + """.stripMargin) + buildCode.append( + s""" + | b$index = new ${classOf[CelebornDecimalCodeGenColumnBuilder].getName} + | (new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale})); + | b$index.initialize($batchSize, "${schema.fields(index).name}", false); + """.stripMargin) + writeCode.append(genWriteCode(index)) + writeRowCode.append( + s""" + | b$index.appendFrom(row, $index); + """.stripMargin) + } + } + (initCode, buildCode, writeCode, writeRowCode) + } + + def genWriteCode(index: Int): String = { + s""" + | ${classOf[ByteBuffer].getName} buffers$index = b$index.build(); + | byte[] bytes$index = ${classOf[JavaUtils].getName}.bufferToArray(buffers$index); + | byte[] columnBuilderBytes$index = int2ByteArray(bytes$index.length); + | giantBuffer.write(columnBuilderBytes$index); + | giantBuffer.write(bytes$index); + """.stripMargin + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala new file mode 100644 index 00000000000..f9c08a0f6a1 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io._ +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +class CelebornColumnarBatchSerializer( + schema: StructType, + offHeapColumnVectorEnabled: Boolean, + dataSize: SQLMetric = null) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = + new CelebornColumnarBatchSerializerInstance( + schema, + offHeapColumnVectorEnabled, + dataSize) + override def supportsRelocationOfSerializedObjects: Boolean = true +} + +class CelebornColumnarBatchSerializerInstance( + schema: StructType, + offHeapColumnVectorEnabled: Boolean, + dataSize: SQLMetric) extends SerializerInstance { + + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + private[this] val dOut: DataOutputStream = + new DataOutputStream(new BufferedOutputStream(out)) + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val row = value.asInstanceOf[UnsafeRow] + if (dataSize != null) { + dataSize.add(row.getSizeInBytes) + } + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(dOut, writeBuffer) + this + } + + override def writeKey[T: ClassTag](key: T): SerializationStream = { + assert(null == key || key.isInstanceOf[Int]) + this + } + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + throw new UnsupportedOperationException + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + throw new UnsupportedOperationException + } + + override def flush(): Unit = { + dOut.flush() + } + + override def close(): Unit = { + writeBuffer = null + dOut.close() + } + } + + private val toUnsafe: UnsafeProjection = + UnsafeProjection.create(schema.fields.map(f => f.dataType)) + + override def deserializeStream(in: InputStream): DeserializationStream = { + val numFields = schema.fields.length + new DeserializationStream { + val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) + val EOF: Int = -1 + var colBuffer: Array[Byte] = new Array[Byte](1024) + var numRows: Int = readSize() + var rowIter: Iterator[InternalRow] = if (numRows != EOF) nextBatch() else Iterator.empty + + override def asKeyValueIterator: Iterator[(Int, InternalRow)] = { + new Iterator[(Int, InternalRow)] { + + override def hasNext: Boolean = rowIter.hasNext || { + if (numRows != EOF) { + rowIter = nextBatch() + true + } else { + false + } + } + + override def next(): (Int, InternalRow) = { + (0, rowIter.next()) + } + } + } + + override def asIterator: Iterator[Any] = { + throw new UnsupportedOperationException + } + + override def readObject[T: ClassTag](): T = { + throw new UnsupportedOperationException + } + + def nextBatch(): Iterator[InternalRow] = { + val columnVectors = + if (!offHeapColumnVectorEnabled) { + OnHeapColumnVector.allocateColumns(numRows, schema) + } else { + OffHeapColumnVector.allocateColumns(numRows, schema) + } + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) + columnarBatch.setNumRows(numRows) + + for (i <- 0 until numFields) { + val colLen: Int = readSize() + if (colBuffer.length < colLen) { + colBuffer = new Array[Byte](colLen) + } + ByteStreams.readFully(dIn, colBuffer, 0, colLen) + CelebornColumnAccessor.decompress( + colBuffer, + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + schema.fields(i).dataType, + numRows) + } + numRows = readSize() + columnarBatch.rowIterator().asScala.map(toUnsafe) + } + + def readSize(): Int = + try { + dIn.readInt() + } catch { + case _: EOFException => + dIn.close() + EOF + } + + override def close(): Unit = { + dIn.close() + } + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala new file mode 100644 index 00000000000..5d633221271 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.execution.vectorized.WritableColumnVector + +trait CelebornCompressibleColumnAccessor[T <: PhysicalDataType] extends CelebornColumnAccessor { + this: CelebornNativeColumnAccessor[T] => + + private var decoder: Decoder[T] = _ + + abstract override protected def initialize(): Unit = { + super.initialize() + decoder = CelebornCompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) + } + + abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext + + override def extractSingle(row: InternalRow, ordinal: Int): Unit = { + decoder.next(row, ordinal) + } + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = + decoder.decompress(columnVector, capacity) +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala new file mode 100644 index 00000000000..e721f262dc0 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.unsafe.Platform + +trait CelebornCompressibleColumnBuilder[T <: PhysicalDataType] + extends CelebornColumnBuilder with Logging { + + this: CelebornNativeColumnBuilder[T] with WithCelebornCompressionSchemes => + + private var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType) + + def init(encoder: Encoder[T]): Unit = { + compressionEncoder = encoder + } + + abstract override def initialize( + rowCnt: Int, + columnName: String, + encodingEnabled: Boolean): Unit = { + super.initialize(rowCnt, columnName, encodingEnabled) + } + + // The various compression schemes, while saving memory use, cause all of the data within + // the row to become unaligned, thus causing crashes. Until a way of fixing the compression + // is found to also allow aligned accesses this must be disabled for SPARK. + + protected def isWorthCompressing(encoder: Encoder[T]) = { + CelebornCompressibleColumnBuilder.unaligned && encoder.compressionRatio < 0.8 + } + + private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { + compressionEncoder.gatherCompressibilityStats(row, ordinal) + } + + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + super.appendFrom(row, ordinal) + if (!row.isNullAt(ordinal)) { + gatherCompressibilityStats(row, ordinal) + } + } + + override def build(): ByteBuffer = { + val nonNullBuffer = buildNonNulls() + val encoder: Encoder[T] = { + if (isWorthCompressing(compressionEncoder)) compressionEncoder + else CelebornPassThrough.encoder(columnType) + } + + // Header = null count + null positions + val headerSize = 4 + nulls.limit() + val compressedSize = + if (encoder.compressedSize == 0) { + nonNullBuffer.remaining() + } else { + encoder.compressedSize + } + + val compressedBuffer = ByteBuffer + // Reserves 4 bytes for compression scheme ID + .allocate(headerSize + 4 + compressedSize) + .order(ByteOrder.nativeOrder) + // Write the header + .putInt(nullCount) + .put(nulls) + + logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer) + } + + override def getTotalSize: Long = { + val encoder: Encoder[T] = { + if (isWorthCompressing(compressionEncoder)) compressionEncoder + else CelebornPassThrough.encoder(columnType) + } + if (encoder.compressedSize == 0) { + 4 + 4 + columnStats.sizeInBytes + } else { + 4 + 4 * nullCount + 4 + encoder.compressedSize + } + } +} + +object CelebornCompressibleColumnBuilder { + val unaligned = Platform.unaligned() +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala new file mode 100644 index 00000000000..7e25956674c --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.execution.vectorized.WritableColumnVector + +trait Encoder[T <: PhysicalDataType] { + def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} + + def compressedSize: Int + + def uncompressedSize: Int + + def compressionRatio: Double = { + if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 + } + + def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer +} + +trait Decoder[T <: PhysicalDataType] { + def next(row: InternalRow, ordinal: Int): Unit + + def hasNext: Boolean + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit +} + +trait CelebornCompressionScheme { + def typeId: Int + + def supports(columnType: CelebornColumnType[_]): Boolean + + def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T]): Encoder[T] + + def decoder[T <: PhysicalDataType]( + buffer: ByteBuffer, + columnType: NativeCelebornColumnType[T]): Decoder[T] +} + +trait WithCelebornCompressionSchemes { + def schemes: Seq[CelebornCompressionScheme] +} + +trait AllCelebornCompressionSchemes extends WithCelebornCompressionSchemes { + override val schemes: Seq[CelebornCompressionScheme] = CelebornCompressionScheme.all +} + +object CelebornCompressionScheme { + val all: Seq[CelebornCompressionScheme] = + Seq(CelebornPassThrough, CelebornDictionaryEncoding) + + private val typeIdToScheme = all.map(scheme => scheme.typeId -> scheme).toMap + + def apply(typeId: Int): CelebornCompressionScheme = { + typeIdToScheme.getOrElse( + typeId, + throw new UnsupportedOperationException(s"Unrecognized compression scheme type ID: $typeId")) + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala new file mode 100644 index 00000000000..9b17626c314 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.ByteBuffer +import java.nio.ByteOrder + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case object CelebornPassThrough extends CelebornCompressionScheme { + override val typeId = 0 + + override def supports(columnType: CelebornColumnType[_]): Boolean = true + + override def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T]) + : Encoder[T] = { + new this.CelebornEncoder[T]() + } + + override def decoder[T <: PhysicalDataType]( + buffer: ByteBuffer, + columnType: NativeCelebornColumnType[T]): Decoder[T] = { + new this.CelebornDecoder(buffer, columnType) + } + + class CelebornEncoder[T <: PhysicalDataType]() + extends Encoder[T] { + override def uncompressedSize: Int = 0 + + override def compressedSize: Int = 0 + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + // Writes compression type ID and copies raw contents + to.putInt(CelebornPassThrough.typeId).put(from).rewind() + to + } + } + + class CelebornDecoder[T <: PhysicalDataType]( + buffer: ByteBuffer, + columnType: NativeCelebornColumnType[T]) + extends Decoder[T] { + + override def next(row: InternalRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } + + override def hasNext: Boolean = buffer.hasRemaining + + private def putBooleans( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + for (i <- 0 until len) { + columnVector.putBoolean(pos + i, buffer.get(bufferPos + i) != 0) + } + } + + private def putBytes( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putBytes(pos, len, buffer.array, bufferPos) + } + + private def putShorts( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putShorts(pos, len, buffer.array, bufferPos) + } + + private def putInts( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putInts(pos, len, buffer.array, bufferPos) + } + + private def putLongs( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putLongs(pos, len, buffer.array, bufferPos) + } + + private def putFloats( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putFloats(pos, len, buffer.array, bufferPos) + } + + private def putDoubles( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putDoubles(pos, len, buffer.array, bufferPos) + } + + private def putByteArray( + columnVector: WritableColumnVector, + pos: Int, + bufferPos: Int, + len: Int): Unit = { + columnVector.putByteArray(pos, buffer.array, bufferPos, len) + } + + private def decompressPrimitive( + columnVector: WritableColumnVector, + rowCnt: Int, + unitSize: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else rowCnt + var valueIndex = 0 + var seenNulls = 0 + var bufferPos = buffer.position() + while (valueIndex < rowCnt) { + if (valueIndex != nextNullIndex) { + val len = nextNullIndex - valueIndex + assert(len * unitSize.toLong < Int.MaxValue) + putFunction(columnVector, valueIndex, bufferPos, len) + bufferPos += len * unitSize + valueIndex += len + } else { + seenNulls += 1 + nextNullIndex = + if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + rowCnt + } + columnVector.putNull(valueIndex) + valueIndex += 1 + } + } + } + + private def decompressString( + columnVector: WritableColumnVector, + rowCnt: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else rowCnt + var valueIndex = 0 + var seenNulls = 0 + while (valueIndex < rowCnt) { + if (valueIndex != nextNullIndex) { + val len = nextNullIndex - valueIndex + for (index <- valueIndex until nextNullIndex) { + val length = buffer.getInt() + val cursor = buffer.position() + buffer.position(cursor + length) + putFunction(columnVector, index, buffer.arrayOffset() + cursor, length) + } + valueIndex += len + } else { + seenNulls += 1 + nextNullIndex = + if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + rowCnt + } + columnVector.putNull(valueIndex) + valueIndex += 1 + } + } + } + + private def decompressDecimal( + columnVector: WritableColumnVector, + rowCnt: Int, + precision: Int): Unit = { + if (precision <= Decimal.MAX_INT_DIGITS) decompressPrimitive(columnVector, rowCnt, 4, putInts) + else if (precision <= Decimal.MAX_LONG_DIGITS) { + decompressPrimitive(columnVector, rowCnt, 8, putLongs) + } else { + decompressString(columnVector, rowCnt, putByteArray) + } + } + + override def decompress(columnVector: WritableColumnVector, rowCnt: Int): Unit = { + columnType.dataType match { + case _: PhysicalBooleanType => + val unitSize = 1 + decompressPrimitive(columnVector, rowCnt, unitSize, putBooleans) + case _: PhysicalByteType => + val unitSize = 1 + decompressPrimitive(columnVector, rowCnt, unitSize, putBytes) + case _: PhysicalShortType => + val unitSize = 2 + decompressPrimitive(columnVector, rowCnt, unitSize, putShorts) + case _: PhysicalIntegerType => + val unitSize = 4 + decompressPrimitive(columnVector, rowCnt, unitSize, putInts) + case _: PhysicalLongType => + val unitSize = 8 + decompressPrimitive(columnVector, rowCnt, unitSize, putLongs) + case _: PhysicalFloatType => + val unitSize = 4 + decompressPrimitive(columnVector, rowCnt, unitSize, putFloats) + case _: PhysicalDoubleType => + val unitSize = 8 + decompressPrimitive(columnVector, rowCnt, unitSize, putDoubles) + case _: PhysicalStringType => + decompressString(columnVector, rowCnt, putByteArray) + case d: PhysicalDecimalType => + decompressDecimal(columnVector, rowCnt, d.precision) + } + } + } +} + +case object CelebornDictionaryEncoding extends CelebornCompressionScheme { + override val typeId = 1 + + // 32K unique values allowed + var MAX_DICT_SIZE: Short = Short.MaxValue + + override def decoder[T <: PhysicalDataType]( + buffer: ByteBuffer, + columnType: NativeCelebornColumnType[T]): Decoder[T] = { + new this.CelebornDecoder(buffer, columnType) + } + + override def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T]) + : Encoder[T] = { + new this.CelebornEncoder[T](columnType) + } + + override def supports(columnType: CelebornColumnType[_]): Boolean = columnType match { + case CELEBORN_INT | CELEBORN_LONG | CELEBORN_FLOAT | CELEBORN_DOUBLE | CELEBORN_STRING => true + case _ => false + } + + class CelebornEncoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T]) + extends Encoder[T] { + // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary + // overflows. + private var _uncompressedSize = 0 + + // If the number of distinct elements is too large, we discard the use of dictionary encoding + // and set the overflow flag to true. + var overflow = false + + // Total number of elements. + private var count = 0 + + def cleanBatch(): Unit = { + count = 0 + _uncompressedSize = 0 + } + + // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. + private val values = new mutable.ArrayBuffer[T#InternalType](1024) + + // The dictionary that maps a value to the encoded short integer. + private val dictionary = new java.util.HashMap[Any, Short](1024) + + // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int` + // to store dictionary element count. + private var dictionarySize = 4 + + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { + if (!overflow) { + val value = columnType.getField(row, ordinal) + val actualSize = columnType.actualSize(row, ordinal) + count += 1 + _uncompressedSize += actualSize + if (!dictionary.containsKey(value)) { + if (dictionary.size < MAX_DICT_SIZE) { + val clone = columnType.clone(value) + values += clone + dictionarySize += actualSize + dictionary.put(clone, dictionary.size.toShort) + } else { + overflow = true + values.clear() + dictionary.clear() + } + } + } + } + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + to.putInt(CelebornDictionaryEncoding.typeId) + .putInt(dictionary.size) + + var i = 0 + while (i < values.length) { + columnType.append(values(i), to) + i += 1 + } + + while (from.hasRemaining) { + to.putShort(dictionary.get(columnType.extract(from))) + } + + to.rewind() + to + } + + override def uncompressedSize: Int = _uncompressedSize + + // 2 is the data size after(short type) dictionary encoding + override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 + } + + class CelebornDecoder[T <: PhysicalDataType]( + buffer: ByteBuffer, + columnType: NativeCelebornColumnType[T]) + extends Decoder[T] { + private val elementNum: Int = ByteBufferHelper.getInt(buffer) + private val dictionary: Array[Any] = new Array[Any](elementNum) + private var intDictionary: Array[Int] = _ + private var longDictionary: Array[Long] = _ + private var floatDictionary: Array[Float] = _ + private var doubleDictionary: Array[Double] = _ + private var stringDictionary: Array[String] = _ + + columnType.dataType match { + case _: PhysicalIntegerType => + intDictionary = new Array[Int](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Int] + intDictionary(i) = v + dictionary(i) = v + } + case _: PhysicalLongType => + longDictionary = new Array[Long](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Long] + longDictionary(i) = v + dictionary(i) = v + } + case _: PhysicalFloatType => + floatDictionary = new Array[Float](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Float] + floatDictionary(i) = v + dictionary(i) = v + } + case _: PhysicalDoubleType => + doubleDictionary = new Array[Double](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Double] + doubleDictionary(i) = v + dictionary(i) = v + } + case _: PhysicalStringType => + stringDictionary = new Array[String](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[UTF8String] + stringDictionary(i) = v.toString + dictionary(i) = v + } + } + + override def next(row: InternalRow, ordinal: Int): Unit = { + columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) + } + + override def hasNext: Boolean = buffer.hasRemaining + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnType.dataType match { + case _: PhysicalIntegerType => + columnVector.setDictionary(new CelebornColumnDictionary(intDictionary)) + case _: PhysicalLongType => + columnVector.setDictionary(new CelebornColumnDictionary(longDictionary)) + case _: PhysicalFloatType => + columnVector.setDictionary(new CelebornColumnDictionary(floatDictionary)) + case _: PhysicalDoubleType => + columnVector.setDictionary(new CelebornColumnDictionary(doubleDictionary)) + case _: PhysicalStringType => + columnVector.setDictionary(new CelebornColumnDictionary(stringDictionary)) + case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.") + } + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala new file mode 100644 index 00000000000..c2a721b58cc --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.WritableColumnVector + +trait CelebornNullableColumnAccessor extends CelebornColumnAccessor { + private var nullsBuffer: ByteBuffer = _ + private var nullCount: Int = _ + private var seenNulls: Int = 0 + + private var nextNullIndex: Int = _ + private var pos: Int = 0 + + abstract override protected def initialize(): Unit = { + nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) + nullCount = ByteBufferHelper.getInt(nullsBuffer) + nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + pos = 0 + + underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4) + super.initialize() + } + + abstract override def extractTo(row: InternalRow, ordinal: Int): Unit = { + if (pos == nextNullIndex) { + seenNulls += 1 + + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + + row.setNullAt(ordinal) + } else { + super.extractTo(row, ordinal) + } + + pos += 1 + } + + abstract override def extractToColumnVector( + columnVector: WritableColumnVector, + ordinal: Int): Unit = { + if (pos == nextNullIndex) { + seenNulls += 1 + + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + + columnVector.putNull(ordinal) + } else { + super.extractToColumnVector(columnVector, ordinal) + } + + pos += 1 + } + + abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext +} diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala new file mode 100644 index 00000000000..320df89f0a2 --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow + +trait CelebornNullableColumnBuilder extends CelebornColumnBuilder { + protected var nulls: ByteBuffer = _ + protected var nullCount: Int = _ + var pos: Int = _ + + abstract override def initialize( + rowCnt: Int, + columnName: String, + encodingEnabled: Boolean): Unit = { + + nulls = ByteBuffer.allocate(1024) + nulls.order(ByteOrder.nativeOrder()) + pos = 0 + nullCount = 0 + super.initialize(rowCnt, columnName, encodingEnabled) + } + + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4) + nulls.putInt(pos) + nullCount += 1 + } else { + super.appendFrom(row, ordinal) + } + pos += 1 + } + + abstract override def build(): ByteBuffer = { + val nonNulls = super.build() + val nullDataLen = nulls.position() + + nulls.limit(nullDataLen) + nulls.rewind() + + val buffer = ByteBuffer + .allocate(4 + nullDataLen + nonNulls.remaining()) + .order(ByteOrder.nativeOrder()) + .putInt(nullCount) + .put(nulls) + .put(nonNulls) + + buffer.rewind() + buffer + } + + protected def buildNonNulls(): ByteBuffer = { + nulls.limit(nulls.position()).rewind() + super.build() + } + + override def getTotalSize: Long = { + 4 + columnStats.sizeInBytes + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java b/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java new file mode 100644 index 00000000000..92aee552fba --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.util.UUID; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.TaskContext; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.sql.execution.UnsafeRowSerializer; +import org.apache.spark.sql.execution.columnar.CelebornBatchBuilder; +import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializer; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import org.apache.celeborn.client.DummyShuffleClient; +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.common.CelebornConf; + +public class ColumnarHashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase { + + private final StructType schema = + new StructType().add("key", IntegerType$.MODULE$).add("value", StringType$.MODULE$); + + @Test + public void createColumnarShuffleWriter() throws Exception { + Mockito.doReturn(new HashPartitioner(numPartitions)).when(dependency).partitioner(); + final CelebornConf conf = new CelebornConf(); + final File tempFile = new File(tempDir, UUID.randomUUID().toString()); + final DummyShuffleClient client = new DummyShuffleClient(conf, tempFile); + client.initReducePartitionMap(shuffleId, numPartitions, 1); + + // Create ColumnarHashBasedShuffleWriter with handle of which dependency has null schema. + Mockito.doReturn(new KryoSerializer(sparkConf)).when(dependency).serializer(); + ShuffleWriter writer = + createShuffleWriterWithoutSchema( + new CelebornShuffleHandle<>( + "appId", "host", 0, this.userIdentifier, 0, false, 10, this.dependency), + taskContext, + conf, + client, + metrics.shuffleWriteMetrics()); + assertTrue(writer instanceof ColumnarHashBasedShuffleWriter); + assertFalse(((ColumnarHashBasedShuffleWriter) writer).isColumnarShuffle()); + + // Create ColumnarHashBasedShuffleWriter with handle of which dependency has non-null schema. + Mockito.doReturn(new UnsafeRowSerializer(2, null)).when(dependency).serializer(); + writer = + createShuffleWriter( + new CelebornShuffleHandle<>( + "appId", "host", 0, this.userIdentifier, 0, false, 10, this.dependency), + taskContext, + conf, + client, + metrics.shuffleWriteMetrics()); + assertTrue(((ColumnarHashBasedShuffleWriter) writer).isColumnarShuffle()); + } + + @Override + protected SerializerInstance newSerializerInstance(Serializer serializer) { + if (serializer instanceof UnsafeRowSerializer + && CelebornBatchBuilder.supportsColumnarType(schema)) { + CelebornConf conf = new CelebornConf(); + return new CelebornColumnarBatchSerializer(schema, conf.columnarShuffleOffHeapEnabled(), null) + .newInstance(); + } else { + return serializer.newInstance(); + } + } + + @Override + protected ShuffleWriter createShuffleWriter( + CelebornShuffleHandle handle, + TaskContext context, + CelebornConf conf, + ShuffleClient client, + ShuffleWriteMetricsReporter metrics) { + try (MockedStatic utils = + Mockito.mockStatic(CustomShuffleDependencyUtils.class)) { + utils + .when(() -> CustomShuffleDependencyUtils.getSchema(handle.dependency())) + .thenReturn(schema); + return SparkUtils.createColumnarHashBasedShuffleWriter( + SparkUtils.celebornShuffleId(client, handle, context, true), + handle, + context, + conf, + client, + metrics, + SendBufferPool.get(1, 30, 60)); + } + } + + private ShuffleWriter createShuffleWriterWithoutSchema( + CelebornShuffleHandle handle, + TaskContext context, + CelebornConf conf, + ShuffleClient client, + ShuffleWriteMetricsReporter metrics) { + return SparkUtils.createColumnarHashBasedShuffleWriter( + SparkUtils.celebornShuffleId(client, handle, context, true), + handle, + context, + conf, + client, + metrics, + SendBufferPool.get(1, 30, 60)); + } +} diff --git a/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml b/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml new file mode 100644 index 00000000000..9adcdccfd0e --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala new file mode 100644 index 00000000000..d0f4462be3e --- /dev/null +++ b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} +import org.apache.spark.sql.execution.UnsafeRowSerializer +import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.junit.Test +import org.mockito.{MockedStatic, Mockito} + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier + +class CelebornColumnarShuffleReaderSuite { + + @Test + def createColumnarShuffleReader(): Unit = { + val handle = new CelebornShuffleHandle[Int, String, String]( + "appId", + "host", + 0, + new UserIdentifier("mock", "mock"), + 0, + false, + 10, + null) + + var shuffleClient: MockedStatic[ShuffleClient] = null + try { + val taskContext = Mockito.mock(classOf[TaskContext]) + Mockito.when(taskContext.stageAttemptNumber).thenReturn(0) + Mockito.when(taskContext.attemptNumber).thenReturn(0) + shuffleClient = Mockito.mockStatic(classOf[ShuffleClient]) + val shuffleReader = SparkUtils.createColumnarShuffleReader( + handle, + 0, + 10, + 0, + 10, + taskContext, + new CelebornConf(), + null, + new ExecutorShuffleIdTracker()) + assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) + } finally { + if (shuffleClient != null) { + shuffleClient.close() + } + } + } + + @Test + def columnarShuffleReaderNewSerializerInstance(): Unit = { + var shuffleClient: MockedStatic[ShuffleClient] = null + try { + val taskContext = Mockito.mock(classOf[TaskContext]) + Mockito.when(taskContext.stageAttemptNumber).thenReturn(0) + Mockito.when(taskContext.attemptNumber).thenReturn(0) + shuffleClient = Mockito.mockStatic(classOf[ShuffleClient]) + val shuffleReader = SparkUtils.createColumnarShuffleReader( + new CelebornShuffleHandle[Int, String, String]( + "appId", + "host", + 0, + new UserIdentifier("mock", "mock"), + 0, + false, + 10, + null), + 0, + 10, + 0, + 10, + taskContext, + new CelebornConf(), + null, + new ExecutorShuffleIdTracker()) + val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) + Mockito.when(shuffleDependency.shuffleId).thenReturn(0) + Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( + new SparkConf(false))) + + // CelebornColumnarShuffleReader creates new serializer instance with dependency which has null schema. + var serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency) + assert(serializerInstance.getClass == classOf[KryoSerializerInstance]) + + // CelebornColumnarShuffleReader creates new serializer instance with dependency which has non-null schema. + val dependencyUtils = Mockito.mockStatic(classOf[CustomShuffleDependencyUtils]) + try { + dependencyUtils.when(() => + CustomShuffleDependencyUtils.getSchema(shuffleDependency)).thenReturn( + new StructType().add( + "key", + IntegerType).add("value", StringType)) + Mockito.when(shuffleDependency.serializer).thenReturn(new UnsafeRowSerializer(2, null)) + serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency) + assert(serializerInstance.getClass == classOf[CelebornColumnarBatchSerializerInstance]) + } finally { + // The registration of CustomShuffleDependencyUtils static mocking must be deregistered. + dependencyUtils.close() + } + } finally { + if (shuffleClient != null) { + shuffleClient.close() + } + } + } +} diff --git a/client-spark/spark-4-shaded/pom.xml b/client-spark/spark-4-shaded/pom.xml new file mode 100644 index 00000000000..52ce16618b4 --- /dev/null +++ b/client-spark/spark-4-shaded/pom.xml @@ -0,0 +1,143 @@ + + + + 4.0.0 + + org.apache.celeborn + celeborn-parent_${scala.binary.version} + ${project.version} + ../../pom.xml + + + celeborn-client-spark-4-shaded_${scala.binary.version} + jar + Celeborn Shaded Client for Spark 4 + + + + org.apache.celeborn + celeborn-client-spark-3-4_${scala.binary.version} + ${project.version} + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + + com.google.protobuf + ${shading.prefix}.com.google.protobuf + + + com.google.common + ${shading.prefix}.com.google.common + + + com.google.thirdparty + ${shading.prefix}.com.google.thirdparty + + + io.netty + ${shading.prefix}.io.netty + + + org.apache.commons + ${shading.prefix}.org.apache.commons + + + org.roaringbitmap + ${shading.prefix}.org.roaringbitmap + + + + + org.apache.celeborn:* + com.google.protobuf:protobuf-java + com.google.guava:guava + com.google.guava:failureaccess + io.netty:* + org.apache.commons:commons-lang3 + org.roaringbitmap:RoaringBitmap + + + + + *:* + + **/*.proto + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + **/log4j.properties + META-INF/LICENSE.txt + META-INF/NOTICE.txt + LICENSE.txt + NOTICE.txt + + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + ${maven.plugin.antrun.version} + + + rename-native-library + + run + + package + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE b/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE new file mode 100644 index 00000000000..f96af536a1b --- /dev/null +++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE @@ -0,0 +1,248 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +------------------------------------------------------------------------------------ +This project bundles the following dependencies under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt): + + +Apache License 2.0 +-------------------------------------- + +com.google.guava:failureaccess +com.google.guava:guava +io.netty:netty-all +io.netty:netty-buffer +io.netty:netty-codec +io.netty:netty-codec-dns +io.netty:netty-codec-haproxy +io.netty:netty-codec-http +io.netty:netty-codec-http2 +io.netty:netty-codec-memcache +io.netty:netty-codec-mqtt +io.netty:netty-codec-redis +io.netty:netty-codec-smtp +io.netty:netty-codec-socks +io.netty:netty-codec-stomp +io.netty:netty-codec-xml +io.netty:netty-common +io.netty:netty-handler +io.netty:netty-handler-proxy +io.netty:netty-resolver +io.netty:netty-resolver-dns +io.netty:netty-transport +io.netty:netty-transport-classes-epoll +io.netty:netty-transport-classes-kqueue +io.netty:netty-transport-native-epoll +io.netty:netty-transport-native-kqueue +io.netty:netty-transport-native-unix-common +io.netty:netty-transport-rxtx +io.netty:netty-transport-sctp +io.netty:netty-transport-udt +org.apache.commons:commons-lang3 +org.roaringbitmap:RoaringBitmap + + +BSD 3-clause +------------ +See licenses/LICENSE-protobuf.txt for details. +com.google.protobuf:protobuf-java \ No newline at end of file diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE b/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE new file mode 100644 index 00000000000..c48952d00d9 --- /dev/null +++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE @@ -0,0 +1,46 @@ + +Apache Celeborn +Copyright 2022-2024 The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (https://www.apache.org/). + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation + +Apache Kyuubi +Copyright 2021-2023 The Apache Software Foundation + +Apache Iceberg +Copyright 2017-2022 The Apache Software Foundation + +Apache Parquet MR +Copyright 2014-2024 The Apache Software Foundation + +This project includes code from Kite, developed at Cloudera, Inc. with +the following copyright notice: + +| Copyright 2013 Cloudera Inc. +| +| Licensed under the Apache License, Version 2.0 (the "License"); +| you may not use this file except in compliance with the License. +| You may obtain a copy of the License at +| +| http://www.apache.org/licenses/LICENSE-2.0 +| +| Unless required by applicable law or agreed to in writing, software +| distributed under the License is distributed on an "AS IS" BASIS, +| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +| See the License for the specific language governing permissions and +| limitations under the License. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2021 The Apache Software Foundation + diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt b/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt new file mode 100644 index 00000000000..b4350ec83c7 --- /dev/null +++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt @@ -0,0 +1,42 @@ +This license applies to all parts of Protocol Buffers except the following: + + - Atomicops support for generic gcc, located in + src/google/protobuf/stubs/atomicops_internals_generic_gcc.h. + This file is copyrighted by Red Hat Inc. + + - Atomicops support for AIX/POWER, located in + src/google/protobuf/stubs/atomicops_internals_aix.h. + This file is copyrighted by Bloomberg Finance LP. + +Copyright 2014, Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. \ No newline at end of file diff --git a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala index 7391d0529ae..684aba222df 100644 --- a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala +++ b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala @@ -487,7 +487,7 @@ abstract class AbstractSource(conf: CelebornConf, role: String) sum } - override def getMetrics(): String = { + override def getMetrics: String = { var leftMetricsNum = metricsCapacity val sb = new mutable.StringBuilder leftMetricsNum = fillInnerMetricsSnapshot(getAndClearTimerMetrics(), leftMetricsNum, sb) diff --git a/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala b/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala index 1644e4e87b5..0f776f8df09 100644 --- a/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala @@ -33,7 +33,7 @@ class CelebornSourceSuite extends CelebornFunSuite { for (i <- 1 to 100) { mockSource.updateHistogram(histogram, 10) } - val res = mockSource.getMetrics() + val res = mockSource.getMetrics assert(res.contains("metrics_abc_Count")) } @@ -70,7 +70,7 @@ class CelebornSourceSuite extends CelebornFunSuite { mockSource.stopTimer("Timer1", "key1") mockSource.stopTimer("Timer2", "key2", user3) - val res = mockSource.getMetrics() + val res = mockSource.getMetrics var extraLabelsStr = extraLabels if (extraLabels.nonEmpty) { extraLabelsStr = extraLabels + "," diff --git a/dev/dependencies.sh b/dev/dependencies.sh index 625b12edf2f..9dd8930ce4c 100755 --- a/dev/dependencies.sh +++ b/dev/dependencies.sh @@ -169,7 +169,7 @@ case "$MODULE" in SBT_PROJECT="celeborn-client-spark-2" ;; "spark-3"*) # Match all versions starting with "spark-3" - MVN_MODULES="client-spark/spark-3" + MVN_MODULES="client-spark/spark-3-4" SBT_PROJECT="celeborn-client-spark-3" ;; "flink-1.14") diff --git a/dev/reformat b/dev/reformat index 17f834ee21e..85c4cde0869 100755 --- a/dev/reformat +++ b/dev/reformat @@ -32,6 +32,8 @@ else ${PROJECT_DIR}/build/mvn spotless:apply -Pflink-1.20 ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-2.4 ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-3.3 + ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-3.5 + ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-4.0 ${PROJECT_DIR}/build/mvn spotless:apply -Paws ${PROJECT_DIR}/build/mvn spotless:apply -Pmr ${PROJECT_DIR}/build/mvn spotless:apply -Ptez diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala index d6d8c5f0fb2..fe3cc6eb791 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala @@ -151,8 +151,8 @@ class WorkerResource extends ApiRequestContext { new HandleResponse().success(success).message(finalMsg) } - private def toWorkerEventType(enum: EventTypeEnum): WorkerEventType = { - enum match { + private def toWorkerEventType(eunmObj: EventTypeEnum): WorkerEventType = { + eunmObj match { case EventTypeEnum.NONE => WorkerEventType.None case EventTypeEnum.IMMEDIATELY => WorkerEventType.Immediately case EventTypeEnum.DECOMMISSION => WorkerEventType.Decommission diff --git a/pom.xml b/pom.xml index bdbd1f0c441..b8f2cfb5510 100644 --- a/pom.xml +++ b/pom.xml @@ -991,7 +991,6 @@ ${maven.plugin.scala.version} - -Ywarn-unused-import -unchecked -deprecation -feature @@ -1373,7 +1372,7 @@ spark-3.0 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3-columnar-shuffle client-spark/spark-3-shaded @@ -1393,7 +1392,7 @@ spark-3.1 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3-columnar-shuffle client-spark/spark-3-shaded @@ -1413,7 +1412,7 @@ spark-3.2 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3-columnar-shuffle client-spark/spark-3-shaded @@ -1432,7 +1431,7 @@ spark-3.3 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3-columnar-shuffle client-spark/spark-3-shaded @@ -1451,7 +1450,7 @@ spark-3.4 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3-columnar-shuffle client-spark/spark-3-shaded @@ -1470,7 +1469,7 @@ spark-3.5 client-spark/common - client-spark/spark-3 + client-spark/spark-3-4 client-spark/spark-3-columnar-common client-spark/spark-3.5-columnar-shuffle client-spark/spark-3-shaded @@ -1485,6 +1484,25 @@ + + spark-4.0 + + client-spark/common + client-spark/spark-3-4 + client-spark/spark-3-columnar-common + client-spark/spark-4-columnar-shuffle + client-spark/spark-4-shaded + tests/spark-it + + + 1.8.0 + 2.13.11 + 2.13 + 4.0.0-preview2 + 1.5.5-6 + + + jdk-8 diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index b51e29d061d..45aa86c5c4b 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -226,7 +226,8 @@ object CelebornCommonSettings { val SCALA_2_12_17 = "2.12.17" val SCALA_2_12_18 = "2.12.18" val scala213 = "2.13.5" - val ALL_SCALA_VERSIONS = Seq(SCALA_2_11_12, SCALA_2_12_10, SCALA_2_12_15, SCALA_2_12_17, SCALA_2_12_18, scala213) + val scala213_11 = "2.13.11" + val ALL_SCALA_VERSIONS = Seq(SCALA_2_11_12, SCALA_2_12_10, SCALA_2_12_15, SCALA_2_12_17, SCALA_2_12_18, scala213, scala213_11) val DEFAULT_SCALA_VERSION = SCALA_2_12_18 @@ -421,6 +422,7 @@ object Utils { case Some("spark-3.3") => Some(Spark33) case Some("spark-3.4") => Some(Spark34) case Some("spark-3.5") => Some(Spark35) + case Some("spark-4.0") => Some(Spark40) case _ => None } @@ -725,7 +727,7 @@ object Spark24 extends SparkClientProjects { object Spark30 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -739,7 +741,7 @@ object Spark30 extends SparkClientProjects { object Spark31 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -753,7 +755,7 @@ object Spark31 extends SparkClientProjects { object Spark32 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -767,7 +769,7 @@ object Spark32 extends SparkClientProjects { object Spark33 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -784,7 +786,7 @@ object Spark33 extends SparkClientProjects { object Spark34 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -798,7 +800,7 @@ object Spark34 extends SparkClientProjects { object Spark35 extends SparkClientProjects { - val sparkClientProjectPath = "client-spark/spark-3" + val sparkClientProjectPath = "client-spark/spark-3-4" val sparkClientProjectName = "celeborn-client-spark-3" val sparkClientShadedProjectPath = "client-spark/spark-3-shaded" val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded" @@ -812,6 +814,23 @@ object Spark35 extends SparkClientProjects { override val sparkColumnarShuffleVersion: String = "3.5" } +object Spark40 extends SparkClientProjects { + + val sparkClientProjectPath = "client-spark/spark-3-4" + val sparkClientProjectName = "celeborn-client-spark-4" + val sparkClientShadedProjectPath = "client-spark/spark-4-shaded" + val sparkClientShadedProjectName = "celeborn-client-spark-4-shaded" + + val lz4JavaVersion = "1.8.0" + val sparkProjectScalaVersion = "2.13.11" + + val sparkVersion = "4.0.0-preview2" + val zstdJniVersion = "1.5.5-6" + val scalaBinaryVersion = "2.13" + + override val sparkColumnarShuffleVersion: String = "4" +} + trait SparkClientProjects { val sparkClientProjectPath: String diff --git a/tests/spark-it/pom.xml b/tests/spark-it/pom.xml index 842d7a8ed98..d66aed131ad 100644 --- a/tests/spark-it/pom.xml +++ b/tests/spark-it/pom.xml @@ -115,13 +115,13 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test @@ -133,13 +133,13 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test @@ -151,13 +151,13 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test @@ -169,13 +169,13 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test @@ -187,13 +187,13 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test @@ -205,13 +205,31 @@ org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test org.apache.celeborn - celeborn-client-spark-3_${scala.binary.version} + celeborn-client-spark-3-4_${scala.binary.version} + ${project.version} + test-jar + test + + + + + spark-4.0 + + + org.apache.celeborn + celeborn-client-spark-3-4_${scala.binary.version} + ${project.version} + test + + + org.apache.celeborn + celeborn-client-spark-3-4_${scala.binary.version} ${project.version} test-jar test diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index f460e768e3d..30a12535b59 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -690,7 +690,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler pushMergedDataCallback.onSuccess(StatusCode.SUCCESS) } } - case None => + case _ => if (replicaReason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) { pushMergedDataCallback.onSuccess( StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED) @@ -791,7 +791,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } else { pushMergedDataCallback.onSuccess(StatusCode.SUCCESS) } - case None => + case _ => pushMergedDataCallback.onSuccess(StatusCode.SUCCESS) } case Failure(e) => pushMergedDataCallback.onFailure(e)