From c00b6a350305fd5396d91c989cf90a0190221746 Mon Sep 17 00:00:00 2001 From: Ralph Gasser Date: Fri, 1 Dec 2023 08:37:12 +0100 Subject: [PATCH] Adds early check to prevent NULL insert in NON-NULL column and added a test case that tests for it. --- .../dbms/queries/binding/GrpcQueryBinder.kt | 34 ++++++++++++------- .../cottontail/server/grpc/DMLServiceTest.kt | 27 +++++++++++++++ .../vitrivr/cottontail/test/GrpcTestUtils.kt | 14 ++++---- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/queries/binding/GrpcQueryBinder.kt b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/queries/binding/GrpcQueryBinder.kt index c8771d365..d31b71fa8 100644 --- a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/queries/binding/GrpcQueryBinder.kt +++ b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/queries/binding/GrpcQueryBinder.kt @@ -139,8 +139,12 @@ object GrpcQueryBinder { } /* Create and return INSERT-clause. */ - val record = TupleBinding(-1L, columns, values, this@QueryContext.bindings) - return InsertLogicalOperatorNode(this@QueryContext.nextGroupId(), entityTx, mutableListOf(record)) + try { + val record = TupleBinding(-1L, columns, values, this@QueryContext.bindings) + return InsertLogicalOperatorNode(this@QueryContext.nextGroupId(), entityTx, mutableListOf(record)) + } catch (e: IllegalArgumentException) { + throw DatabaseException.ValidationException("Provided data could not be bound to INSERT due to validation error: ${e.message}") + } } /** @@ -163,17 +167,21 @@ object GrpcQueryBinder { } /* Parse records to BATCH INSERT. */ - val tuples: MutableList = insert.insertsList.map { ins -> - TupleBinding(-1L, columns, Array(ins.valuesCount) { i -> - val literal = ins.valuesList[i].toValue() - if (literal == null) { - this@QueryContext.bindings.bindNull(columns[i].type) - } else { - this@QueryContext.bindings.bind(literal) - } - }, this@QueryContext.bindings) - }.toMutableList() - return InsertLogicalOperatorNode(this@QueryContext.nextGroupId(), entityTx, tuples) + try { + val tuples: MutableList = insert.insertsList.map { ins -> + TupleBinding(-1L, columns, Array(ins.valuesCount) { i -> + val literal = ins.valuesList[i].toValue() + if (literal == null) { + this@QueryContext.bindings.bindNull(columns[i].type) + } else { + this@QueryContext.bindings.bind(literal) + } + }, this@QueryContext.bindings) + }.toMutableList() + return InsertLogicalOperatorNode(this@QueryContext.nextGroupId(), entityTx, tuples) + } catch (e: IllegalArgumentException) { + throw DatabaseException.ValidationException("Provided data could not be bound to BATCH INSERT due to validation error: ${e.message}") + } } /** diff --git a/cottontaildb-dbms/src/test/kotlin/org/vitrivr/cottontail/server/grpc/DMLServiceTest.kt b/cottontaildb-dbms/src/test/kotlin/org/vitrivr/cottontail/server/grpc/DMLServiceTest.kt index 6bc8ee3cb..94cb2adb2 100644 --- a/cottontaildb-dbms/src/test/kotlin/org/vitrivr/cottontail/server/grpc/DMLServiceTest.kt +++ b/cottontaildb-dbms/src/test/kotlin/org/vitrivr/cottontail/server/grpc/DMLServiceTest.kt @@ -5,6 +5,7 @@ import io.grpc.StatusRuntimeException import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.math3.random.JDKRandomGenerator import org.junit.jupiter.api.* +import org.junit.platform.commons.function.Try.success import org.vitrivr.cottontail.client.language.basics.Direction import org.vitrivr.cottontail.client.language.basics.predicate.Compare import org.vitrivr.cottontail.client.language.ddl.CreateEntity @@ -316,4 +317,30 @@ class DMLServiceTest : AbstractClientTest() { this.client.commit(txId) Assertions.assertEquals(0, countElements(this.client, entityName)!!.toInt()) } + + /** + * Performs a large number of INSERTs in a single transaction and checks the count before and afterwards. + */ + @Test + fun testInvalidNullInsert() { + val entries = TEST_COLLECTION_SIZE * 5 + val batchSize = 1000 + val stringLength = 200 + + /* Start large insert. */ + val txId = this.client.begin() + try { + repeat(entries / batchSize) { i -> + val batch = BatchInsert(TEST_ENTITY_NAME).columns(INT_COLUMN_NAME, STRING_COLUMN_NAME, DOUBLE_COLUMN_NAME) + repeat(batchSize) { j -> + batch.any(i * j, null, 1.0) + } + this.client.insert(batch.txId(txId)) + } + } catch (e: StatusRuntimeException) { + success("Creating entity ${TEST_ENTITY_NAME.fqn} failed with exception " + e.message) + } finally { + this.client.rollback(txId) + } + } } \ No newline at end of file diff --git a/cottontaildb-dbms/src/testFixtures/kotlin/org/vitrivr/cottontail/test/GrpcTestUtils.kt b/cottontaildb-dbms/src/testFixtures/kotlin/org/vitrivr/cottontail/test/GrpcTestUtils.kt index 9f59e37b7..4c11de196 100644 --- a/cottontaildb-dbms/src/testFixtures/kotlin/org/vitrivr/cottontail/test/GrpcTestUtils.kt +++ b/cottontaildb-dbms/src/testFixtures/kotlin/org/vitrivr/cottontail/test/GrpcTestUtils.kt @@ -54,10 +54,10 @@ object GrpcTestUtils { */ fun createTestEntity(client: SimpleClient) { val create = CreateEntity(TestConstants.TEST_ENTITY_NAME) - .column(Name.ColumnName.parse(ID_COLUMN_NAME), Types.Long, autoIncrement = true) - .column(Name.ColumnName.parse(STRING_COLUMN_NAME), Types.String) - .column(Name.ColumnName.parse(INT_COLUMN_NAME), Types.Int) - .column(Name.ColumnName.parse(DOUBLE_COLUMN_NAME), Types.Double) + .column(Name.ColumnName.parse(ID_COLUMN_NAME), Types.Long, nullable = false, autoIncrement = true) + .column(Name.ColumnName.parse(STRING_COLUMN_NAME), Types.String, nullable = false) + .column(Name.ColumnName.parse(INT_COLUMN_NAME), Types.Int, nullable = true) + .column(Name.ColumnName.parse(DOUBLE_COLUMN_NAME), Types.Double, nullable = true) client.create(create) } @@ -69,9 +69,9 @@ object GrpcTestUtils { fun createTestVectorEntity(client: SimpleClient) { val create = CreateEntity(TestConstants.TEST_VECTOR_ENTITY_NAME) .column(Name.ColumnName.parse(ID_COLUMN_NAME), Types.Long, autoIncrement = true) - .column(Name.ColumnName.parse(STRING_COLUMN_NAME), Types.String) - .column(Name.ColumnName.parse(INT_COLUMN_NAME), Types.Int) - .column(Name.ColumnName.parse(TWOD_COLUMN_NAME), Types.FloatVector(2)) + .column(Name.ColumnName.parse(STRING_COLUMN_NAME), Types.String, nullable = false) + .column(Name.ColumnName.parse(INT_COLUMN_NAME), Types.Int, nullable = true) + .column(Name.ColumnName.parse(TWOD_COLUMN_NAME), Types.FloatVector(2), nullable = true) client.create(create) }