Skip to content

Commit

Permalink
support partial limit push down
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil committed Jan 10, 2025
1 parent 20b1228 commit 4d9c36d
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class DorisReaderPartition implements Serializable {
private final String opaquedQueryPlan;
private final String[] readColumns;
private final String[] filters;
private final Integer limit;
private final DorisConfig config;

public DorisReaderPartition(String database, String table, Backend backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, String[] filters, DorisConfig config) {
Expand All @@ -42,6 +43,19 @@ public DorisReaderPartition(String database, String table, Backend backend, Long
this.opaquedQueryPlan = opaquedQueryPlan;
this.readColumns = readColumns;
this.filters = filters;
this.limit = -1;
this.config = config;
}

public DorisReaderPartition(String database, String table, Backend backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, String[] filters, Integer limit, DorisConfig config) {
this.database = database;
this.table = table;
this.backend = backend;
this.tablets = tablets;
this.opaquedQueryPlan = opaquedQueryPlan;
this.readColumns = readColumns;
this.filters = filters;
this.limit = limit;
this.config = config;
}

Expand Down Expand Up @@ -78,6 +92,10 @@ public String[] getFilters() {
return filters;
}

public Integer getLimit() {
return limit;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Expand All @@ -89,11 +107,12 @@ public boolean equals(Object o) {
&& Objects.equals(opaquedQueryPlan, that.opaquedQueryPlan)
&& Objects.deepEquals(readColumns, that.readColumns)
&& Objects.deepEquals(filters, that.filters)
&& Objects.equals(limit, that.limit)
&& Objects.equals(config, that.config);
}

@Override
public int hashCode() {
return Objects.hash(database, table, backend, Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), Arrays.hashCode(filters), config);
return Objects.hash(database, table, backend, Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), Arrays.hashCode(filters), limit, config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public abstract class AbstractThriftReader extends DorisReader {

private final Thread asyncThread;

private int readCount = 0;

protected AbstractThriftReader(DorisReaderPartition partition) throws Exception {
super(partition);
this.frontend = new DorisFrontendClient(config);
Expand Down Expand Up @@ -132,6 +134,9 @@ private void runAsync() throws DorisException, InterruptedException {

@Override
public boolean hasNext() throws DorisException {
if (partition.getLimit() > 0 && readCount >= partition.getLimit()) {
return false;
}
boolean hasNext = false;
if (isAsync && asyncThread != null && asyncThread.isAlive()) {
if (rowBatch == null || !rowBatch.hasNext()) {
Expand Down Expand Up @@ -186,6 +191,9 @@ public Object next() throws DorisException {
if (!hasNext()) {
throw new RuntimeException("No more elements");
}
if (partition.getLimit() > 0) {
readCount++;
}
return rowBatch.next().toArray();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ protected String generateQuerySql(DorisReaderPartition partition) throws OptionR
String fullTableName = config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
String tablets = String.format("TABLET(%s)", StringUtils.join(partition.getTablets(), ","));
String predicates = partition.getFilters().length == 0 ? "" : " WHERE " + String.join(" AND ", partition.getFilters());
return String.format("SELECT %s FROM %s %s%s", columns, fullTableName, tablets, predicates);
String limit = partition.getLimit() > 0 ? " LIMIT " + partition.getLimit() : "";
return String.format("SELECT %s FROM %s %s%s%s", columns, fullTableName, tablets, predicates, limit);
}

protected Schema processDorisSchema(DorisReaderPartition partition, final Schema originSchema) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.doris.spark.rest.models.QueryPlan;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.DorisDialects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -38,6 +40,8 @@

public class ReaderPartitionGenerator {

private static final Logger LOG = LoggerFactory.getLogger(ReaderPartitionGenerator.class);

/*
* for spark 2
*/
Expand All @@ -51,14 +55,14 @@ public static DorisReaderPartition[] generatePartitions(DorisConfig config) thro
}
String[] filters = config.contains(DorisOptions.DORIS_FILTER_QUERY) ?
config.getValue(DorisOptions.DORIS_FILTER_QUERY).split("\\.") : new String[0];
return generatePartitions(config, originReadCols, filters);
return generatePartitions(config, originReadCols, filters, -1);
}

/*
* for spark 3
*/
public static DorisReaderPartition[] generatePartitions(DorisConfig config,
String[] fields, String[] filters) throws Exception {
String[] fields, String[] filters, Integer limit) throws Exception {
DorisFrontendClient frontend = new DorisFrontendClient(config);
String fullTableName = config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
String[] tableParts = fullTableName.split("\\.");
Expand All @@ -69,13 +73,15 @@ public static DorisReaderPartition[] generatePartitions(DorisConfig config,
originReadCols = frontend.getTableAllColumns(db, table);
}
String[] finalReadColumns = getFinalReadColumns(config, frontend, db, table, originReadCols);
String sql = "SELECT " + String.join(",", finalReadColumns) + " FROM `" + db + "`.`" + table + "`" +
(filters.length == 0 ? "" : " WHERE " + String.join(" AND ", filters));
String finalReadColumnString = String.join(",", finalReadColumns);
String finalWhereClauseString = filters.length == 0 ? "" : " WHERE " + String.join(" AND ", filters);
String sql = "SELECT " + finalReadColumnString + " FROM `" + db + "`.`" + table + "`" + finalWhereClauseString;
LOG.info("get query plan for table " + db + "." + table + ", sql: " + sql);
QueryPlan queryPlan = frontend.getQueryPlan(db, table, sql);
Map<String, List<Long>> beToTablets = mappingBeToTablets(queryPlan);
int maxTabletSize = config.getValue(DorisOptions.DORIS_TABLET_SIZE);
return distributeTabletsToPartitions(db, table, beToTablets, queryPlan.getOpaqued_query_plan(), maxTabletSize,
finalReadColumns, filters, config);
finalReadColumns, filters, config, limit);
}

@VisibleForTesting
Expand Down Expand Up @@ -106,7 +112,7 @@ private static DorisReaderPartition[] distributeTabletsToPartitions(String datab
Map<String, List<Long>> beToTablets,
String opaquedQueryPlan, int maxTabletSize,
String[] readColumns, String[] predicates,
DorisConfig config) {
DorisConfig config, Integer limit) {
List<DorisReaderPartition> partitions = new ArrayList<>();
beToTablets.forEach((backendStr, tabletIds) -> {
List<Long> distinctTablets = new ArrayList<>(new HashSet<>(tabletIds));
Expand All @@ -115,7 +121,7 @@ private static DorisReaderPartition[] distributeTabletsToPartitions(String datab
Long[] tablets = distinctTablets.subList(offset, Math.min(offset + maxTabletSize, distinctTablets.size())).toArray(new Long[0]);
offset += maxTabletSize;
partitions.add(new DorisReaderPartition(database, table, new Backend(backendStr), tablets,
opaquedQueryPlan, readColumns, predicates, config));
opaquedQueryPlan, readColumns, predicates, limit, config));
}
});
return partitions.toArray(new DorisReaderPartition[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class AbstractDorisScan(config: DorisConfig, schema: StructType) extend
override def toBatch: Batch = this

override def planInputPartitions(): Array[InputPartition] = {
ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters()).map(toInputPartition)
ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters(), getLimit).map(toInputPartition)
}


Expand All @@ -44,10 +44,12 @@ abstract class AbstractDorisScan(config: DorisConfig, schema: StructType) extend
}

private def toInputPartition(rp: DorisReaderPartition): DorisInputPartition =
DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters)
DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters, rp.getLimit)

protected def compiledFilters(): Array[String]

protected def getLimit: Int = -1

}

case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String]) extends InputPartition
case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String], limit: Int = -1) extends InputPartition
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DorisPartitionReader(inputPartition: InputPartition, schema: StructType, m
private implicit def toReaderPartition(inputPart: DorisInputPartition): DorisReaderPartition = {
val tablets = inputPart.tablets.map(java.lang.Long.valueOf)
new DorisReaderPartition(inputPart.database, inputPart.table, inputPart.backend, tablets,
inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, config)
inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, inputPart.limit, config)
}

private lazy val reader: DorisReader = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.doris.spark.read
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.read.expression.V2ExpressionBuilder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, SupportsPushDownV2Filters}
import org.apache.spark.sql.types.StructType

class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema)
with SupportsPushDownV2Filters {
with SupportsPushDownV2Filters
with SupportsPushDownLimit {

private var pushDownPredicates: Array[Predicate] = Array[Predicate]()

private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates)
private var limitSize: Int = -1

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates, limitSize)

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
Expand All @@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisSca

override def pushedPredicates(): Array[Predicate] = pushDownPredicates

override def pushLimit(i: Int): Boolean = {
limitSize = i
true
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.types.StructType

class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with Logging {
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
}

override protected def getLimit: Int = limit
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.doris.spark.read
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.read.expression.V2ExpressionBuilder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, SupportsPushDownV2Filters}
import org.apache.spark.sql.types.StructType

class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema)
with SupportsPushDownV2Filters {
with SupportsPushDownV2Filters
with SupportsPushDownLimit {

private var pushDownPredicates: Array[Predicate] = Array[Predicate]()

private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates)
private var limitSize: Int = -1

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates, limitSize)

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
Expand All @@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisSca

override def pushedPredicates(): Array[Predicate] = pushDownPredicates

override def pushLimit(i: Int): Boolean = {
limitSize = i
true
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.types.StructType

class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with Logging {
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
}

override protected def getLimit: Int = limit
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.doris.spark.read
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.read.expression.V2ExpressionBuilder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, SupportsPushDownV2Filters}
import org.apache.spark.sql.types.StructType

class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema)
with SupportsPushDownV2Filters {
with SupportsPushDownV2Filters
with SupportsPushDownLimit {

private var pushDownPredicates: Array[Predicate] = Array[Predicate]()

private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates)
private var limitSize: Int = -1

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates, limitSize)

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
Expand All @@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisSca

override def pushedPredicates(): Array[Predicate] = pushDownPredicates

override def pushLimit(i: Int): Boolean = {
limitSize = i
true
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.types.StructType

class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with Logging {
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
}

override protected def getLimit: Int = limit
}

0 comments on commit 4d9c36d

Please sign in to comment.