diff --git a/.github/workflows/maven.yml b/.github/workflows/maven.yml
index 583a1c4dac3..7aeb0e45c38 100644
--- a/.github/workflows/maven.yml
+++ b/.github/workflows/maven.yml
@@ -133,8 +133,12 @@ jobs:
run: |
SPARK_BINARY_VERSION=${{ matrix.spark }}
SPARK_MAJOR_VERSION=${SPARK_BINARY_VERSION%%.*}
+ SPARK_MODULE_NAME=$SPARK_MAJOR_VERSION
+ if [[ $SPARK_MAJOR_VERSION == "3" ]]; then
+ SPARK_MODULE_NAME="3-4"
+ fi
PROFILES="-Pgoogle-mirror,spark-${{ matrix.spark }}"
- TEST_MODULES="client-spark/common,client-spark/spark-${SPARK_MAJOR_VERSION},client-spark/spark-${SPARK_MAJOR_VERSION}-shaded,tests/spark-it"
+ TEST_MODULES="client-spark/common,client-spark/spark-${SPARK_MODULE_NAME},client-spark/spark-${SPARK_MAJOR_VERSION}-columnar-common,client-spark/spark-${SPARK_MAJOR_VERSION}-shaded,tests/spark-it"
build/mvn $PROFILES -pl $TEST_MODULES -am clean install -DskipTests
build/mvn $PROFILES -pl $TEST_MODULES -Dspark.shuffle.sort.io.plugin.class=${{ matrix.shuffle-plugin-class }} test
- name: Upload test log
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
index f00fe063e13..9b959a4b895 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -19,14 +19,12 @@
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
-import org.apache.spark.scheduler.DAGScheduler;
public class SparkCommonUtils {
public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException {
+ int DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4;
int maxStageAttempts =
- conf.getInt(
- "spark.stage.maxConsecutiveAttempts",
- DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
+ conf.getInt("spark.stage.maxConsecutiveAttempts", DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS);
// In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in Spark 3, it has been
// changed to TASK_MAX_FAILURES. The default value for both is consistently set to 4.
int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
diff --git a/client-spark/spark-3/pom.xml b/client-spark/spark-3-4/pom.xml
similarity index 96%
rename from client-spark/spark-3/pom.xml
rename to client-spark/spark-3-4/pom.xml
index 4acacbed0b7..8ddfd6f2267 100644
--- a/client-spark/spark-3/pom.xml
+++ b/client-spark/spark-3-4/pom.xml
@@ -24,9 +24,9 @@
../../pom.xml
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
jar
- Celeborn Client for Spark 3
+ Celeborn Client for Spark 3 and 4
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java
similarity index 100%
rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java
rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/CelebornShuffleDataIO.java
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
similarity index 100%
rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
similarity index 100%
rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
similarity index 100%
rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
similarity index 100%
rename from client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
rename to client-spark/spark-3-4/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/SparkVersionUtil.scala
similarity index 100%
rename from client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala
rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/SparkVersionUtil.scala
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala
similarity index 100%
rename from client-spark/spark-3/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala
rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/celeborn/ExceptionMakerHelper.scala
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
similarity index 100%
rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
similarity index 100%
rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
similarity index 100%
rename from client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
rename to client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
similarity index 100%
rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
similarity index 100%
rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
similarity index 100%
rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
similarity index 100%
rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java b/client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
similarity index 100%
rename from client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
rename to client-spark/spark-3-4/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
diff --git a/client-spark/spark-3/src/test/resources/log4j.properties b/client-spark/spark-3-4/src/test/resources/log4j.properties
similarity index 100%
rename from client-spark/spark-3/src/test/resources/log4j.properties
rename to client-spark/spark-3-4/src/test/resources/log4j.properties
diff --git a/client-spark/spark-3/src/test/resources/log4j2-test.xml b/client-spark/spark-3-4/src/test/resources/log4j2-test.xml
similarity index 100%
rename from client-spark/spark-3/src/test/resources/log4j2-test.xml
rename to client-spark/spark-3-4/src/test/resources/log4j2-test.xml
diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala b/client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
similarity index 100%
rename from client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
rename to client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
diff --git a/client-spark/spark-3-columnar-common/pom.xml b/client-spark/spark-3-columnar-common/pom.xml
index 6b4b3c280a1..39a36633e4b 100644
--- a/client-spark/spark-3-columnar-common/pom.xml
+++ b/client-spark/spark-3-columnar-common/pom.xml
@@ -31,7 +31,7 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
diff --git a/client-spark/spark-3-columnar-shuffle/pom.xml b/client-spark/spark-3-columnar-shuffle/pom.xml
index 5f082443738..0a628cc3d33 100644
--- a/client-spark/spark-3-columnar-shuffle/pom.xml
+++ b/client-spark/spark-3-columnar-shuffle/pom.xml
@@ -49,7 +49,7 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml
index d3d59cb87a8..9e2c921ab12 100644
--- a/client-spark/spark-3-shaded/pom.xml
+++ b/client-spark/spark-3-shaded/pom.xml
@@ -31,7 +31,7 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
diff --git a/client-spark/spark-4-columnar-shuffle/pom.xml b/client-spark/spark-4-columnar-shuffle/pom.xml
new file mode 100644
index 00000000000..bd806f2fb9b
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/pom.xml
@@ -0,0 +1,68 @@
+
+
+
+ 4.0.0
+
+ org.apache.celeborn
+ celeborn-parent_${scala.binary.version}
+ ${project.version}
+ ../../pom.xml
+
+
+ celeborn-spark-4-columnar-shuffle_${scala.binary.version}
+ jar
+ Celeborn Client for Spark 4 Columnar Shuffle
+
+
+
+ org.apache.celeborn
+ celeborn-spark-3-columnar-common_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.celeborn
+ celeborn-client_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.celeborn
+ celeborn-client-spark-3_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+ org.mockito
+ mockito-inline
+ test
+
+
+
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java b/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
new file mode 100644
index 00000000000..b09b1306c90
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import scala.Product2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.TaskContext;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.execution.UnsafeRowSerializer;
+import org.apache.spark.sql.execution.columnar.CelebornBatchBuilder;
+import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchBuilder;
+import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild;
+import org.apache.spark.sql.execution.metric.SQLMetric;
+import org.apache.spark.sql.types.StructType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+@Private
+public class ColumnarHashBasedShuffleWriter extends HashBasedShuffleWriter {
+
+ private static final Logger logger =
+ LoggerFactory.getLogger(ColumnarHashBasedShuffleWriter.class);
+
+ private final int stageId;
+ private final int shuffleId;
+ private final CelebornBatchBuilder[] celebornBatchBuilders;
+ private final StructType schema;
+ private final Serializer depSerializer;
+ private final boolean isColumnarShuffle;
+ private final int columnarShuffleBatchSize;
+ private final boolean columnarShuffleCodeGenEnabled;
+ private final boolean columnarShuffleDictionaryEnabled;
+ private final double columnarShuffleDictionaryMaxFactor;
+
+ public ColumnarHashBasedShuffleWriter(
+ int shuffleId,
+ CelebornShuffleHandle handle,
+ TaskContext taskContext,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics,
+ SendBufferPool sendBufferPool)
+ throws IOException {
+ super(shuffleId, handle, taskContext, conf, client, metrics, sendBufferPool);
+ columnarShuffleBatchSize = conf.columnarShuffleBatchSize();
+ columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled();
+ columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled();
+ columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor();
+ ShuffleDependency, ?, ?> shuffleDependency = handle.dependency();
+ this.stageId = taskContext.stageId();
+ this.shuffleId = shuffleDependency.shuffleId();
+ this.schema = CustomShuffleDependencyUtils.getSchema(shuffleDependency);
+ this.depSerializer = handle.dependency().serializer();
+ this.celebornBatchBuilders =
+ new CelebornBatchBuilder[handle.dependency().partitioner().numPartitions()];
+ this.isColumnarShuffle = schema != null && CelebornBatchBuilder.supportsColumnarType(schema);
+ }
+
+ @Override
+ protected void fastWrite0(scala.collection.Iterator iterator)
+ throws IOException, InterruptedException {
+ if (isColumnarShuffle) {
+ logger.info("Fast columnar write of columnar shuffle {} for stage {}.", shuffleId, stageId);
+ fastColumnarWrite0(iterator);
+ } else {
+ super.fastWrite0(iterator);
+ }
+ }
+
+ private void fastColumnarWrite0(scala.collection.Iterator iterator) throws IOException {
+ final scala.collection.Iterator> records = iterator;
+
+ SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) depSerializer);
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final int partitionId = record._1();
+ final UnsafeRow row = record._2();
+
+ if (celebornBatchBuilders[partitionId] == null) {
+ CelebornBatchBuilder columnBuilders;
+ if (columnarShuffleCodeGenEnabled && !columnarShuffleDictionaryEnabled) {
+ columnBuilders =
+ new CelebornColumnarBatchCodeGenBuild().create(schema, columnarShuffleBatchSize);
+ } else {
+ columnBuilders =
+ new CelebornColumnarBatchBuilder(
+ schema,
+ columnarShuffleBatchSize,
+ columnarShuffleDictionaryMaxFactor,
+ columnarShuffleDictionaryEnabled);
+ }
+ columnBuilders.newBuilders();
+ celebornBatchBuilders[partitionId] = columnBuilders;
+ }
+
+ celebornBatchBuilders[partitionId].writeRow(row);
+ if (celebornBatchBuilders[partitionId].getRowCnt() >= columnarShuffleBatchSize) {
+ byte[] arr = celebornBatchBuilders[partitionId].buildColumnBytes();
+ pushGiantRecord(partitionId, arr, arr.length);
+ if (dataSize != null) {
+ dataSize.add(arr.length);
+ }
+ celebornBatchBuilders[partitionId].newBuilders();
+ }
+ tmpRecordsWritten++;
+ }
+ }
+
+ @Override
+ protected void closeWrite() throws IOException {
+ if (canUseFastWrite() && isColumnarShuffle) {
+ closeColumnarWrite();
+ } else {
+ super.closeWrite();
+ }
+ }
+
+ private void closeColumnarWrite() throws IOException {
+ SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) depSerializer);
+ for (int i = 0; i < celebornBatchBuilders.length; i++) {
+ final CelebornBatchBuilder builders = celebornBatchBuilders[i];
+ if (builders != null && builders.getRowCnt() > 0) {
+ byte[] buffers = builders.buildColumnBytes();
+ if (dataSize != null) {
+ dataSize.add(buffers.length);
+ }
+ mergeData(i, buffers, 0, buffers.length);
+ // free buffer
+ celebornBatchBuilders[i] = null;
+ }
+ }
+ }
+
+ @VisibleForTesting
+ public boolean isColumnarShuffle() {
+ return isColumnarShuffle;
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
new file mode 100644
index 00000000000..fd888fb9dc1
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn
+
+import org.apache.spark.{ShuffleDependency, TaskContext}
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.sql.execution.UnsafeRowSerializer
+import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer}
+
+import org.apache.celeborn.common.CelebornConf
+
+class CelebornColumnarShuffleReader[K, C](
+ handle: CelebornShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ startMapIndex: Int = 0,
+ endMapIndex: Int = Int.MaxValue,
+ context: TaskContext,
+ conf: CelebornConf,
+ metrics: ShuffleReadMetricsReporter,
+ shuffleIdTracker: ExecutorShuffleIdTracker)
+ extends CelebornShuffleReader[K, C](
+ handle,
+ startPartition,
+ endPartition,
+ startMapIndex,
+ endMapIndex,
+ context,
+ conf,
+ metrics,
+ shuffleIdTracker) {
+
+ override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
+ val schema = CustomShuffleDependencyUtils.getSchema(dep)
+ if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) {
+ logInfo(s"Creating column batch serializer of columnar shuffle ${dep.shuffleId}.")
+ val dataSize = SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer])
+ new CelebornColumnarBatchSerializer(
+ schema,
+ conf.columnarShuffleOffHeapEnabled,
+ dataSize).newInstance()
+ } else {
+ super.newSerializerInstance(dep)
+ }
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
new file mode 100644
index 00000000000..ce003919440
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+import org.apache.spark.sql.types._
+
+trait CelebornColumnAccessor {
+ initialize()
+
+ protected def initialize(): Unit
+
+ def hasNext: Boolean
+
+ def extractTo(row: InternalRow, ordinal: Int): Unit
+
+ def extractToColumnVector(columnVector: WritableColumnVector, ordinal: Int): Unit
+
+ protected def underlyingBuffer: ByteBuffer
+}
+
+abstract class CelebornBasicColumnAccessor[JvmType](
+ protected val buffer: ByteBuffer,
+ protected val columnType: CelebornColumnType[JvmType])
+ extends CelebornColumnAccessor {
+
+ protected def initialize(): Unit = {}
+
+ override def hasNext: Boolean = buffer.hasRemaining
+
+ override def extractTo(row: InternalRow, ordinal: Int): Unit = {
+ extractSingle(row, ordinal)
+ }
+
+ override def extractToColumnVector(columnVector: WritableColumnVector, ordinal: Int): Unit = {
+ val length = buffer.getInt()
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ columnVector.putByteArray(ordinal, bytes)
+ }
+
+ def extractSingle(row: InternalRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
+
+ protected def underlyingBuffer: ByteBuffer = buffer
+}
+
+abstract class CelebornNativeColumnAccessor[T <: PhysicalDataType](
+ override protected val buffer: ByteBuffer,
+ override protected val columnType: NativeCelebornColumnType[T])
+ extends CelebornBasicColumnAccessor(buffer, columnType)
+ with CelebornNullableColumnAccessor
+ with CelebornCompressibleColumnAccessor[T]
+
+class CelebornBooleanColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_BOOLEAN)
+
+class CelebornByteColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_BYTE)
+
+class CelebornShortColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_SHORT)
+
+class CelebornIntColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_INT)
+
+class CelebornLongColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_LONG)
+
+class CelebornFloatColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_FLOAT)
+
+class CelebornDoubleColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_DOUBLE)
+
+class CelebornStringColumnAccessor(buffer: ByteBuffer)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_STRING)
+
+class CelebornCompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
+ extends CelebornNativeColumnAccessor(buffer, CELEBORN_COMPACT_DECIMAL(dataType))
+
+class CelebornDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
+ extends CelebornBasicColumnAccessor[Decimal](buffer, CELEBORN_LARGE_DECIMAL(dataType))
+ with CelebornNullableColumnAccessor
+
+private[sql] object CelebornColumnAccessor {
+
+ def apply(dataType: DataType, buffer: ByteBuffer): CelebornColumnAccessor = {
+ val buf = buffer.order(ByteOrder.nativeOrder)
+
+ dataType match {
+ case BooleanType => new CelebornBooleanColumnAccessor(buf)
+ case ByteType => new CelebornByteColumnAccessor(buf)
+ case ShortType => new CelebornShortColumnAccessor(buf)
+ case IntegerType => new CelebornIntColumnAccessor(buf)
+ case LongType => new CelebornLongColumnAccessor(buf)
+ case FloatType => new CelebornFloatColumnAccessor(buf)
+ case DoubleType => new CelebornDoubleColumnAccessor(buf)
+ case StringType => new CelebornStringColumnAccessor(buf)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ new CelebornCompactDecimalColumnAccessor(buf, dt)
+ case dt: DecimalType => new CelebornDecimalColumnAccessor(buf, dt)
+ case other => throw new Exception(s"not support type: $other")
+ }
+ }
+
+ def decompress(
+ columnAccessor: CelebornColumnAccessor,
+ columnVector: WritableColumnVector,
+ numRows: Int): Unit = {
+ columnAccessor match {
+ case nativeAccessor: CelebornNativeColumnAccessor[_] =>
+ nativeAccessor.decompress(columnVector, numRows)
+ case _: CelebornDecimalColumnAccessor =>
+ (0 until numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _))
+ case _ =>
+ throw new RuntimeException("Not support non-primitive type now")
+ }
+ }
+
+ def decompress(
+ array: Array[Byte],
+ columnVector: WritableColumnVector,
+ dataType: DataType,
+ numRows: Int): Unit = {
+ val byteBuffer = ByteBuffer.wrap(array)
+ val columnAccessor = CelebornColumnAccessor(dataType, byteBuffer)
+ decompress(columnAccessor, columnVector, numRows)
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
new file mode 100644
index 00000000000..7d25449981d
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.columnar.CelebornColumnBuilder.ensureFreeSpace
+import org.apache.spark.sql.types._
+
+trait CelebornColumnBuilder {
+
+ /**
+ * Initializes with an approximate lower bound on the expected number of elements in this column.
+ */
+ def initialize(
+ rowCnt: Int,
+ columnName: String = "",
+ encodingEnabled: Boolean = false): Unit
+
+ /**
+ * Appends `row(ordinal)` to the column builder.
+ */
+ def appendFrom(row: InternalRow, ordinal: Int): Unit
+
+ /**
+ * Column statistics information
+ */
+ def columnStats: CelebornColumnStats
+
+ /**
+ * Returns the final columnar byte buffer.
+ */
+ def build(): ByteBuffer
+
+ def getTotalSize: Long = 0
+}
+
+class CelebornBasicColumnBuilder[JvmType](
+ val columnStats: CelebornColumnStats,
+ val columnType: CelebornColumnType[JvmType])
+ extends CelebornColumnBuilder {
+
+ protected var columnName: String = _
+
+ protected var buffer: ByteBuffer = _
+
+ override def initialize(
+ rowCnt: Int,
+ columnName: String = "",
+ encodingEnabled: Boolean = false): Unit = {
+
+ this.columnName = columnName
+
+ buffer = ByteBuffer.allocate(rowCnt * columnType.defaultSize)
+ buffer.order(ByteOrder.nativeOrder())
+ }
+
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
+ columnType.append(row, ordinal, buffer)
+ }
+
+ override def build(): ByteBuffer = {
+ if (buffer.capacity() > buffer.position() * 1.1) {
+ // trim the buffer
+ buffer = ByteBuffer
+ .allocate(buffer.position())
+ .order(ByteOrder.nativeOrder())
+ .put(buffer.array(), 0, buffer.position())
+ }
+ buffer.flip().asInstanceOf[ByteBuffer]
+ }
+}
+
+abstract class CelebornComplexColumnBuilder[JvmType](
+ columnStats: CelebornColumnStats,
+ columnType: CelebornColumnType[JvmType])
+ extends CelebornBasicColumnBuilder[JvmType](columnStats, columnType)
+ with CelebornNullableColumnBuilder
+
+abstract class CelebornNativeColumnBuilder[T <: PhysicalDataType](
+ override val columnStats: CelebornColumnStats,
+ override val columnType: NativeCelebornColumnType[T])
+ extends CelebornBasicColumnBuilder[T#InternalType](columnStats, columnType)
+ with CelebornNullableColumnBuilder
+ with AllCelebornCompressionSchemes
+ with CelebornCompressibleColumnBuilder[T]
+
+class CelebornBooleanColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornBooleanColumnStats, CELEBORN_BOOLEAN)
+
+class CelebornByteColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornByteColumnStats, CELEBORN_BYTE)
+
+class CelebornShortColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornShortColumnStats, CELEBORN_SHORT)
+
+class CelebornIntColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornIntColumnStats, CELEBORN_INT)
+
+class CelebornLongColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornLongColumnStats, CELEBORN_LONG)
+
+class CelebornFloatColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornFloatColumnStats, CELEBORN_FLOAT)
+
+class CelebornDoubleColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornDoubleColumnStats, CELEBORN_DOUBLE)
+
+class CelebornStringColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornStringColumnStats, CELEBORN_STRING)
+
+class CelebornCompactMiniDecimalColumnBuilder(dataType: DecimalType)
+ extends CelebornNativeColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_COMPACT_MINI_DECIMAL(dataType))
+
+class CelebornCompactDecimalColumnBuilder(dataType: DecimalType)
+ extends CelebornNativeColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_COMPACT_DECIMAL(dataType))
+
+class CelebornDecimalColumnBuilder(dataType: DecimalType)
+ extends CelebornComplexColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_LARGE_DECIMAL(dataType))
+
+class CelebornBooleanCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornBooleanColumnStats, CELEBORN_BOOLEAN) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_BOOLEAN.actualSize(row, ordinal))
+ CELEBORN_BOOLEAN.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornByteCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornByteColumnStats, CELEBORN_BYTE) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_BYTE.actualSize(row, ordinal))
+ CELEBORN_BYTE.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornShortCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornShortColumnStats, CELEBORN_SHORT) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_SHORT.actualSize(row, ordinal))
+ CELEBORN_SHORT.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornIntCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornIntColumnStats, CELEBORN_INT) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_INT.actualSize(row, ordinal))
+ CELEBORN_INT.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornLongCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornLongColumnStats, CELEBORN_LONG) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_LONG.actualSize(row, ordinal))
+ CELEBORN_LONG.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornFloatCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornFloatColumnStats, CELEBORN_FLOAT) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_FLOAT.actualSize(row, ordinal))
+ CELEBORN_FLOAT.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornDoubleCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornDoubleColumnStats, CELEBORN_DOUBLE) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_DOUBLE.actualSize(row, ordinal))
+ CELEBORN_DOUBLE.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornStringCodeGenColumnBuilder
+ extends CelebornNativeColumnBuilder(new CelebornStringColumnStats, CELEBORN_STRING) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_STRING.actualSize(row, ordinal))
+ CELEBORN_STRING.append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornCompactDecimalCodeGenColumnBuilder(dataType: DecimalType)
+ extends CelebornNativeColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_COMPACT_DECIMAL(dataType)) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_COMPACT_DECIMAL(dataType).actualSize(row, ordinal))
+ CELEBORN_COMPACT_DECIMAL(dataType).append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornCompactMiniDecimalCodeGenColumnBuilder(dataType: DecimalType)
+ extends CelebornNativeColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_COMPACT_MINI_DECIMAL(dataType)) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer =
+ ensureFreeSpace(buffer, CELEBORN_COMPACT_MINI_DECIMAL(dataType).actualSize(row, ordinal))
+ CELEBORN_COMPACT_MINI_DECIMAL(dataType).append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+class CelebornDecimalCodeGenColumnBuilder(dataType: DecimalType)
+ extends CelebornComplexColumnBuilder(
+ new CelebornDecimalColumnStats(dataType),
+ CELEBORN_LARGE_DECIMAL(dataType)) {
+ override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ buffer = ensureFreeSpace(buffer, CELEBORN_LARGE_DECIMAL(dataType).actualSize(row, ordinal))
+ CELEBORN_LARGE_DECIMAL(dataType).append(row, ordinal, buffer)
+ }
+ pos += 1
+ }
+}
+
+object CelebornColumnBuilder {
+
+ def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = {
+ if (orig.remaining >= size) {
+ orig
+ } else {
+ // grow in steps of initial size
+ val capacity = orig.capacity()
+ val newSize = capacity + size.max(capacity)
+ val pos = orig.position()
+
+ ByteBuffer
+ .allocate(newSize)
+ .order(ByteOrder.nativeOrder())
+ .put(orig.array(), 0, pos)
+ }
+ }
+
+ def apply(
+ dataType: DataType,
+ rowCnt: Int,
+ columnName: String,
+ encodingEnabled: Boolean,
+ encoder: Encoder[_ <: PhysicalDataType]): CelebornColumnBuilder = {
+ val builder: CelebornColumnBuilder = dataType match {
+ case ByteType => new CelebornByteColumnBuilder
+ case BooleanType => new CelebornBooleanColumnBuilder
+ case ShortType => new CelebornShortColumnBuilder
+ case IntegerType =>
+ val builder = new CelebornIntColumnBuilder
+ builder.init(encoder.asInstanceOf[Encoder[PhysicalIntegerType.type]])
+ builder
+ case LongType =>
+ val builder = new CelebornLongColumnBuilder
+ builder.init(encoder.asInstanceOf[Encoder[PhysicalLongType.type]])
+ builder
+ case FloatType => new CelebornFloatColumnBuilder
+ case DoubleType => new CelebornDoubleColumnBuilder
+ case StringType =>
+ val builder = new CelebornStringColumnBuilder
+ builder.init(encoder.asInstanceOf[Encoder[PhysicalStringType]])
+ builder
+ case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS =>
+ new CelebornCompactMiniDecimalColumnBuilder(dt)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ new CelebornCompactDecimalColumnBuilder(dt)
+ case dt: DecimalType => new CelebornDecimalColumnBuilder(dt)
+ case other =>
+ throw new Exception(s"Unsupported type: $other")
+ }
+
+ builder.initialize(rowCnt, columnName, encodingEnabled)
+ builder
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
new file mode 100644
index 00000000000..b0b9f61db9a
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Used to collect statistical information when building in-memory columns.
+ *
+ * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]`
+ * brings significant performance penalty.
+ */
+sealed private[columnar] trait CelebornColumnStats extends Serializable {
+ protected var count = 0
+ protected var nullCount = 0
+ private[columnar] var sizeInBytes = 0L
+
+ /**
+ * Gathers statistics information from `row(ordinal)`.
+ */
+ def gatherStats(row: InternalRow, ordinal: Int): Unit
+
+ /**
+ * Gathers statistics information on `null`.
+ */
+ def gatherNullStats(): Unit = {
+ nullCount += 1
+ // 4 bytes for null position
+ sizeInBytes += 4
+ count += 1
+ }
+
+ /**
+ * Column statistics represented as an array, currently including closed lower bound, closed
+ * upper bound and null count.
+ */
+ def collectedStatistics: Array[Any]
+}
+
+final private[columnar] class CelebornBooleanColumnStats extends CelebornColumnStats {
+ protected var upper = false
+ protected var lower = true
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getBoolean(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Boolean): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_BOOLEAN.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornByteColumnStats extends CelebornColumnStats {
+ protected var upper = Byte.MinValue
+ protected var lower = Byte.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getByte(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Byte): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_BYTE.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornShortColumnStats extends CelebornColumnStats {
+ protected var upper = Short.MinValue
+ protected var lower = Short.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getShort(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Short): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_SHORT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornIntColumnStats extends CelebornColumnStats {
+ protected var upper = Int.MinValue
+ protected var lower = Int.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getInt(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Int): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_INT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornLongColumnStats extends CelebornColumnStats {
+ protected var upper = Long.MinValue
+ protected var lower = Long.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getLong(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Long): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_LONG.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornFloatColumnStats extends CelebornColumnStats {
+ protected var upper = Float.MinValue
+ protected var lower = Float.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getFloat(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Float): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_FLOAT.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornDoubleColumnStats extends CelebornColumnStats {
+ protected var upper = Double.MinValue
+ protected var lower = Double.MaxValue
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getDouble(ordinal)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Double): Unit = {
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ sizeInBytes += CELEBORN_DOUBLE.defaultSize
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornStringColumnStats extends CelebornColumnStats {
+ protected var upper: UTF8String = _
+ protected var lower: UTF8String = _
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getUTF8String(ordinal)
+ val size = CELEBORN_STRING.actualSize(row, ordinal)
+ gatherValueStats(value, size)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: UTF8String, size: Int): Unit = {
+ if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
+ if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
+ sizeInBytes += size
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
+
+final private[columnar] class CelebornDecimalColumnStats(precision: Int, scale: Int)
+ extends CelebornColumnStats {
+ def this(dt: DecimalType) = this(dt.precision, dt.scale)
+
+ protected var upper: Decimal = _
+ protected var lower: Decimal = _
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getDecimal(ordinal, precision, scale)
+ gatherValueStats(value)
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ def gatherValueStats(value: Decimal): Unit = {
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ if (precision <= Decimal.MAX_INT_DIGITS) {
+ sizeInBytes += 4
+ } else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ sizeInBytes += 8
+ } else {
+ sizeInBytes += (4 + value.toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1)
+ }
+ count += 1
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](lower, upper, nullCount, count, sizeInBytes)
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
new file mode 100644
index 00000000000..706078684c5
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
@@ -0,0 +1,663 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.math.{BigDecimal, BigInteger}
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
+ *
+ * Note: There is not much difference between ByteBuffer.getByte/getShort and
+ * Unsafe.getByte/getShort, so we do not have helper methods for them.
+ *
+ * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much,
+ * so we do not have helper methods for them.
+ *
+ * WARNING: This only works with HeapByteBuffer
+ */
+private[columnar] object ByteBufferHelper {
+ def getShort(buffer: ByteBuffer): Short = {
+ val pos = buffer.position()
+ buffer.position(pos + 2)
+ Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getInt(buffer: ByteBuffer): Int = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getLong(buffer: ByteBuffer): Long = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getFloat(buffer: ByteBuffer): Float = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getDouble(buffer: ByteBuffer): Double = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def putShort(buffer: ByteBuffer, value: Short): Unit = {
+ val pos = buffer.position()
+ buffer.position(pos + 2)
+ Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value)
+ }
+
+ def putInt(buffer: ByteBuffer, value: Int): Unit = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value)
+ }
+
+ def putLong(buffer: ByteBuffer, value: Long): Unit = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value)
+ }
+
+ def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = {
+ val srcPos = src.position()
+ val dstPos = dst.position()
+ src.position(srcPos + len)
+ dst.position(dstPos + len)
+ Platform.copyMemory(
+ src.array(),
+ Platform.BYTE_ARRAY_OFFSET + srcPos,
+ dst.array(),
+ Platform.BYTE_ARRAY_OFFSET + dstPos,
+ len)
+ }
+}
+
+/**
+ * An abstract class that represents type of a column. Used to append/extract Java objects into/from
+ * the underlying [[ByteBuffer]] of a column.
+ *
+ * @tparam JvmType Underlying Java type to represent the elements.
+ */
+sealed abstract private[columnar] class CelebornColumnType[JvmType] {
+
+ // The catalyst physical data type of this column.
+ def dataType: PhysicalDataType
+
+ // Default size in bytes for one element of type T (e.g. 4 for `Int`).
+ def defaultSize: Int
+
+ /**
+ * Extracts a value out of the buffer at the buffer's current position.
+ */
+ def extract(buffer: ByteBuffer): JvmType
+
+ /**
+ * Extracts a value out of the buffer at the buffer's current position and stores in
+ * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever
+ * possible.
+ */
+ def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ setField(row, ordinal, extract(buffer))
+ }
+
+ /**
+ * Appends the given value v of type T into the given ByteBuffer.
+ */
+ def append(v: JvmType, buffer: ByteBuffer): Unit
+
+ /**
+ * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this
+ * method to avoid boxing/unboxing costs whenever possible.
+ */
+ def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ append(getField(row, ordinal), buffer)
+ }
+
+ /**
+ * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
+ * length types such as byte arrays and strings.
+ */
+ def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize
+
+ /**
+ * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
+ * whenever possible.
+ */
+ def getField(row: InternalRow, ordinal: Int): JvmType
+
+ /**
+ * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
+ * costs whenever possible.
+ */
+ def setField(row: InternalRow, ordinal: Int, value: JvmType): Unit
+
+ /**
+ * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid
+ * boxing/unboxing costs whenever possible.
+ */
+ def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int): Unit = {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
+ }
+
+ /**
+ * Creates a duplicated copy of the value.
+ */
+ def clone(v: JvmType): JvmType = v
+
+ override def toString: String = getClass.getSimpleName.stripSuffix("$")
+}
+
+abstract private[columnar] class NativeCelebornColumnType[T <: PhysicalDataType](
+ val dataType: T,
+ val defaultSize: Int)
+ extends CelebornColumnType[T#InternalType] {}
+
+private[columnar] object CELEBORN_INT extends NativeCelebornColumnType(PhysicalIntegerType, 4) {
+ override def append(v: Int, buffer: ByteBuffer): Unit = {
+ buffer.putInt(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putInt(row.getInt(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Int = {
+ ByteBufferHelper.getInt(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setInt(ordinal, ByteBufferHelper.getInt(buffer))
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Int): Unit = {
+ row.setInt(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setInt(toOrdinal, from.getInt(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_LONG extends NativeCelebornColumnType(PhysicalLongType, 8) {
+ override def append(v: Long, buffer: ByteBuffer): Unit = {
+ buffer.putLong(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putLong(row.getLong(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Long = {
+ ByteBufferHelper.getLong(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Long): Unit = {
+ row.setLong(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setLong(toOrdinal, from.getLong(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_FLOAT extends NativeCelebornColumnType(PhysicalFloatType, 4) {
+ override def append(v: Float, buffer: ByteBuffer): Unit = {
+ buffer.putFloat(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putFloat(row.getFloat(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Float = {
+ ByteBufferHelper.getFloat(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer))
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Float): Unit = {
+ row.setFloat(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setFloat(toOrdinal, from.getFloat(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_DOUBLE extends NativeCelebornColumnType(PhysicalDoubleType, 8) {
+ override def append(v: Double, buffer: ByteBuffer): Unit = {
+ buffer.putDouble(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putDouble(row.getDouble(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Double = {
+ ByteBufferHelper.getDouble(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer))
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Double): Unit = {
+ row.setDouble(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setDouble(toOrdinal, from.getDouble(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_BOOLEAN extends NativeCelebornColumnType(PhysicalBooleanType, 1) {
+ override def append(v: Boolean, buffer: ByteBuffer): Unit = {
+ buffer.put(if (v) 1: Byte else 0: Byte)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte)
+ }
+
+ override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setBoolean(ordinal, buffer.get() == 1)
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Boolean): Unit = {
+ row.setBoolean(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_BYTE extends NativeCelebornColumnType(PhysicalByteType, 1) {
+ override def append(v: Byte, buffer: ByteBuffer): Unit = {
+ buffer.put(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(row.getByte(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Byte = {
+ buffer.get()
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setByte(ordinal, buffer.get())
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Byte): Unit = {
+ row.setByte(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setByte(toOrdinal, from.getByte(fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_SHORT extends NativeCelebornColumnType(PhysicalShortType, 2) {
+ override def append(v: Short, buffer: ByteBuffer): Unit = {
+ buffer.putShort(v)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putShort(row.getShort(ordinal))
+ }
+
+ override def extract(buffer: ByteBuffer): Short = {
+ buffer.getShort()
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row.setShort(ordinal, buffer.getShort())
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Short): Unit = {
+ row.setShort(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal)
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ to.setShort(toOrdinal, from.getShort(fromOrdinal))
+ }
+}
+
+/**
+ * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper
+ * objects.
+ */
+private[columnar] trait DirectCopyCelebornColumnType[JvmType] extends CelebornColumnType[JvmType] {
+
+ // copy the bytes from ByteBuffer to UnsafeRow
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ row match {
+ case r: MutableUnsafeRow =>
+ val numBytes = buffer.getInt
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ r.writer.write(
+ ordinal,
+ buffer.array(),
+ buffer.arrayOffset() + cursor,
+ numBytes)
+ case _ =>
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
+ // copy the bytes from UnsafeRow to ByteBuffer
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ row match {
+ case r: UnsafeRow =>
+ r.writeFieldTo(ordinal, buffer)
+ case _ =>
+ super.append(row, ordinal, buffer)
+ }
+ }
+}
+
+private[columnar] object CELEBORN_STRING
+ extends NativeCelebornColumnType(
+ PhysicalStringType(SqlApiConf.get.defaultStringType.collationId),
+ 8)
+ with DirectCopyCelebornColumnType[UTF8String] {
+
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ row.getUTF8String(ordinal).numBytes() + 4
+ }
+
+ override def append(v: UTF8String, buffer: ByteBuffer): Unit = {
+ buffer.putInt(v.numBytes())
+ v.writeTo(buffer)
+ }
+
+ override def extract(buffer: ByteBuffer): UTF8String = {
+ val length = buffer.getInt()
+ val cursor = buffer.position()
+ buffer.position(cursor + length)
+ UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length)
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = {
+ row match {
+ case r: MutableUnsafeRow =>
+ r.writer.write(ordinal, value)
+ case _ =>
+ row.update(ordinal, value.clone())
+ }
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): UTF8String = {
+ row.getUTF8String(ordinal)
+ }
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
+ }
+
+ override def clone(v: UTF8String): UTF8String = v.clone()
+
+}
+
+private[columnar] case class CELEBORN_COMPACT_DECIMAL(precision: Int, scale: Int)
+ extends NativeCelebornColumnType(PhysicalDecimalType(precision, scale), 8) {
+
+ override def extract(buffer: ByteBuffer): Decimal = {
+ Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ // copy it as Long
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
+ override def append(v: Decimal, buffer: ByteBuffer): Unit = {
+ buffer.putLong(v.toUnscaledLong)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ // copy it as Long
+ buffer.putLong(row.getLong(ordinal))
+ } else {
+ append(getField(row, ordinal), buffer)
+ }
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Decimal = {
+ row.getDecimal(ordinal, precision, scale)
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = {
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
+ }
+}
+
+private[columnar] case class CELEBORN_COMPACT_MINI_DECIMAL(precision: Int, scale: Int)
+ extends NativeCelebornColumnType(PhysicalDecimalType(precision, scale), 4) {
+
+ override def extract(buffer: ByteBuffer): Decimal = {
+ Decimal(ByteBufferHelper.getInt(buffer), precision, scale)
+ }
+
+ override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ // copy it as Long
+ row.setInt(ordinal, ByteBufferHelper.getInt(buffer))
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
+ override def append(v: Decimal, buffer: ByteBuffer): Unit = {
+ buffer.putInt(v.toInt)
+ }
+
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ // copy it as Long
+ buffer.putInt(row.getInt(ordinal))
+ } else {
+ append(getField(row, ordinal), buffer)
+ }
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): Decimal = {
+ row.getDecimal(ordinal, precision, scale)
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = {
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def copyField(
+ from: InternalRow,
+ fromOrdinal: Int,
+ to: InternalRow,
+ toOrdinal: Int): Unit = {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
+ }
+}
+
+private[columnar] object CELEBORN_COMPACT_DECIMAL {
+ def apply(dt: DecimalType): CELEBORN_COMPACT_DECIMAL = {
+ CELEBORN_COMPACT_DECIMAL(dt.precision, dt.scale)
+ }
+}
+
+private[columnar] object CELEBORN_COMPACT_MINI_DECIMAL {
+ def apply(dt: DecimalType): CELEBORN_COMPACT_MINI_DECIMAL = {
+ CELEBORN_COMPACT_MINI_DECIMAL(dt.precision, dt.scale)
+ }
+}
+
+sealed abstract private[columnar] class ByteArrayCelebornColumnType[JvmType](val defaultSize: Int)
+ extends CelebornColumnType[JvmType] with DirectCopyCelebornColumnType[JvmType] {
+
+ def serialize(value: JvmType): Array[Byte]
+ def deserialize(bytes: Array[Byte]): JvmType
+
+ override def append(v: JvmType, buffer: ByteBuffer): Unit = {
+ val bytes = serialize(v)
+ buffer.putInt(bytes.length).put(bytes, 0, bytes.length)
+ }
+
+ override def extract(buffer: ByteBuffer): JvmType = {
+ val length = buffer.getInt()
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ deserialize(bytes)
+ }
+}
+
+private[columnar] case class CELEBORN_LARGE_DECIMAL(precision: Int, scale: Int)
+ extends ByteArrayCelebornColumnType[Decimal](12) {
+
+ override val dataType: PhysicalDecimalType = PhysicalDecimalType(precision, scale)
+
+ override def getField(row: InternalRow, ordinal: Int): Decimal = {
+ row.getDecimal(ordinal, precision, scale)
+ }
+
+ override def setField(row: InternalRow, ordinal: Int, value: Decimal): Unit = {
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def actualSize(row: InternalRow, ordinal: Int): Int = {
+ 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1
+ }
+
+ override def serialize(value: Decimal): Array[Byte] = {
+ value.toJavaBigDecimal.unscaledValue().toByteArray
+ }
+
+ override def deserialize(bytes: Array[Byte]): Decimal = {
+ val javaDecimal = new BigDecimal(new BigInteger(bytes), scale)
+ Decimal.apply(javaDecimal, precision, scale)
+ }
+}
+
+private[columnar] object CELEBORN_LARGE_DECIMAL {
+ def apply(dt: DecimalType): CELEBORN_LARGE_DECIMAL = {
+ CELEBORN_LARGE_DECIMAL(dt.precision, dt.scale)
+ }
+}
+
+private[columnar] object CelebornColumnType {
+ def apply(dataType: DataType): CelebornColumnType[_] = {
+ dataType match {
+ case BooleanType => CELEBORN_BOOLEAN
+ case ByteType => CELEBORN_BYTE
+ case ShortType => CELEBORN_SHORT
+ case IntegerType => CELEBORN_INT
+ case LongType => CELEBORN_LONG
+ case FloatType => CELEBORN_FLOAT
+ case DoubleType => CELEBORN_DOUBLE
+ case StringType => CELEBORN_STRING
+ case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS =>
+ CELEBORN_COMPACT_MINI_DECIMAL(dt)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ CELEBORN_COMPACT_DECIMAL(dt)
+ case dt: DecimalType => CELEBORN_LARGE_DECIMAL(dt)
+ case other => throw new Exception(s"Unsupported type: ${other.catalogString}")
+ }
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
new file mode 100644
index 00000000000..4685e997736
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io.ByteArrayOutputStream
+
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.types._
+
+class CelebornColumnarBatchBuilder(
+ schema: StructType,
+ batchSize: Int = 0,
+ maxDictFactor: Double,
+ encodingEnabled: Boolean = false) extends CelebornBatchBuilder {
+ var rowCnt = 0
+
+ private val typeConversion
+ : PartialFunction[DataType, NativeCelebornColumnType[_ <: PhysicalDataType]] = {
+ case IntegerType => CELEBORN_INT
+ case LongType => CELEBORN_LONG
+ case StringType => CELEBORN_STRING
+ case BooleanType => CELEBORN_BOOLEAN
+ case ShortType => CELEBORN_SHORT
+ case ByteType => CELEBORN_BYTE
+ case FloatType => CELEBORN_FLOAT
+ case DoubleType => CELEBORN_DOUBLE
+ case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS =>
+ CELEBORN_COMPACT_MINI_DECIMAL(dt)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => CELEBORN_COMPACT_DECIMAL(dt)
+ case _ => null
+ }
+
+ private val encodersArr: Array[Encoder[_ <: PhysicalDataType]] = schema.map { attribute =>
+ val nativeColumnType = typeConversion(attribute.dataType)
+ if (nativeColumnType == null) {
+ null
+ } else {
+ if (encodingEnabled && CelebornDictionaryEncoding.supports(nativeColumnType)) {
+ CelebornDictionaryEncoding.MAX_DICT_SIZE =
+ Math.min(Short.MaxValue, batchSize * maxDictFactor).toShort
+ CelebornDictionaryEncoding.encoder(nativeColumnType)
+ } else {
+ CelebornPassThrough.encoder(nativeColumnType)
+ }
+ }
+ }.toArray
+
+ var columnBuilders: Array[CelebornColumnBuilder] = _
+
+ def newBuilders(): Unit = {
+ rowCnt = 0
+ var i = -1
+ columnBuilders = schema.map { attribute =>
+ i += 1
+ encodersArr(i) match {
+ case encoder: CelebornDictionaryEncoding.CelebornEncoder[_] if !encoder.overflow =>
+ encoder.cleanBatch
+ case _ =>
+ }
+ CelebornColumnBuilder(
+ attribute.dataType,
+ batchSize,
+ attribute.name,
+ encodingEnabled,
+ encodersArr(i))
+ }.toArray
+ }
+
+ def buildColumnBytes(): Array[Byte] = {
+ val giantBuffer = new ByteArrayOutputStream
+ val rowCntBytes = int2ByteArray(rowCnt)
+ giantBuffer.write(rowCntBytes)
+ val builderLen = columnBuilders.length
+ var i = 0
+ while (i < builderLen) {
+ val builder = columnBuilders(i)
+ val buffers = builder.build()
+ val bytes = JavaUtils.bufferToArray(buffers)
+ val columnBuilderBytes = int2ByteArray(bytes.length)
+ giantBuffer.write(columnBuilderBytes)
+ giantBuffer.write(bytes)
+ i += 1
+ }
+ giantBuffer.toByteArray
+ }
+
+ def writeRow(row: InternalRow): Unit = {
+ var i = 0
+ while (i < row.numFields) {
+ columnBuilders(i).appendFrom(row, i)
+ i += 1
+ }
+ rowCnt += 1
+ }
+
+ def getRowCnt: Int = rowCnt
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
new file mode 100644
index 00000000000..e510e645259
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io.ByteArrayOutputStream
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection.newCodeGenContext
+import org.apache.spark.sql.types._
+import org.slf4j.LoggerFactory
+
+class CelebornColumnarBatchCodeGenBuild {
+
+ private val logger = LoggerFactory.getLogger(classOf[CelebornColumnarBatchCodeGenBuild])
+
+ def create(schema: StructType, batchSize: Int): CelebornBatchBuilder = {
+ val ctx = newCodeGenContext()
+ val codes = genCode(schema, batchSize)
+ val codeBody =
+ s"""
+ |
+ |public java.lang.Object generate(Object[] references) {
+ | return new SpecificCelebornColumnarBatchBuilder(references);
+ |}
+ |
+ |class SpecificCelebornColumnarBatchBuilder extends ${classOf[CelebornBatchBuilder].getName} {
+ |
+ | private Object[] references;
+ | int rowCnt = 0;
+ | ${codes._1}
+ |
+ | public SpecificCelebornColumnarBatchBuilder(Object[] references) {
+ | this.references = references;
+ | }
+ |
+ | public void newBuilders() throws Exception {
+ | rowCnt = 0;
+ | ${codes._2}
+ | }
+ |
+ | public byte[] buildColumnBytes() throws Exception {
+ | ${classOf[ByteArrayOutputStream].getName} giantBuffer = new ${classOf[
+ ByteArrayOutputStream].getName}();
+ | byte[] rowCntBytes = int2ByteArray(rowCnt);
+ | giantBuffer.write(rowCntBytes);
+ | ${codes._3}
+ | return giantBuffer.toByteArray();
+ | }
+ |
+ | public void writeRow(InternalRow row) throws Exception {
+ | ${codes._4}
+ | rowCnt += 1;
+ | }
+ |
+ | public int getRowCnt() {
+ | return rowCnt;
+ | }
+ |}
+ """.stripMargin
+
+ val code = CodeFormatter.stripOverlappingComments(
+ new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+
+ logger.debug(s"\n${CodeFormatter.format(code)}")
+
+ val (clazz, _) = CodeGenerator.compile(code)
+
+ clazz.generate(ctx.references.toArray).asInstanceOf[CelebornBatchBuilder]
+ }
+
+ /**
+ * @return (code to define columnBuilder, code to instantiate columnBuilder,
+ * code to build column bytes, code to write row to columnBuilder)
+ */
+ def genCode(schema: StructType, batchSize: Int): (
+ mutable.StringBuilder,
+ mutable.StringBuilder,
+ mutable.StringBuilder,
+ mutable.StringBuilder) = {
+ val initCode = new mutable.StringBuilder()
+ val buildCode = new mutable.StringBuilder()
+ val writeCode = new mutable.StringBuilder()
+ val writeRowCode = new mutable.StringBuilder()
+ for (index <- schema.indices) {
+ schema.fields(index).dataType match {
+ case ByteType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornByteCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornByteCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case BooleanType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornBooleanCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornBooleanCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case ShortType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornShortCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornShortCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case IntegerType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornIntCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornIntCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case LongType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornLongCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornLongCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case FloatType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornFloatCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornFloatCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case DoubleType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornDoubleCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornDoubleCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case StringType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornStringCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornStringCodeGenColumnBuilder].getName}();
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornCompactMiniDecimalCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index =
+ | new ${classOf[CelebornCompactMiniDecimalCodeGenColumnBuilder].getName}(
+ | new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale}));
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornCompactDecimalCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index =
+ | new ${classOf[CelebornCompactDecimalCodeGenColumnBuilder].getName}
+ | (new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale}));
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ case dt: DecimalType =>
+ initCode.append(
+ s"""
+ | ${classOf[CelebornDecimalCodeGenColumnBuilder].getName} b$index;
+ """.stripMargin)
+ buildCode.append(
+ s"""
+ | b$index = new ${classOf[CelebornDecimalCodeGenColumnBuilder].getName}
+ | (new ${classOf[DecimalType].getName}(${dt.precision}, ${dt.scale}));
+ | b$index.initialize($batchSize, "${schema.fields(index).name}", false);
+ """.stripMargin)
+ writeCode.append(genWriteCode(index))
+ writeRowCode.append(
+ s"""
+ | b$index.appendFrom(row, $index);
+ """.stripMargin)
+ }
+ }
+ (initCode, buildCode, writeCode, writeRowCode)
+ }
+
+ def genWriteCode(index: Int): String = {
+ s"""
+ | ${classOf[ByteBuffer].getName} buffers$index = b$index.build();
+ | byte[] bytes$index = ${classOf[JavaUtils].getName}.bufferToArray(buffers$index);
+ | byte[] columnBuilderBytes$index = int2ByteArray(bytes$index.length);
+ | giantBuffer.write(columnBuilderBytes$index);
+ | giantBuffer.write(bytes$index);
+ """.stripMargin
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
new file mode 100644
index 00000000000..f9c08a0f6a1
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.io._
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+
+import com.google.common.io.ByteStreams
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+class CelebornColumnarBatchSerializer(
+ schema: StructType,
+ offHeapColumnVectorEnabled: Boolean,
+ dataSize: SQLMetric = null) extends Serializer with Serializable {
+ override def newInstance(): SerializerInstance =
+ new CelebornColumnarBatchSerializerInstance(
+ schema,
+ offHeapColumnVectorEnabled,
+ dataSize)
+ override def supportsRelocationOfSerializedObjects: Boolean = true
+}
+
+class CelebornColumnarBatchSerializerInstance(
+ schema: StructType,
+ offHeapColumnVectorEnabled: Boolean,
+ dataSize: SQLMetric) extends SerializerInstance {
+
+ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
+ private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
+ private[this] val dOut: DataOutputStream =
+ new DataOutputStream(new BufferedOutputStream(out))
+
+ override def writeValue[T: ClassTag](value: T): SerializationStream = {
+ val row = value.asInstanceOf[UnsafeRow]
+ if (dataSize != null) {
+ dataSize.add(row.getSizeInBytes)
+ }
+ dOut.writeInt(row.getSizeInBytes)
+ row.writeToStream(dOut, writeBuffer)
+ this
+ }
+
+ override def writeKey[T: ClassTag](key: T): SerializationStream = {
+ assert(null == key || key.isInstanceOf[Int])
+ this
+ }
+
+ override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
+ throw new UnsupportedOperationException
+ }
+
+ override def writeObject[T: ClassTag](t: T): SerializationStream = {
+ throw new UnsupportedOperationException
+ }
+
+ override def flush(): Unit = {
+ dOut.flush()
+ }
+
+ override def close(): Unit = {
+ writeBuffer = null
+ dOut.close()
+ }
+ }
+
+ private val toUnsafe: UnsafeProjection =
+ UnsafeProjection.create(schema.fields.map(f => f.dataType))
+
+ override def deserializeStream(in: InputStream): DeserializationStream = {
+ val numFields = schema.fields.length
+ new DeserializationStream {
+ val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
+ val EOF: Int = -1
+ var colBuffer: Array[Byte] = new Array[Byte](1024)
+ var numRows: Int = readSize()
+ var rowIter: Iterator[InternalRow] = if (numRows != EOF) nextBatch() else Iterator.empty
+
+ override def asKeyValueIterator: Iterator[(Int, InternalRow)] = {
+ new Iterator[(Int, InternalRow)] {
+
+ override def hasNext: Boolean = rowIter.hasNext || {
+ if (numRows != EOF) {
+ rowIter = nextBatch()
+ true
+ } else {
+ false
+ }
+ }
+
+ override def next(): (Int, InternalRow) = {
+ (0, rowIter.next())
+ }
+ }
+ }
+
+ override def asIterator: Iterator[Any] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def readObject[T: ClassTag](): T = {
+ throw new UnsupportedOperationException
+ }
+
+ def nextBatch(): Iterator[InternalRow] = {
+ val columnVectors =
+ if (!offHeapColumnVectorEnabled) {
+ OnHeapColumnVector.allocateColumns(numRows, schema)
+ } else {
+ OffHeapColumnVector.allocateColumns(numRows, schema)
+ }
+ val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]])
+ columnarBatch.setNumRows(numRows)
+
+ for (i <- 0 until numFields) {
+ val colLen: Int = readSize()
+ if (colBuffer.length < colLen) {
+ colBuffer = new Array[Byte](colLen)
+ }
+ ByteStreams.readFully(dIn, colBuffer, 0, colLen)
+ CelebornColumnAccessor.decompress(
+ colBuffer,
+ columnarBatch.column(i).asInstanceOf[WritableColumnVector],
+ schema.fields(i).dataType,
+ numRows)
+ }
+ numRows = readSize()
+ columnarBatch.rowIterator().asScala.map(toUnsafe)
+ }
+
+ def readSize(): Int =
+ try {
+ dIn.readInt()
+ } catch {
+ case _: EOFException =>
+ dIn.close()
+ EOF
+ }
+
+ override def close(): Unit = {
+ dIn.close()
+ }
+ }
+ }
+
+ override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
+ throw new UnsupportedOperationException
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
+ throw new UnsupportedOperationException
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala
new file mode 100644
index 00000000000..5d633221271
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnAccessor.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+
+trait CelebornCompressibleColumnAccessor[T <: PhysicalDataType] extends CelebornColumnAccessor {
+ this: CelebornNativeColumnAccessor[T] =>
+
+ private var decoder: Decoder[T] = _
+
+ abstract override protected def initialize(): Unit = {
+ super.initialize()
+ decoder = CelebornCompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType)
+ }
+
+ abstract override def hasNext: Boolean = super.hasNext || decoder.hasNext
+
+ override def extractSingle(row: InternalRow, ordinal: Int): Unit = {
+ decoder.next(row, ordinal)
+ }
+
+ def decompress(columnVector: WritableColumnVector, capacity: Int): Unit =
+ decoder.decompress(columnVector, capacity)
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
new file mode 100644
index 00000000000..e721f262dc0
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.unsafe.Platform
+
+trait CelebornCompressibleColumnBuilder[T <: PhysicalDataType]
+ extends CelebornColumnBuilder with Logging {
+
+ this: CelebornNativeColumnBuilder[T] with WithCelebornCompressionSchemes =>
+
+ private var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType)
+
+ def init(encoder: Encoder[T]): Unit = {
+ compressionEncoder = encoder
+ }
+
+ abstract override def initialize(
+ rowCnt: Int,
+ columnName: String,
+ encodingEnabled: Boolean): Unit = {
+ super.initialize(rowCnt, columnName, encodingEnabled)
+ }
+
+ // The various compression schemes, while saving memory use, cause all of the data within
+ // the row to become unaligned, thus causing crashes. Until a way of fixing the compression
+ // is found to also allow aligned accesses this must be disabled for SPARK.
+
+ protected def isWorthCompressing(encoder: Encoder[T]) = {
+ CelebornCompressibleColumnBuilder.unaligned && encoder.compressionRatio < 0.8
+ }
+
+ private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
+ compressionEncoder.gatherCompressibilityStats(row, ordinal)
+ }
+
+ abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ super.appendFrom(row, ordinal)
+ if (!row.isNullAt(ordinal)) {
+ gatherCompressibilityStats(row, ordinal)
+ }
+ }
+
+ override def build(): ByteBuffer = {
+ val nonNullBuffer = buildNonNulls()
+ val encoder: Encoder[T] = {
+ if (isWorthCompressing(compressionEncoder)) compressionEncoder
+ else CelebornPassThrough.encoder(columnType)
+ }
+
+ // Header = null count + null positions
+ val headerSize = 4 + nulls.limit()
+ val compressedSize =
+ if (encoder.compressedSize == 0) {
+ nonNullBuffer.remaining()
+ } else {
+ encoder.compressedSize
+ }
+
+ val compressedBuffer = ByteBuffer
+ // Reserves 4 bytes for compression scheme ID
+ .allocate(headerSize + 4 + compressedSize)
+ .order(ByteOrder.nativeOrder)
+ // Write the header
+ .putInt(nullCount)
+ .put(nulls)
+
+ logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(nonNullBuffer, compressedBuffer)
+ }
+
+ override def getTotalSize: Long = {
+ val encoder: Encoder[T] = {
+ if (isWorthCompressing(compressionEncoder)) compressionEncoder
+ else CelebornPassThrough.encoder(columnType)
+ }
+ if (encoder.compressedSize == 0) {
+ 4 + 4 + columnStats.sizeInBytes
+ } else {
+ 4 + 4 * nullCount + 4 + encoder.compressedSize
+ }
+ }
+}
+
+object CelebornCompressibleColumnBuilder {
+ val unaligned = Platform.unaligned()
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
new file mode 100644
index 00000000000..7e25956674c
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+
+trait Encoder[T <: PhysicalDataType] {
+ def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {}
+
+ def compressedSize: Int
+
+ def uncompressedSize: Int
+
+ def compressionRatio: Double = {
+ if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
+ }
+
+ def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer
+}
+
+trait Decoder[T <: PhysicalDataType] {
+ def next(row: InternalRow, ordinal: Int): Unit
+
+ def hasNext: Boolean
+
+ def decompress(columnVector: WritableColumnVector, capacity: Int): Unit
+}
+
+trait CelebornCompressionScheme {
+ def typeId: Int
+
+ def supports(columnType: CelebornColumnType[_]): Boolean
+
+ def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T]): Encoder[T]
+
+ def decoder[T <: PhysicalDataType](
+ buffer: ByteBuffer,
+ columnType: NativeCelebornColumnType[T]): Decoder[T]
+}
+
+trait WithCelebornCompressionSchemes {
+ def schemes: Seq[CelebornCompressionScheme]
+}
+
+trait AllCelebornCompressionSchemes extends WithCelebornCompressionSchemes {
+ override val schemes: Seq[CelebornCompressionScheme] = CelebornCompressionScheme.all
+}
+
+object CelebornCompressionScheme {
+ val all: Seq[CelebornCompressionScheme] =
+ Seq(CelebornPassThrough, CelebornDictionaryEncoding)
+
+ private val typeIdToScheme = all.map(scheme => scheme.typeId -> scheme).toMap
+
+ def apply(typeId: Int): CelebornCompressionScheme = {
+ typeIdToScheme.getOrElse(
+ typeId,
+ throw new UnsupportedOperationException(s"Unrecognized compression scheme type ID: $typeId"))
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
new file mode 100644
index 00000000000..9b17626c314
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+case object CelebornPassThrough extends CelebornCompressionScheme {
+ override val typeId = 0
+
+ override def supports(columnType: CelebornColumnType[_]): Boolean = true
+
+ override def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T])
+ : Encoder[T] = {
+ new this.CelebornEncoder[T]()
+ }
+
+ override def decoder[T <: PhysicalDataType](
+ buffer: ByteBuffer,
+ columnType: NativeCelebornColumnType[T]): Decoder[T] = {
+ new this.CelebornDecoder(buffer, columnType)
+ }
+
+ class CelebornEncoder[T <: PhysicalDataType]()
+ extends Encoder[T] {
+ override def uncompressedSize: Int = 0
+
+ override def compressedSize: Int = 0
+
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
+ // Writes compression type ID and copies raw contents
+ to.putInt(CelebornPassThrough.typeId).put(from).rewind()
+ to
+ }
+ }
+
+ class CelebornDecoder[T <: PhysicalDataType](
+ buffer: ByteBuffer,
+ columnType: NativeCelebornColumnType[T])
+ extends Decoder[T] {
+
+ override def next(row: InternalRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
+
+ override def hasNext: Boolean = buffer.hasRemaining
+
+ private def putBooleans(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ for (i <- 0 until len) {
+ columnVector.putBoolean(pos + i, buffer.get(bufferPos + i) != 0)
+ }
+ }
+
+ private def putBytes(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putBytes(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putShorts(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putShorts(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putInts(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putInts(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putLongs(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putLongs(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putFloats(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putFloats(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putDoubles(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putDoubles(pos, len, buffer.array, bufferPos)
+ }
+
+ private def putByteArray(
+ columnVector: WritableColumnVector,
+ pos: Int,
+ bufferPos: Int,
+ len: Int): Unit = {
+ columnVector.putByteArray(pos, buffer.array, bufferPos, len)
+ }
+
+ private def decompressPrimitive(
+ columnVector: WritableColumnVector,
+ rowCnt: Int,
+ unitSize: Int,
+ putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = {
+ val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
+ nullsBuffer.rewind()
+ val nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else rowCnt
+ var valueIndex = 0
+ var seenNulls = 0
+ var bufferPos = buffer.position()
+ while (valueIndex < rowCnt) {
+ if (valueIndex != nextNullIndex) {
+ val len = nextNullIndex - valueIndex
+ assert(len * unitSize.toLong < Int.MaxValue)
+ putFunction(columnVector, valueIndex, bufferPos, len)
+ bufferPos += len * unitSize
+ valueIndex += len
+ } else {
+ seenNulls += 1
+ nextNullIndex =
+ if (seenNulls < nullCount) {
+ ByteBufferHelper.getInt(nullsBuffer)
+ } else {
+ rowCnt
+ }
+ columnVector.putNull(valueIndex)
+ valueIndex += 1
+ }
+ }
+ }
+
+ private def decompressString(
+ columnVector: WritableColumnVector,
+ rowCnt: Int,
+ putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = {
+ val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
+ nullsBuffer.rewind()
+ val nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else rowCnt
+ var valueIndex = 0
+ var seenNulls = 0
+ while (valueIndex < rowCnt) {
+ if (valueIndex != nextNullIndex) {
+ val len = nextNullIndex - valueIndex
+ for (index <- valueIndex until nextNullIndex) {
+ val length = buffer.getInt()
+ val cursor = buffer.position()
+ buffer.position(cursor + length)
+ putFunction(columnVector, index, buffer.arrayOffset() + cursor, length)
+ }
+ valueIndex += len
+ } else {
+ seenNulls += 1
+ nextNullIndex =
+ if (seenNulls < nullCount) {
+ ByteBufferHelper.getInt(nullsBuffer)
+ } else {
+ rowCnt
+ }
+ columnVector.putNull(valueIndex)
+ valueIndex += 1
+ }
+ }
+ }
+
+ private def decompressDecimal(
+ columnVector: WritableColumnVector,
+ rowCnt: Int,
+ precision: Int): Unit = {
+ if (precision <= Decimal.MAX_INT_DIGITS) decompressPrimitive(columnVector, rowCnt, 4, putInts)
+ else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ decompressPrimitive(columnVector, rowCnt, 8, putLongs)
+ } else {
+ decompressString(columnVector, rowCnt, putByteArray)
+ }
+ }
+
+ override def decompress(columnVector: WritableColumnVector, rowCnt: Int): Unit = {
+ columnType.dataType match {
+ case _: PhysicalBooleanType =>
+ val unitSize = 1
+ decompressPrimitive(columnVector, rowCnt, unitSize, putBooleans)
+ case _: PhysicalByteType =>
+ val unitSize = 1
+ decompressPrimitive(columnVector, rowCnt, unitSize, putBytes)
+ case _: PhysicalShortType =>
+ val unitSize = 2
+ decompressPrimitive(columnVector, rowCnt, unitSize, putShorts)
+ case _: PhysicalIntegerType =>
+ val unitSize = 4
+ decompressPrimitive(columnVector, rowCnt, unitSize, putInts)
+ case _: PhysicalLongType =>
+ val unitSize = 8
+ decompressPrimitive(columnVector, rowCnt, unitSize, putLongs)
+ case _: PhysicalFloatType =>
+ val unitSize = 4
+ decompressPrimitive(columnVector, rowCnt, unitSize, putFloats)
+ case _: PhysicalDoubleType =>
+ val unitSize = 8
+ decompressPrimitive(columnVector, rowCnt, unitSize, putDoubles)
+ case _: PhysicalStringType =>
+ decompressString(columnVector, rowCnt, putByteArray)
+ case d: PhysicalDecimalType =>
+ decompressDecimal(columnVector, rowCnt, d.precision)
+ }
+ }
+ }
+}
+
+case object CelebornDictionaryEncoding extends CelebornCompressionScheme {
+ override val typeId = 1
+
+ // 32K unique values allowed
+ var MAX_DICT_SIZE: Short = Short.MaxValue
+
+ override def decoder[T <: PhysicalDataType](
+ buffer: ByteBuffer,
+ columnType: NativeCelebornColumnType[T]): Decoder[T] = {
+ new this.CelebornDecoder(buffer, columnType)
+ }
+
+ override def encoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T])
+ : Encoder[T] = {
+ new this.CelebornEncoder[T](columnType)
+ }
+
+ override def supports(columnType: CelebornColumnType[_]): Boolean = columnType match {
+ case CELEBORN_INT | CELEBORN_LONG | CELEBORN_FLOAT | CELEBORN_DOUBLE | CELEBORN_STRING => true
+ case _ => false
+ }
+
+ class CelebornEncoder[T <: PhysicalDataType](columnType: NativeCelebornColumnType[T])
+ extends Encoder[T] {
+ // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
+ // overflows.
+ private var _uncompressedSize = 0
+
+ // If the number of distinct elements is too large, we discard the use of dictionary encoding
+ // and set the overflow flag to true.
+ var overflow = false
+
+ // Total number of elements.
+ private var count = 0
+
+ def cleanBatch(): Unit = {
+ count = 0
+ _uncompressedSize = 0
+ }
+
+ // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
+ private val values = new mutable.ArrayBuffer[T#InternalType](1024)
+
+ // The dictionary that maps a value to the encoded short integer.
+ private val dictionary = new java.util.HashMap[Any, Short](1024)
+
+ // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int`
+ // to store dictionary element count.
+ private var dictionarySize = 4
+
+ override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!overflow) {
+ val value = columnType.getField(row, ordinal)
+ val actualSize = columnType.actualSize(row, ordinal)
+ count += 1
+ _uncompressedSize += actualSize
+ if (!dictionary.containsKey(value)) {
+ if (dictionary.size < MAX_DICT_SIZE) {
+ val clone = columnType.clone(value)
+ values += clone
+ dictionarySize += actualSize
+ dictionary.put(clone, dictionary.size.toShort)
+ } else {
+ overflow = true
+ values.clear()
+ dictionary.clear()
+ }
+ }
+ }
+ }
+
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
+ to.putInt(CelebornDictionaryEncoding.typeId)
+ .putInt(dictionary.size)
+
+ var i = 0
+ while (i < values.length) {
+ columnType.append(values(i), to)
+ i += 1
+ }
+
+ while (from.hasRemaining) {
+ to.putShort(dictionary.get(columnType.extract(from)))
+ }
+
+ to.rewind()
+ to
+ }
+
+ override def uncompressedSize: Int = _uncompressedSize
+
+ // 2 is the data size after(short type) dictionary encoding
+ override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2
+ }
+
+ class CelebornDecoder[T <: PhysicalDataType](
+ buffer: ByteBuffer,
+ columnType: NativeCelebornColumnType[T])
+ extends Decoder[T] {
+ private val elementNum: Int = ByteBufferHelper.getInt(buffer)
+ private val dictionary: Array[Any] = new Array[Any](elementNum)
+ private var intDictionary: Array[Int] = _
+ private var longDictionary: Array[Long] = _
+ private var floatDictionary: Array[Float] = _
+ private var doubleDictionary: Array[Double] = _
+ private var stringDictionary: Array[String] = _
+
+ columnType.dataType match {
+ case _: PhysicalIntegerType =>
+ intDictionary = new Array[Int](elementNum)
+ for (i <- 0 until elementNum) {
+ val v = columnType.extract(buffer).asInstanceOf[Int]
+ intDictionary(i) = v
+ dictionary(i) = v
+ }
+ case _: PhysicalLongType =>
+ longDictionary = new Array[Long](elementNum)
+ for (i <- 0 until elementNum) {
+ val v = columnType.extract(buffer).asInstanceOf[Long]
+ longDictionary(i) = v
+ dictionary(i) = v
+ }
+ case _: PhysicalFloatType =>
+ floatDictionary = new Array[Float](elementNum)
+ for (i <- 0 until elementNum) {
+ val v = columnType.extract(buffer).asInstanceOf[Float]
+ floatDictionary(i) = v
+ dictionary(i) = v
+ }
+ case _: PhysicalDoubleType =>
+ doubleDictionary = new Array[Double](elementNum)
+ for (i <- 0 until elementNum) {
+ val v = columnType.extract(buffer).asInstanceOf[Double]
+ doubleDictionary(i) = v
+ dictionary(i) = v
+ }
+ case _: PhysicalStringType =>
+ stringDictionary = new Array[String](elementNum)
+ for (i <- 0 until elementNum) {
+ val v = columnType.extract(buffer).asInstanceOf[UTF8String]
+ stringDictionary(i) = v.toString
+ dictionary(i) = v
+ }
+ }
+
+ override def next(row: InternalRow, ordinal: Int): Unit = {
+ columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType])
+ }
+
+ override def hasNext: Boolean = buffer.hasRemaining
+
+ override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = {
+ val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
+ nullsBuffer.rewind()
+ val nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
+ var pos = 0
+ var seenNulls = 0
+ val dictionaryIds = columnVector.reserveDictionaryIds(capacity)
+ columnType.dataType match {
+ case _: PhysicalIntegerType =>
+ columnVector.setDictionary(new CelebornColumnDictionary(intDictionary))
+ case _: PhysicalLongType =>
+ columnVector.setDictionary(new CelebornColumnDictionary(longDictionary))
+ case _: PhysicalFloatType =>
+ columnVector.setDictionary(new CelebornColumnDictionary(floatDictionary))
+ case _: PhysicalDoubleType =>
+ columnVector.setDictionary(new CelebornColumnDictionary(doubleDictionary))
+ case _: PhysicalStringType =>
+ columnVector.setDictionary(new CelebornColumnDictionary(stringDictionary))
+ case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.")
+ }
+ while (pos < capacity) {
+ if (pos != nextNullIndex) {
+ dictionaryIds.putInt(pos, buffer.getShort())
+ } else {
+ seenNulls += 1
+ if (seenNulls < nullCount) {
+ nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
+ }
+ columnVector.putNull(pos)
+ }
+ pos += 1
+ }
+ }
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala
new file mode 100644
index 00000000000..c2a721b58cc
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnAccessor.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+
+trait CelebornNullableColumnAccessor extends CelebornColumnAccessor {
+ private var nullsBuffer: ByteBuffer = _
+ private var nullCount: Int = _
+ private var seenNulls: Int = 0
+
+ private var nextNullIndex: Int = _
+ private var pos: Int = 0
+
+ abstract override protected def initialize(): Unit = {
+ nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
+ nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
+ pos = 0
+
+ underlyingBuffer.position(underlyingBuffer.position() + 4 + nullCount * 4)
+ super.initialize()
+ }
+
+ abstract override def extractTo(row: InternalRow, ordinal: Int): Unit = {
+ if (pos == nextNullIndex) {
+ seenNulls += 1
+
+ if (seenNulls < nullCount) {
+ nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
+ }
+
+ row.setNullAt(ordinal)
+ } else {
+ super.extractTo(row, ordinal)
+ }
+
+ pos += 1
+ }
+
+ abstract override def extractToColumnVector(
+ columnVector: WritableColumnVector,
+ ordinal: Int): Unit = {
+ if (pos == nextNullIndex) {
+ seenNulls += 1
+
+ if (seenNulls < nullCount) {
+ nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
+ }
+
+ columnVector.putNull(ordinal)
+ } else {
+ super.extractToColumnVector(columnVector, ordinal)
+ }
+
+ pos += 1
+ }
+
+ abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala
new file mode 100644
index 00000000000..320df89f0a2
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornNullableColumnBuilder.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+trait CelebornNullableColumnBuilder extends CelebornColumnBuilder {
+ protected var nulls: ByteBuffer = _
+ protected var nullCount: Int = _
+ var pos: Int = _
+
+ abstract override def initialize(
+ rowCnt: Int,
+ columnName: String,
+ encodingEnabled: Boolean): Unit = {
+
+ nulls = ByteBuffer.allocate(1024)
+ nulls.order(ByteOrder.nativeOrder())
+ pos = 0
+ nullCount = 0
+ super.initialize(rowCnt, columnName, encodingEnabled)
+ }
+
+ abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nulls = CelebornColumnBuilder.ensureFreeSpace(nulls, 4)
+ nulls.putInt(pos)
+ nullCount += 1
+ } else {
+ super.appendFrom(row, ordinal)
+ }
+ pos += 1
+ }
+
+ abstract override def build(): ByteBuffer = {
+ val nonNulls = super.build()
+ val nullDataLen = nulls.position()
+
+ nulls.limit(nullDataLen)
+ nulls.rewind()
+
+ val buffer = ByteBuffer
+ .allocate(4 + nullDataLen + nonNulls.remaining())
+ .order(ByteOrder.nativeOrder())
+ .putInt(nullCount)
+ .put(nulls)
+ .put(nonNulls)
+
+ buffer.rewind()
+ buffer
+ }
+
+ protected def buildNonNulls(): ByteBuffer = {
+ nulls.limit(nulls.position()).rewind()
+ super.build()
+ }
+
+ override def getTotalSize: Long = {
+ 4 + columnStats.sizeInBytes
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java b/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
new file mode 100644
index 00000000000..92aee552fba
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.util.UUID;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.TaskContext;
+import org.apache.spark.serializer.KryoSerializer;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.sql.execution.UnsafeRowSerializer;
+import org.apache.spark.sql.execution.columnar.CelebornBatchBuilder;
+import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializer;
+import org.apache.spark.sql.types.IntegerType$;
+import org.apache.spark.sql.types.StringType$;
+import org.apache.spark.sql.types.StructType;
+import org.junit.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import org.apache.celeborn.client.DummyShuffleClient;
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class ColumnarHashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase {
+
+ private final StructType schema =
+ new StructType().add("key", IntegerType$.MODULE$).add("value", StringType$.MODULE$);
+
+ @Test
+ public void createColumnarShuffleWriter() throws Exception {
+ Mockito.doReturn(new HashPartitioner(numPartitions)).when(dependency).partitioner();
+ final CelebornConf conf = new CelebornConf();
+ final File tempFile = new File(tempDir, UUID.randomUUID().toString());
+ final DummyShuffleClient client = new DummyShuffleClient(conf, tempFile);
+ client.initReducePartitionMap(shuffleId, numPartitions, 1);
+
+ // Create ColumnarHashBasedShuffleWriter with handle of which dependency has null schema.
+ Mockito.doReturn(new KryoSerializer(sparkConf)).when(dependency).serializer();
+ ShuffleWriter writer =
+ createShuffleWriterWithoutSchema(
+ new CelebornShuffleHandle<>(
+ "appId", "host", 0, this.userIdentifier, 0, false, 10, this.dependency),
+ taskContext,
+ conf,
+ client,
+ metrics.shuffleWriteMetrics());
+ assertTrue(writer instanceof ColumnarHashBasedShuffleWriter);
+ assertFalse(((ColumnarHashBasedShuffleWriter, ?, ?>) writer).isColumnarShuffle());
+
+ // Create ColumnarHashBasedShuffleWriter with handle of which dependency has non-null schema.
+ Mockito.doReturn(new UnsafeRowSerializer(2, null)).when(dependency).serializer();
+ writer =
+ createShuffleWriter(
+ new CelebornShuffleHandle<>(
+ "appId", "host", 0, this.userIdentifier, 0, false, 10, this.dependency),
+ taskContext,
+ conf,
+ client,
+ metrics.shuffleWriteMetrics());
+ assertTrue(((ColumnarHashBasedShuffleWriter, ?, ?>) writer).isColumnarShuffle());
+ }
+
+ @Override
+ protected SerializerInstance newSerializerInstance(Serializer serializer) {
+ if (serializer instanceof UnsafeRowSerializer
+ && CelebornBatchBuilder.supportsColumnarType(schema)) {
+ CelebornConf conf = new CelebornConf();
+ return new CelebornColumnarBatchSerializer(schema, conf.columnarShuffleOffHeapEnabled(), null)
+ .newInstance();
+ } else {
+ return serializer.newInstance();
+ }
+ }
+
+ @Override
+ protected ShuffleWriter createShuffleWriter(
+ CelebornShuffleHandle handle,
+ TaskContext context,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics) {
+ try (MockedStatic utils =
+ Mockito.mockStatic(CustomShuffleDependencyUtils.class)) {
+ utils
+ .when(() -> CustomShuffleDependencyUtils.getSchema(handle.dependency()))
+ .thenReturn(schema);
+ return SparkUtils.createColumnarHashBasedShuffleWriter(
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(1, 30, 60));
+ }
+ }
+
+ private ShuffleWriter createShuffleWriterWithoutSchema(
+ CelebornShuffleHandle handle,
+ TaskContext context,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics) {
+ return SparkUtils.createColumnarHashBasedShuffleWriter(
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(1, 30, 60));
+ }
+}
diff --git a/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml b/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml
new file mode 100644
index 00000000000..9adcdccfd0e
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/test/resources/log4j2-test.xml
@@ -0,0 +1,41 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
new file mode 100644
index 00000000000..d0f4462be3e
--- /dev/null
+++ b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn
+
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
+import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
+import org.apache.spark.sql.execution.UnsafeRowSerializer
+import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.junit.Test
+import org.mockito.{MockedStatic, Mockito}
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.identity.UserIdentifier
+
+class CelebornColumnarShuffleReaderSuite {
+
+ @Test
+ def createColumnarShuffleReader(): Unit = {
+ val handle = new CelebornShuffleHandle[Int, String, String](
+ "appId",
+ "host",
+ 0,
+ new UserIdentifier("mock", "mock"),
+ 0,
+ false,
+ 10,
+ null)
+
+ var shuffleClient: MockedStatic[ShuffleClient] = null
+ try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
+ shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
+ val shuffleReader = SparkUtils.createColumnarShuffleReader(
+ handle,
+ 0,
+ 10,
+ 0,
+ 10,
+ taskContext,
+ new CelebornConf(),
+ null,
+ new ExecutorShuffleIdTracker())
+ assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]])
+ } finally {
+ if (shuffleClient != null) {
+ shuffleClient.close()
+ }
+ }
+ }
+
+ @Test
+ def columnarShuffleReaderNewSerializerInstance(): Unit = {
+ var shuffleClient: MockedStatic[ShuffleClient] = null
+ try {
+ val taskContext = Mockito.mock(classOf[TaskContext])
+ Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
+ Mockito.when(taskContext.attemptNumber).thenReturn(0)
+ shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
+ val shuffleReader = SparkUtils.createColumnarShuffleReader(
+ new CelebornShuffleHandle[Int, String, String](
+ "appId",
+ "host",
+ 0,
+ new UserIdentifier("mock", "mock"),
+ 0,
+ false,
+ 10,
+ null),
+ 0,
+ 10,
+ 0,
+ 10,
+ taskContext,
+ new CelebornConf(),
+ null,
+ new ExecutorShuffleIdTracker())
+ val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]])
+ Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
+ Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
+ new SparkConf(false)))
+
+ // CelebornColumnarShuffleReader creates new serializer instance with dependency which has null schema.
+ var serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency)
+ assert(serializerInstance.getClass == classOf[KryoSerializerInstance])
+
+ // CelebornColumnarShuffleReader creates new serializer instance with dependency which has non-null schema.
+ val dependencyUtils = Mockito.mockStatic(classOf[CustomShuffleDependencyUtils])
+ try {
+ dependencyUtils.when(() =>
+ CustomShuffleDependencyUtils.getSchema(shuffleDependency)).thenReturn(
+ new StructType().add(
+ "key",
+ IntegerType).add("value", StringType))
+ Mockito.when(shuffleDependency.serializer).thenReturn(new UnsafeRowSerializer(2, null))
+ serializerInstance = shuffleReader.newSerializerInstance(shuffleDependency)
+ assert(serializerInstance.getClass == classOf[CelebornColumnarBatchSerializerInstance])
+ } finally {
+ // The registration of CustomShuffleDependencyUtils static mocking must be deregistered.
+ dependencyUtils.close()
+ }
+ } finally {
+ if (shuffleClient != null) {
+ shuffleClient.close()
+ }
+ }
+ }
+}
diff --git a/client-spark/spark-4-shaded/pom.xml b/client-spark/spark-4-shaded/pom.xml
new file mode 100644
index 00000000000..52ce16618b4
--- /dev/null
+++ b/client-spark/spark-4-shaded/pom.xml
@@ -0,0 +1,143 @@
+
+
+
+ 4.0.0
+
+ org.apache.celeborn
+ celeborn-parent_${scala.binary.version}
+ ${project.version}
+ ../../pom.xml
+
+
+ celeborn-client-spark-4-shaded_${scala.binary.version}
+ jar
+ Celeborn Shaded Client for Spark 4
+
+
+
+ org.apache.celeborn
+ celeborn-client-spark-3-4_${scala.binary.version}
+ ${project.version}
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+
+
+ com.google.protobuf
+ ${shading.prefix}.com.google.protobuf
+
+
+ com.google.common
+ ${shading.prefix}.com.google.common
+
+
+ com.google.thirdparty
+ ${shading.prefix}.com.google.thirdparty
+
+
+ io.netty
+ ${shading.prefix}.io.netty
+
+
+ org.apache.commons
+ ${shading.prefix}.org.apache.commons
+
+
+ org.roaringbitmap
+ ${shading.prefix}.org.roaringbitmap
+
+
+
+
+ org.apache.celeborn:*
+ com.google.protobuf:protobuf-java
+ com.google.guava:guava
+ com.google.guava:failureaccess
+ io.netty:*
+ org.apache.commons:commons-lang3
+ org.roaringbitmap:RoaringBitmap
+
+
+
+
+ *:*
+
+ **/*.proto
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+ **/log4j.properties
+ META-INF/LICENSE.txt
+ META-INF/NOTICE.txt
+ LICENSE.txt
+ NOTICE.txt
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+ ${maven.plugin.antrun.version}
+
+
+ rename-native-library
+
+ run
+
+ package
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE b/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE
new file mode 100644
index 00000000000..f96af536a1b
--- /dev/null
+++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/LICENSE
@@ -0,0 +1,248 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+------------------------------------------------------------------------------------
+This project bundles the following dependencies under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt):
+
+
+Apache License 2.0
+--------------------------------------
+
+com.google.guava:failureaccess
+com.google.guava:guava
+io.netty:netty-all
+io.netty:netty-buffer
+io.netty:netty-codec
+io.netty:netty-codec-dns
+io.netty:netty-codec-haproxy
+io.netty:netty-codec-http
+io.netty:netty-codec-http2
+io.netty:netty-codec-memcache
+io.netty:netty-codec-mqtt
+io.netty:netty-codec-redis
+io.netty:netty-codec-smtp
+io.netty:netty-codec-socks
+io.netty:netty-codec-stomp
+io.netty:netty-codec-xml
+io.netty:netty-common
+io.netty:netty-handler
+io.netty:netty-handler-proxy
+io.netty:netty-resolver
+io.netty:netty-resolver-dns
+io.netty:netty-transport
+io.netty:netty-transport-classes-epoll
+io.netty:netty-transport-classes-kqueue
+io.netty:netty-transport-native-epoll
+io.netty:netty-transport-native-kqueue
+io.netty:netty-transport-native-unix-common
+io.netty:netty-transport-rxtx
+io.netty:netty-transport-sctp
+io.netty:netty-transport-udt
+org.apache.commons:commons-lang3
+org.roaringbitmap:RoaringBitmap
+
+
+BSD 3-clause
+------------
+See licenses/LICENSE-protobuf.txt for details.
+com.google.protobuf:protobuf-java
\ No newline at end of file
diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE b/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE
new file mode 100644
index 00000000000..c48952d00d9
--- /dev/null
+++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/NOTICE
@@ -0,0 +1,46 @@
+
+Apache Celeborn
+Copyright 2022-2024 The Apache Software Foundation.
+
+This product includes software developed at
+The Apache Software Foundation (https://www.apache.org/).
+
+Apache Spark
+Copyright 2014 and onwards The Apache Software Foundation
+
+Apache Kyuubi
+Copyright 2021-2023 The Apache Software Foundation
+
+Apache Iceberg
+Copyright 2017-2022 The Apache Software Foundation
+
+Apache Parquet MR
+Copyright 2014-2024 The Apache Software Foundation
+
+This project includes code from Kite, developed at Cloudera, Inc. with
+the following copyright notice:
+
+| Copyright 2013 Cloudera Inc.
+|
+| Licensed under the Apache License, Version 2.0 (the "License");
+| you may not use this file except in compliance with the License.
+| You may obtain a copy of the License at
+|
+| http://www.apache.org/licenses/LICENSE-2.0
+|
+| Unless required by applicable law or agreed to in writing, software
+| distributed under the License is distributed on an "AS IS" BASIS,
+| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+| See the License for the specific language governing permissions and
+| limitations under the License.
+
+=============================================================================
+= NOTICE file corresponding to section 4d of the Apache License Version 2.0 =
+=============================================================================
+
+Apache Spark
+Copyright 2014 and onwards The Apache Software Foundation
+
+Apache Commons Lang
+Copyright 2001-2021 The Apache Software Foundation
+
diff --git a/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt b/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt
new file mode 100644
index 00000000000..b4350ec83c7
--- /dev/null
+++ b/client-spark/spark-4-shaded/src/main/resources/META-INF/licenses/LICENSE-protobuf.txt
@@ -0,0 +1,42 @@
+This license applies to all parts of Protocol Buffers except the following:
+
+ - Atomicops support for generic gcc, located in
+ src/google/protobuf/stubs/atomicops_internals_generic_gcc.h.
+ This file is copyrighted by Red Hat Inc.
+
+ - Atomicops support for AIX/POWER, located in
+ src/google/protobuf/stubs/atomicops_internals_aix.h.
+ This file is copyrighted by Bloomberg Finance LP.
+
+Copyright 2014, Google Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+Code generated by the Protocol Buffer compiler is owned by the owner
+of the input file used when generating it. This code is not
+standalone and requires a support library to be linked with it. This
+support library is itself covered by the above license.
\ No newline at end of file
diff --git a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala
index 7391d0529ae..684aba222df 100644
--- a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala
@@ -487,7 +487,7 @@ abstract class AbstractSource(conf: CelebornConf, role: String)
sum
}
- override def getMetrics(): String = {
+ override def getMetrics: String = {
var leftMetricsNum = metricsCapacity
val sb = new mutable.StringBuilder
leftMetricsNum = fillInnerMetricsSnapshot(getAndClearTimerMetrics(), leftMetricsNum, sb)
diff --git a/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala b/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala
index 1644e4e87b5..0f776f8df09 100644
--- a/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala
+++ b/common/src/test/scala/org/apache/celeborn/common/metrics/source/CelebornSourceSuite.scala
@@ -33,7 +33,7 @@ class CelebornSourceSuite extends CelebornFunSuite {
for (i <- 1 to 100) {
mockSource.updateHistogram(histogram, 10)
}
- val res = mockSource.getMetrics()
+ val res = mockSource.getMetrics
assert(res.contains("metrics_abc_Count"))
}
@@ -70,7 +70,7 @@ class CelebornSourceSuite extends CelebornFunSuite {
mockSource.stopTimer("Timer1", "key1")
mockSource.stopTimer("Timer2", "key2", user3)
- val res = mockSource.getMetrics()
+ val res = mockSource.getMetrics
var extraLabelsStr = extraLabels
if (extraLabels.nonEmpty) {
extraLabelsStr = extraLabels + ","
diff --git a/dev/dependencies.sh b/dev/dependencies.sh
index 625b12edf2f..9dd8930ce4c 100755
--- a/dev/dependencies.sh
+++ b/dev/dependencies.sh
@@ -169,7 +169,7 @@ case "$MODULE" in
SBT_PROJECT="celeborn-client-spark-2"
;;
"spark-3"*) # Match all versions starting with "spark-3"
- MVN_MODULES="client-spark/spark-3"
+ MVN_MODULES="client-spark/spark-3-4"
SBT_PROJECT="celeborn-client-spark-3"
;;
"flink-1.14")
diff --git a/dev/reformat b/dev/reformat
index 17f834ee21e..85c4cde0869 100755
--- a/dev/reformat
+++ b/dev/reformat
@@ -32,6 +32,8 @@ else
${PROJECT_DIR}/build/mvn spotless:apply -Pflink-1.20
${PROJECT_DIR}/build/mvn spotless:apply -Pspark-2.4
${PROJECT_DIR}/build/mvn spotless:apply -Pspark-3.3
+ ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-3.5
+ ${PROJECT_DIR}/build/mvn spotless:apply -Pspark-4.0
${PROJECT_DIR}/build/mvn spotless:apply -Paws
${PROJECT_DIR}/build/mvn spotless:apply -Pmr
${PROJECT_DIR}/build/mvn spotless:apply -Ptez
diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala
index d6d8c5f0fb2..fe3cc6eb791 100644
--- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala
+++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/http/api/v1/WorkerResource.scala
@@ -151,8 +151,8 @@ class WorkerResource extends ApiRequestContext {
new HandleResponse().success(success).message(finalMsg)
}
- private def toWorkerEventType(enum: EventTypeEnum): WorkerEventType = {
- enum match {
+ private def toWorkerEventType(eunmObj: EventTypeEnum): WorkerEventType = {
+ eunmObj match {
case EventTypeEnum.NONE => WorkerEventType.None
case EventTypeEnum.IMMEDIATELY => WorkerEventType.Immediately
case EventTypeEnum.DECOMMISSION => WorkerEventType.Decommission
diff --git a/pom.xml b/pom.xml
index bdbd1f0c441..b8f2cfb5510 100644
--- a/pom.xml
+++ b/pom.xml
@@ -991,7 +991,6 @@
${maven.plugin.scala.version}
- -Ywarn-unused-import
-unchecked
-deprecation
-feature
@@ -1373,7 +1372,7 @@
spark-3.0
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3-columnar-shuffle
client-spark/spark-3-shaded
@@ -1393,7 +1392,7 @@
spark-3.1
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3-columnar-shuffle
client-spark/spark-3-shaded
@@ -1413,7 +1412,7 @@
spark-3.2
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3-columnar-shuffle
client-spark/spark-3-shaded
@@ -1432,7 +1431,7 @@
spark-3.3
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3-columnar-shuffle
client-spark/spark-3-shaded
@@ -1451,7 +1450,7 @@
spark-3.4
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3-columnar-shuffle
client-spark/spark-3-shaded
@@ -1470,7 +1469,7 @@
spark-3.5
client-spark/common
- client-spark/spark-3
+ client-spark/spark-3-4
client-spark/spark-3-columnar-common
client-spark/spark-3.5-columnar-shuffle
client-spark/spark-3-shaded
@@ -1485,6 +1484,25 @@
+
+ spark-4.0
+
+ client-spark/common
+ client-spark/spark-3-4
+ client-spark/spark-3-columnar-common
+ client-spark/spark-4-columnar-shuffle
+ client-spark/spark-4-shaded
+ tests/spark-it
+
+
+ 1.8.0
+ 2.13.11
+ 2.13
+ 4.0.0-preview2
+ 1.5.5-6
+
+
+
jdk-8
diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala
index b51e29d061d..45aa86c5c4b 100644
--- a/project/CelebornBuild.scala
+++ b/project/CelebornBuild.scala
@@ -226,7 +226,8 @@ object CelebornCommonSettings {
val SCALA_2_12_17 = "2.12.17"
val SCALA_2_12_18 = "2.12.18"
val scala213 = "2.13.5"
- val ALL_SCALA_VERSIONS = Seq(SCALA_2_11_12, SCALA_2_12_10, SCALA_2_12_15, SCALA_2_12_17, SCALA_2_12_18, scala213)
+ val scala213_11 = "2.13.11"
+ val ALL_SCALA_VERSIONS = Seq(SCALA_2_11_12, SCALA_2_12_10, SCALA_2_12_15, SCALA_2_12_17, SCALA_2_12_18, scala213, scala213_11)
val DEFAULT_SCALA_VERSION = SCALA_2_12_18
@@ -421,6 +422,7 @@ object Utils {
case Some("spark-3.3") => Some(Spark33)
case Some("spark-3.4") => Some(Spark34)
case Some("spark-3.5") => Some(Spark35)
+ case Some("spark-4.0") => Some(Spark40)
case _ => None
}
@@ -725,7 +727,7 @@ object Spark24 extends SparkClientProjects {
object Spark30 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -739,7 +741,7 @@ object Spark30 extends SparkClientProjects {
object Spark31 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -753,7 +755,7 @@ object Spark31 extends SparkClientProjects {
object Spark32 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -767,7 +769,7 @@ object Spark32 extends SparkClientProjects {
object Spark33 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -784,7 +786,7 @@ object Spark33 extends SparkClientProjects {
object Spark34 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -798,7 +800,7 @@ object Spark34 extends SparkClientProjects {
object Spark35 extends SparkClientProjects {
- val sparkClientProjectPath = "client-spark/spark-3"
+ val sparkClientProjectPath = "client-spark/spark-3-4"
val sparkClientProjectName = "celeborn-client-spark-3"
val sparkClientShadedProjectPath = "client-spark/spark-3-shaded"
val sparkClientShadedProjectName = "celeborn-client-spark-3-shaded"
@@ -812,6 +814,23 @@ object Spark35 extends SparkClientProjects {
override val sparkColumnarShuffleVersion: String = "3.5"
}
+object Spark40 extends SparkClientProjects {
+
+ val sparkClientProjectPath = "client-spark/spark-3-4"
+ val sparkClientProjectName = "celeborn-client-spark-4"
+ val sparkClientShadedProjectPath = "client-spark/spark-4-shaded"
+ val sparkClientShadedProjectName = "celeborn-client-spark-4-shaded"
+
+ val lz4JavaVersion = "1.8.0"
+ val sparkProjectScalaVersion = "2.13.11"
+
+ val sparkVersion = "4.0.0-preview2"
+ val zstdJniVersion = "1.5.5-6"
+ val scalaBinaryVersion = "2.13"
+
+ override val sparkColumnarShuffleVersion: String = "4"
+}
+
trait SparkClientProjects {
val sparkClientProjectPath: String
diff --git a/tests/spark-it/pom.xml b/tests/spark-it/pom.xml
index 842d7a8ed98..d66aed131ad 100644
--- a/tests/spark-it/pom.xml
+++ b/tests/spark-it/pom.xml
@@ -115,13 +115,13 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
@@ -133,13 +133,13 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
@@ -151,13 +151,13 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
@@ -169,13 +169,13 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
@@ -187,13 +187,13 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
@@ -205,13 +205,31 @@
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test
org.apache.celeborn
- celeborn-client-spark-3_${scala.binary.version}
+ celeborn-client-spark-3-4_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+
+
+ spark-4.0
+
+
+ org.apache.celeborn
+ celeborn-client-spark-3-4_${scala.binary.version}
+ ${project.version}
+ test
+
+
+ org.apache.celeborn
+ celeborn-client-spark-3-4_${scala.binary.version}
${project.version}
test-jar
test
diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
index f460e768e3d..30a12535b59 100644
--- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
+++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala
@@ -690,7 +690,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
}
- case None =>
+ case _ =>
if (replicaReason == StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED.getValue) {
pushMergedDataCallback.onSuccess(
StatusCode.PUSH_DATA_SUCCESS_REPLICA_CONGESTED)
@@ -791,7 +791,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
} else {
pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
- case None =>
+ case _ =>
pushMergedDataCallback.onSuccess(StatusCode.SUCCESS)
}
case Failure(e) => pushMergedDataCallback.onFailure(e)