Skip to content

Commit

Permalink
refactor package
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil committed Dec 31, 2024
1 parent 7241494 commit fb28319
Show file tree
Hide file tree
Showing 14 changed files with 585 additions and 5 deletions.
8 changes: 4 additions & 4 deletions spark-doris-connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@

<properties>
<revision>24.0.0-SNAPSHOT</revision>
<!-- <spark.version>2.4.8</spark.version> -->
<!-- <spark.major.version>2.4</spark.major.version> -->
<!-- <scala.version>2.11.12</scala.version> -->
<!-- <scala.major.version>2.11</scala.major.version> -->
<spark.version>2.4.8</spark.version>
<spark.major.version>2.4</spark.major.version>
<scala.version>2.11.12</scala.version>
<scala.major.version>2.11</scala.major.version>
<libthrift.version>0.16.0</libthrift.version>
<arrow.version>15.0.2</arrow.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down
14 changes: 14 additions & 0 deletions spark-doris-connector/spark-doris-connector-it/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
Expand All @@ -81,6 +83,12 @@
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<properties>
<spark.version>2.4.8</spark.version>
<spark.major.version>2.4</spark.major.version>
<scala.version>2.12.18</scala.version>
<scala.major.version>2.12</scala.major.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.doris</groupId>
Expand All @@ -92,6 +100,12 @@
</profile>
<profile>
<id>spark-3-it</id>
<properties>
<spark.version>3.1.0</spark.version>
<spark.major.version>3.1</spark.major.version>
<scala.version>2.12.18</scala.version>
<scala.major.version>2.12</scala.major.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.doris</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.util.Retry
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.connector.write.DataWriter
import org.apache.spark.sql.types.StructType

import java.time.Duration
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.write

import org.apache.commons.lang3.StringUtils
import org.apache.doris.spark.client.write.{CopyIntoProcessor, DorisCommitter, DorisWriter, StreamLoadProcessor}
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.util.Retry
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.DataWriter
import org.apache.spark.sql.types.StructType

import java.time.Duration
import scala.collection.mutable
import scala.util.{Failure, Success}

class DorisDataWriter(config: DorisConfig, schema: StructType, partitionId: Int, taskId: Long, epochId: Long = -1) extends DataWriter[InternalRow] with Logging {

private val (writer: DorisWriter[InternalRow], committer: DorisCommitter) =
config.getValue(DorisOptions.LOAD_MODE) match {
case "stream_load" => (new StreamLoadProcessor(config, schema), new StreamLoadProcessor(config, schema))
case "copy_into" => (new CopyIntoProcessor(config, schema), new CopyIntoProcessor(config, schema))
case mode => throw new IllegalArgumentException("Unsupported load mode: " + mode)
}

private val batchSize = config.getValue(DorisOptions.DORIS_SINK_BATCH_SIZE)

private val batchIntervalMs = config.getValue(DorisOptions.DORIS_SINK_BATCH_INTERVAL_MS)

private val retries = config.getValue(DorisOptions.DORIS_SINK_MAX_RETRIES)

private val twoPhaseCommitEnabled = config.getValue(DorisOptions.DORIS_SINK_ENABLE_2PC)

private var currentBatchCount = 0

private val committedMessages = mutable.Buffer[String]()

private lazy val recordBuffer = mutable.Buffer[InternalRow]()

override def write(record: InternalRow): Unit = {
if (currentBatchCount >= batchSize) {
val txnId = Some(writer.stop())
if (txnId.isDefined) {
committedMessages += txnId.get
currentBatchCount = 0
if (retries != 0) {
recordBuffer.clear()
}
} else {
throw new Exception()
}
}
loadWithRetries(record)
}

override def commit(): WriterCommitMessage = {
val txnId = writer.stop()
if (twoPhaseCommitEnabled) {
if (StringUtils.isNotBlank(txnId)) {
committedMessages += txnId
} else {
throw new Exception()
}
}
DorisWriterCommitMessage(partitionId, taskId, epochId, committedMessages.toArray)
}

override def abort(): Unit = {
if (committedMessages.nonEmpty) {
committedMessages.foreach(msg => committer.abort(msg))
}
close()
}

override def close(): Unit = {
if (writer != null) {
writer.close()
}
}

@throws[Exception]
private def loadWithRetries(record: InternalRow): Unit = {
var isRetrying = false
Retry.exec[Unit, Exception](retries, Duration.ofMillis(batchIntervalMs.toLong), log) {
if (isRetrying) {
do {
writer.load(recordBuffer(currentBatchCount))
currentBatchCount += 1
} while (currentBatchCount < recordBuffer.size)
isRetrying = false
}
writer.load(record)
currentBatchCount += 1
} {
isRetrying = true
currentBatchCount = 0
} match {
case Success(_) => if (retries > 0) recordBuffer += record
case Failure(exception) => throw new Exception(exception)
}

}

}

case class DorisWriterCommitMessage(partitionId: Int, taskId: Long, epochId: Long, commitMessages: Array[String]) extends WriterCommitMessage
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.write

import org.apache.doris.spark.config.DorisConfig
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory}
import org.apache.spark.sql.types.StructType

class DorisDataWriterFactory(config: DorisConfig, schema: StructType) extends DataWriterFactory with StreamingDataWriterFactory {

// for batch write
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
new DorisDataWriter(config, schema, partitionId, taskId)
}

// for streaming write
override def createWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
new DorisDataWriter(config, schema, partitionId, taskId, epochId)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.write

import org.apache.doris.spark.client.write.{CopyIntoProcessor, DorisCommitter, StreamLoadProcessor}
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory

class DorisWrite(config: DorisConfig, schema: StructType) extends BatchWrite with StreamingWrite {

private val LOG = LoggerFactory.getLogger(classOf[DorisWrite])

private val committer: DorisCommitter = config.getValue(DorisOptions.LOAD_MODE) match {
case "stream_load" => new StreamLoadProcessor(config, schema)
case "copy_into" => new CopyIntoProcessor(config, schema)
case _ => throw new IllegalArgumentException()
}

private var lastCommittedEpoch: Option[Long] = None

private val committedEpochLock = new AnyRef

override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = {
new DorisDataWriterFactory(config, schema)
}

// for batch write
override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
if (writerCommitMessages != null && writerCommitMessages.nonEmpty) {
writerCommitMessages.filter(_ != null)
.foreach(_.asInstanceOf[DorisWriterCommitMessage].commitMessages.foreach(committer.commit))
}
}

// for batch write
override def abort(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
LOG.info("writerCommitMessages size: " + writerCommitMessages.length)
writerCommitMessages.foreach(x => println(x))
if (writerCommitMessages.exists(_ != null) && writerCommitMessages.nonEmpty) {
writerCommitMessages.foreach(x => println(x))
writerCommitMessages.filter(_ != null)
.foreach(_.asInstanceOf[DorisWriterCommitMessage].commitMessages.foreach(committer.abort))
}
}

override def useCommitCoordinator(): Boolean = true

override def createStreamingWriterFactory(physicalWriteInfo: PhysicalWriteInfo): StreamingDataWriterFactory = {
new DorisDataWriterFactory(config, schema)
}

// for streaming write
override def commit(epochId: Long, writerCommitMessages: Array[WriterCommitMessage]): Unit = {
committedEpochLock.synchronized {
if (lastCommittedEpoch.isEmpty || epochId > lastCommittedEpoch.get && writerCommitMessages.exists(_ != null)) {
writerCommitMessages.foreach(_.asInstanceOf[DorisWriterCommitMessage].commitMessages.foreach(committer.commit))
lastCommittedEpoch = Some(epochId)
}
}
}

// for streaming write
override def abort(epochId: Long, writerCommitMessages: Array[WriterCommitMessage]): Unit = {
committedEpochLock.synchronized {
if ((lastCommittedEpoch.isEmpty || epochId > lastCommittedEpoch.get) && writerCommitMessages.exists(_ != null)) {
writerCommitMessages.foreach(_.asInstanceOf[DorisWriterCommitMessage].commitMessages.foreach(committer.abort))
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.write

import org.apache.doris.spark.config.DorisConfig
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}
import org.apache.spark.sql.types.StructType

class DorisWriteBuilder(config: DorisConfig, schema: StructType) extends WriteBuilder {

override def buildForBatch(): BatchWrite = {
new DorisWrite(config, schema)
}

override def buildForStreaming(): StreamingWrite = {
new DorisWrite(config, schema)
}

}
Loading

0 comments on commit fb28319

Please sign in to comment.