diff --git a/pkg/spanner/client.go b/pkg/spanner/client.go index 90a68d9..31e0609 100644 --- a/pkg/spanner/client.go +++ b/pkg/spanner/client.go @@ -311,6 +311,13 @@ func (c *Client) ExecuteMigrations(ctx context.Context, migrations Migrations, l } } case statementKindDML: + if _, err := c.ApplyDML(ctx, m.Statements); err != nil { + return &Error{ + Code: ErrorCodeExecuteMigrations, + err: err, + } + } + case statementKindPartitionedDML: if _, err := c.ApplyPartitionedDML(ctx, m.Statements); err != nil { return &Error{ Code: ErrorCodeExecuteMigrations, diff --git a/pkg/spanner/migration.go b/pkg/spanner/migration.go index c47cc95..daa709c 100644 --- a/pkg/spanner/migration.go +++ b/pkg/spanner/migration.go @@ -42,12 +42,19 @@ var ( MigrationNameRegex = regexp.MustCompile(`[a-zA-Z0-9_\-]+`) - dmlRegex = regexp.MustCompile("^(UPDATE|DELETE)[\t\n\f\r ].*") + dmlAnyRegex = regexp.MustCompile("(?i)^(UPDATE|DELETE|INSERT)[\t\n\f\r ].*") + + // Instead of a regex to determine if a DML statement is a PartitionedDML we check if it's + // not a PartitionedDML by checking for INSERT statements. + // An improvement would be to use spanners algo to distinguish between DML types. + notPartitionedDmlRegex = regexp.MustCompile(`(?i)^INSERT`) + ) const ( statementKindDDL statementKind = "DDL" statementKindDML statementKind = "DML" + statementKindPartitionedDML statementKind = "PartitionedDML" ) type ( @@ -143,27 +150,55 @@ func inspectStatementsKind(statements []string) (statementKind, error) { kindMap := map[statementKind]uint64{ statementKindDDL: 0, statementKindDML: 0, + statementKindPartitionedDML: 0, } for _, s := range statements { - if isDML(s) { - kindMap[statementKindDML]++ - } else { - kindMap[statementKindDDL]++ - } + kindMap[getStatementKind(s)]++ } - if kindMap[statementKindDML] > 0 { - if kindMap[statementKindDDL] > 0 { - return "", errors.New("Cannot specify DDL and DML at same migration file.") - } + if distinctKind(kindMap, statementKindDDL) { + return statementKindDDL, nil + } + if distinctKind(kindMap, statementKindDML) { return statementKindDML, nil } - return statementKindDDL, nil + if distinctKind(kindMap, statementKindPartitionedDML) { + return statementKindPartitionedDML, nil + } + + return "", errors.New("Cannot specify DDL and DML in the same migration file") +} + +func distinctKind(kindMap map[statementKind]uint64, kind statementKind) bool { + target := kindMap[kind] + + var total uint64 + for k:= range kindMap { + total = total + kindMap[k] + } + + return target == total +} + +func getStatementKind(statement string) statementKind { + if isPartitionedDMLOnly(statement){ + return statementKindPartitionedDML + } + + if isDMLAny(statement){ + return statementKindDML + } + + return statementKindDDL +} + +func isPartitionedDMLOnly(statement string) bool { + return isDMLAny(statement) && !notPartitionedDmlRegex.Match([]byte(statement)) } -func isDML(statement string) bool { - return dmlRegex.Match([]byte(statement)) +func isDMLAny(statement string) bool { + return dmlAnyRegex.Match([]byte(statement)) } diff --git a/pkg/spanner/migration_test.go b/pkg/spanner/migration_test.go index 4bc875e..0c5adba 100644 --- a/pkg/spanner/migration_test.go +++ b/pkg/spanner/migration_test.go @@ -17,17 +17,21 @@ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package spanner_test +package spanner import ( "path/filepath" "testing" +) - "github.com/cloudspannerecosystem/wrench/pkg/spanner" +const ( + TestStmtDDL = "ALTER TABLE Singers ADD COLUMN Foo STRING(MAX)" + TestStmtPartitionedDML = "UPDATE Singers SET FirstName = \"Bar\" WHERE SingerID = \"1\"" + TestStmtDML = "INSERT INTO Singers(FirstName) VALUES(\"Bar\")" ) func TestLoadMigrations(t *testing.T) { - ms, err := spanner.LoadMigrations(filepath.Join("testdata", "migrations")) + ms, err := LoadMigrations(filepath.Join("testdata", "migrations")) if err != nil { t.Fatal(err) } @@ -63,3 +67,103 @@ func TestLoadMigrations(t *testing.T) { } } } + +func Test_getStatementKind(t *testing.T) { + tests := []struct { + name string + statement string + want statementKind + }{ + { + "ALTER statement is DDL", + TestStmtDDL, + statementKindDDL, + }, + { + "UPDATE statement is PartitionedDML", + TestStmtPartitionedDML, + statementKindPartitionedDML, + }, + { + "INSERT statement is DML", + TestStmtDML, + statementKindDML, + }, + { + "lowercase insert statement is DML", + TestStmtDML, + statementKindDML, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getStatementKind(tt.statement); got != tt.want { + t.Errorf("getStatementKind() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_inspectStatementsKind(t *testing.T) { + tests := []struct { + name string + statements []string + want statementKind + wantErr bool + }{ + { + "Only DDL returns DDL", + []string{TestStmtDDL, TestStmtDDL}, + statementKindDDL, + false, + }, + { + "Only PartitionedDML returns PartitionedDML", + []string{TestStmtPartitionedDML, TestStmtPartitionedDML}, + statementKindPartitionedDML, + false, + }, + { + "Only DML returns DML", + []string{TestStmtDML, TestStmtDML}, + statementKindDDL, + false, + }, + { + "DML and DDL returns error", + []string{TestStmtDDL, TestStmtDML}, + "", + true, + }, + { + "DML and PartitionedDML returns error", + []string{TestStmtDML, TestStmtPartitionedDML}, + "", + true, + }, + { + "DDL and PartitionedDML returns error", + []string{TestStmtDDL, TestStmtPartitionedDML}, + "", + true, + }, + { + "no statements defaults to DDL as before", + []string{}, + statementKindDDL, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := inspectStatementsKind(tt.statements) + if (err != nil) != tt.wantErr { + t.Errorf("inspectStatementsKind() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("inspectStatementsKind() got = %v, want %v", got, tt.want) + } + }) + } +}