From 10486147cd3ce90ba7abb9df3c29e265d5206463 Mon Sep 17 00:00:00 2001 From: apstndb <803393+apstndb@users.noreply.github.com> Date: Wed, 18 Dec 2024 20:13:00 +0900 Subject: [PATCH] Refactor SQL() methods (#173) * Refactor SQL() methods * Refine UnaryExpr * Modify a comment --- ast/sql.go | 505 ++++++++++++----------------------------------------- 1 file changed, 115 insertions(+), 390 deletions(-) diff --git a/ast/sql.go b/ast/sql.go index d05dc8dd..19fc07e5 100644 --- a/ast/sql.go +++ b/ast/sql.go @@ -34,7 +34,7 @@ func sqlOpt[T interface { // strOpt outputs: // // when pred == true: s -// else : empty string +// else : "" // // This function corresponds to {{if pred}}s{{end}} in ast.go func strOpt(pred bool, s string) string { @@ -44,6 +44,19 @@ func strOpt(pred bool, s string) string { return "" } +// strIfElse outputs: +// +// when pred == true: ifStr +// else : elseStr +// +// This function corresponds to {{if pred}}ifStr{{else}}elseStr{{end}} in ast.go +func strIfElse(pred bool, ifStr string, elseStr string) string { + if pred { + return ifStr + } + return elseStr +} + // sqlJoin outputs joined string of SQL() of all elems by sep. // This function corresponds to sqlJoin in ast.go func sqlJoin[T Node](elems []T, sep string) string { @@ -57,6 +70,11 @@ func sqlJoin[T Node](elems []T, sep string) string { return b.String() } +// formatBoolUpper formats bool value as uppercase. +func formatBoolUpper(b bool) string { + return strings.ToUpper(strconv.FormatBool(b)) +} + type prec int const ( @@ -148,12 +166,7 @@ func (q *Query) SQL() string { } func (h *Hint) SQL() string { - sql := "@{" + h.Records[0].SQL() - for _, r := range h.Records[1:] { - sql += ", " + r.SQL() - } - sql += "}" - return sql + return "@{" + sqlJoin(h.Records, ", ") + "}" } func (h *HintRecord) SQL() string { @@ -161,11 +174,7 @@ func (h *HintRecord) SQL() string { } func (w *With) SQL() string { - sql := "WITH " + w.CTEs[0].SQL() - for _, c := range w.CTEs[1:] { - sql += ", " + c.SQL() - } - return sql + return "WITH " + sqlJoin(w.CTEs, ", ") } func (c *CTE) SQL() string { @@ -232,11 +241,7 @@ func (w *Where) SQL() string { } func (g *GroupBy) SQL() string { - sql := "GROUP BY " + g.Exprs[0].SQL() - for _, e := range g.Exprs[1:] { - sql += ", " + e.SQL() - } - return sql + return "GROUP BY " + sqlJoin(g.Exprs, ", ") } func (h *Having) SQL() string { @@ -244,22 +249,13 @@ func (h *Having) SQL() string { } func (o *OrderBy) SQL() string { - sql := "ORDER BY " + o.Items[0].SQL() - for _, item := range o.Items[1:] { - sql += ", " + item.SQL() - } - return sql + return "ORDER BY " + sqlJoin(o.Items, ", ") } func (o *OrderByItem) SQL() string { - sql := o.Expr.SQL() - if o.Collate != nil { - sql += " " + o.Collate.SQL() - } - if o.Dir != "" { - sql += " " + string(o.Dir) - } - return sql + return o.Expr.SQL() + + sqlOpt(" ", o.Collate, "") + + strOpt(o.Dir != "", " "+string(o.Dir)) } func (c *Collate) SQL() string { @@ -267,11 +263,8 @@ func (c *Collate) SQL() string { } func (l *Limit) SQL() string { - sql := "LIMIT " + l.Count.SQL() - if l.Offset != nil { - sql += " " + l.Offset.SQL() - } - return sql + return "LIMIT " + l.Count.SQL() + + sqlOpt(" ", l.Offset, "") } func (o *Offset) SQL() string { @@ -307,25 +300,14 @@ func (u *Unnest) SQL() string { } func (w *WithOffset) SQL() string { - sql := "WITH OFFSET" - if w.As != nil { - sql += " " + w.As.SQL() - } - return sql + return "WITH OFFSET" + sqlOpt(" ", w.As, "") } func (t *TableName) SQL() string { - sql := t.Table.SQL() - if t.Hint != nil { - sql += " " + t.Hint.SQL() - } - if t.As != nil { - sql += " " + t.As.SQL() - } - if t.Sample != nil { - sql += " " + t.Sample.SQL() - } - return sql + return t.Table.SQL() + + sqlOpt(" ", t.Hint, "") + + sqlOpt(" ", t.As, "") + + sqlOpt(" ", t.Sample, "") } func (e *PathTableExpr) SQL() string { @@ -337,38 +319,23 @@ func (e *PathTableExpr) SQL() string { } func (s *SubQueryTableExpr) SQL() string { - sql := "(" + s.Query.SQL() + ")" - if s.As != nil { - sql += " " + s.As.SQL() - } - if s.Sample != nil { - sql += " " + s.Sample.SQL() - } - return sql + return "(" + s.Query.SQL() + ")" + + sqlOpt(" ", s.As, "") + + sqlOpt(" ", s.Sample, "") } func (p *ParenTableExpr) SQL() string { - sql := "(" + p.Source.SQL() + ")" - if p.Sample != nil { - sql += " " + p.Sample.SQL() - } - return sql + return "(" + p.Source.SQL() + ")" + + sqlOpt(" ", p.Sample, "") } func (j *Join) SQL() string { - sql := j.Left.SQL() - if j.Op != CommaJoin { - sql += " " - } - sql += string(j.Op) + " " - if j.Hint != nil { - sql += j.Hint.SQL() + " " - } - sql += j.Right.SQL() - if j.Cond != nil { - sql += " " + j.Cond.SQL() - } - return sql + return j.Left.SQL() + + strOpt(j.Op != CommaJoin, " ") + + string(j.Op) + " " + + sqlOpt("", j.Hint, " ") + + j.Right.SQL() + + sqlOpt(" ", j.Cond, "") } func (o *On) SQL() string { @@ -376,12 +343,7 @@ func (o *On) SQL() string { } func (u *Using) SQL() string { - sql := "USING (" + u.Idents[0].SQL() - for _, id := range u.Idents[1:] { - sql += ", " + id.SQL() - } - sql += ")" - return sql + return "USING (" + sqlJoin(u.Idents, ", ") + ")" } func (t *TableSample) SQL() string { @@ -400,29 +362,22 @@ func (t *TableSampleSize) SQL() string { func (b *BinaryExpr) SQL() string { p := exprPrec(b) - sql := paren(p, b.Left) - sql += " " + string(b.Op) + " " - sql += paren(p, b.Right) - return sql + + return paren(p, b.Left) + + " " + string(b.Op) + " " + + paren(p, b.Right) } func (u *UnaryExpr) SQL() string { p := exprPrec(u) - if u.Op == OpNot { - return "NOT " + paren(p, u.Expr) - } - return string(u.Op) + paren(p, u.Expr) + return string(u.Op) + strOpt(u.Op == OpNot, " ") + paren(p, u.Expr) } func (i *InExpr) SQL() string { p := exprPrec(i) - sql := paren(p, i.Left) - if i.Not { - sql += " NOT" - } - sql += " IN " - sql += i.Right.SQL() - return sql + return paren(p, i.Left) + + strOpt(i.Not, " NOT") + + " IN " + i.Right.SQL() } func (u *UnnestInCondition) SQL() string { @@ -434,47 +389,25 @@ func (s *SubQueryInCondition) SQL() string { } func (v *ValuesInCondition) SQL() string { - sql := "(" + v.Exprs[0].SQL() - for _, e := range v.Exprs[1:] { - sql += ", " + e.SQL() - } - sql += ")" - return sql + return "(" + sqlJoin(v.Exprs, ", ") + ")" } func (i *IsNullExpr) SQL() string { p := exprPrec(i) - sql := paren(p, i.Left) - sql += " IS " - if i.Not { - sql += "NOT " - } - sql += "NULL" - return sql + return paren(p, i.Left) + + " IS " + strOpt(i.Not, "NOT ") + "NULL" } func (i *IsBoolExpr) SQL() string { p := exprPrec(i) - sql := paren(p, i.Left) - sql += " IS " - if i.Not { - sql += "NOT " - } - if i.Right { - sql += "TRUE" - } else { - sql += "FALSE" - } - return sql + return paren(p, i.Left) + " IS " + strOpt(i.Not, "NOT ") + formatBoolUpper(i.Right) } func (b *BetweenExpr) SQL() string { p := exprPrec(b) - sql := paren(p, b.Left) - if b.Not { - sql += " NOT" - } - return sql + " BETWEEN " + paren(p, b.RightStart) + " AND " + paren(p, b.RightEnd) + return paren(p, b.Left) + + strOpt(b.Not, " NOT") + + " BETWEEN " + paren(p, b.RightStart) + " AND " + paren(p, b.RightEnd) } func (s *SelectorExpr) SQL() string { @@ -535,11 +468,7 @@ func (s *ExprArg) SQL() string { } func (i *IntervalArg) SQL() string { - sql := "INTERVAL " + i.Expr.SQL() - if i.Unit != nil { - sql += " " + i.Unit.SQL() - } - return sql + return "INTERVAL " + i.Expr.SQL() + sqlOpt(" ", i.Unit, "") } func (s *SequenceArg) SQL() string { @@ -559,12 +488,8 @@ func (*CountStarExpr) SQL() string { } func (e *ExtractExpr) SQL() string { - sql := "EXTRACT(" + e.Part.SQL() + " FROM " + e.Expr.SQL() - if e.AtTimeZone != nil { - sql += " " + e.AtTimeZone.SQL() - } - sql += ")" - return sql + return "EXTRACT(" + e.Part.SQL() + " FROM " + e.Expr.SQL() + + sqlOpt(" ", e.AtTimeZone, "") + ")" } func (a *AtTimeZone) SQL() string { @@ -588,18 +513,10 @@ func (c *CastExpr) SQL() string { } func (c *CaseExpr) SQL() string { - sql := "CASE " - if c.Expr != nil { - sql += c.Expr.SQL() + " " - } - for _, w := range c.Whens { - sql += w.SQL() + " " - } - if c.Else != nil { - sql += c.Else.SQL() + " " - } - sql += "END" - return sql + return "CASE " + sqlOpt("", c.Expr, " ") + + sqlJoin(c.Whens, " ") + " " + + sqlOpt("", c.Else, " ") + + "END" } func (c *CaseWhen) SQL() string { @@ -627,12 +544,9 @@ func (a *ArraySubQuery) SQL() string { } func (e *ExistsSubQuery) SQL() string { - sql := "EXISTS" - if e.Hint != nil { - sql += " " + e.Hint.SQL() + " " - } - sql += "(" + e.Query.SQL() + ")" - return sql + return "EXISTS" + + sqlOpt(" ", e.Hint, " ") + + "(" + e.Query.SQL() + ")" } func (p *Param) SQL() string { @@ -670,11 +584,7 @@ func (*NullLiteral) SQL() string { } func (b *BoolLiteral) SQL() string { - if b.Value { - return "TRUE" - } else { - return "FALSE" - } + return formatBoolUpper(b.Value) } func (i *IntLiteral) SQL() string { @@ -753,35 +663,15 @@ func (a *ArrayType) SQL() string { } func (s *StructType) SQL() string { - sql := "STRUCT<" - for i, f := range s.Fields { - if i != 0 { - sql += ", " - } - sql += f.SQL() - } - sql += ">" - return sql + return "STRUCT<" + sqlJoin(s.Fields, ", ") + ">" } func (f *StructField) SQL() string { - var sql string - if f.Ident != nil { - sql += f.Ident.SQL() + " " - } - sql += f.Type.SQL() - return sql + return sqlOpt("", f.Ident, " ") + f.Type.SQL() } func (n *NamedType) SQL() string { - var sql string - for i, elem := range n.Path { - if i > 0 { - sql += "." - } - sql += elem.SQL() - } - return sql + return sqlJoin(n.Path, ".") } // ================================================================================ @@ -870,12 +760,8 @@ func (c *CreateTable) SQL() string { func (s *Synonym) SQL() string { return "SYNONYM (" + s.Name.SQL() + ")" } func (c *CreateSequence) SQL() string { - sql := "CREATE SEQUENCE " - if c.IfNotExists { - sql += "IF NOT EXISTS " - } - sql += c.Name.SQL() + " " + c.Options.SQL() - return sql + return "CREATE SEQUENCE " + strOpt(c.IfNotExists, "IF NOT EXISTS ") + + c.Name.SQL() + " " + c.Options.SQL() } func (c *AlterSequence) SQL() string { @@ -883,12 +769,8 @@ func (c *AlterSequence) SQL() string { } func (c *CreateView) SQL() string { - sql := "CREATE" - if c.OrReplace { - sql += " OR REPLACE" - } - sql += " VIEW " + c.Name.SQL() + " SQL SECURITY " + string(c.SecurityType) + " AS " + c.Query.SQL() - return sql + return "CREATE" + strOpt(c.OrReplace, " OR REPLACE") + " VIEW " + c.Name.SQL() + + " SQL SECURITY " + string(c.SecurityType) + " AS " + c.Query.SQL() } func (d *DropView) SQL() string { return "DROP VIEW " + d.Name.SQL() } @@ -903,36 +785,14 @@ func (c *ColumnDef) SQL() string { } func (c *TableConstraint) SQL() string { - var sql string - if c.Name != nil { - sql += "CONSTRAINT " + c.Name.SQL() + " " - } - sql += c.Constraint.SQL() - return sql + return sqlOpt("CONSTRAINT ", c.Name, " ") + c.Constraint.SQL() } func (f *ForeignKey) SQL() string { - var sql string - sql += "FOREIGN KEY (" - for i, k := range f.Columns { - if i != 0 { - sql += ", " - } - sql += k.SQL() - } - sql += ") " - sql += "REFERENCES " + f.ReferenceTable.SQL() + " (" - for i, k := range f.ReferenceColumns { - if i != 0 { - sql += ", " - } - sql += k.SQL() - } - sql += ")" - if f.OnDelete != "" { - sql += " " + string(f.OnDelete) - } - return sql + return "FOREIGN KEY (" + sqlJoin(f.Columns, ", ") + ") " + + "REFERENCES " + f.ReferenceTable.SQL() + " (" + + sqlJoin(f.ReferenceColumns, ", ") + ")" + + strOpt(f.OnDelete != "", " "+string(f.OnDelete)) } func (c *Check) SQL() string { @@ -948,19 +808,12 @@ func (g *GeneratedColumnExpr) SQL() string { } func (i *IndexKey) SQL() string { - sql := i.Name.SQL() - if i.Dir != "" { - sql += " " + string(i.Dir) - } - return sql + return i.Name.SQL() + strOpt(i.Dir != "", " "+string(i.Dir)) } func (c *Cluster) SQL() string { - sql := ", INTERLEAVE IN PARENT " + c.TableName.SQL() - if c.OnDelete != "" { - sql += " " + string(c.OnDelete) - } - return sql + return ", INTERLEAVE IN PARENT " + c.TableName.SQL() + + strOpt(c.OnDelete != "", " "+string(c.OnDelete)) } func (c *CreateRowDeletionPolicy) SQL() string { @@ -982,11 +835,7 @@ func (s *DropSynonym) SQL() string { return "DROP SYNONYM " + s.Name.SQL() } func (t *RenameTo) SQL() string { return "RENAME TO " + t.Name.SQL() + sqlOpt(", ", t.AddSynonym, "") } func (a *AddColumn) SQL() string { - sql := "ADD COLUMN " - if a.IfNotExists { - sql += "IF NOT EXISTS " - } - return sql + a.Column.SQL() + return "ADD COLUMN " + strOpt(a.IfNotExists, "IF NOT EXISTS ") + a.Column.SQL() } func (a *AddTableConstraint) SQL() string { @@ -1034,11 +883,7 @@ func (a *AlterColumnSetDefault) SQL() string { return "SET " + a.DefaultExpr.SQL func (a *AlterColumnDropDefault) SQL() string { return "DROP DEFAULT" } func (d *DropTable) SQL() string { - sql := "DROP TABLE " - if d.IfExists { - sql += "IF EXISTS " - } - return sql + d.Name.SQL() + return "DROP TABLE " + strOpt(d.IfExists, "IF EXISTS ") + d.Name.SQL() } func (r *RenameTable) SQL() string { return "RENAME TABLE " + sqlJoin(r.Tos, ", ") } @@ -1059,17 +904,11 @@ func (c *CreateIndex) SQL() string { } func (c *CreateVectorIndex) SQL() string { - sql := "CREATE VECTOR INDEX " - if c.IfNotExists { - sql += "IF NOT EXISTS " - } - sql += c.Name.SQL() - sql += " ON " + c.TableName.SQL() + " (" + c.ColumnName.SQL() + ") " - if c.Where != nil { - sql += c.Where.SQL() + " " - } - sql += c.Options.SQL() - return sql + return "CREATE VECTOR INDEX " + + strOpt(c.IfNotExists, "IF NOT EXISTS ") + + c.Name.SQL() + " ON " + c.TableName.SQL() + " (" + c.ColumnName.SQL() + ") " + + sqlOpt("", c.Where, " ") + + c.Options.SQL() } func (c *CreateChangeStream) SQL() string { @@ -1083,6 +922,7 @@ func (c *ChangeStreamForAll) SQL() string { } func (c *ChangeStreamForTables) SQL() string { + // TODO: Refactor after ChangeStreamForTable implements Node. sql := "FOR " for i, table := range c.Tables { if i > 0 { @@ -1110,18 +950,7 @@ func (a ChangeStreamSetOptions) SQL() string { } func (c *ChangeStreamForTable) SQL() string { - sql := c.TableName.SQL() - if len(c.Columns) > 0 { - sql += "(" - for i, id := range c.Columns { - if i > 0 { - sql += ", " - } - sql += id.SQL() - } - sql += ")" - } - return sql + return c.TableName.SQL() + strOpt(len(c.Columns) > 0, "("+sqlJoin(c.Columns, ", ")+")") } func (d *DropChangeStream) SQL() string { @@ -1129,15 +958,7 @@ func (d *DropChangeStream) SQL() string { } func (s *Storing) SQL() string { - sql := "STORING (" - for i, c := range s.Columns { - if i != 0 { - sql += ", " - } - sql += c.SQL() - } - sql += ")" - return sql + return "STORING (" + sqlJoin(s.Columns, ", ") + ")" } func (i *InterleaveIn) SQL() string { @@ -1157,27 +978,15 @@ func (a *DropStoredColumn) SQL() string { } func (d *DropIndex) SQL() string { - sql := "DROP INDEX " - if d.IfExists { - sql += "IF EXISTS " - } - return sql + d.Name.SQL() + return "DROP INDEX " + strOpt(d.IfExists, "IF EXISTS ") + d.Name.SQL() } func (d *DropVectorIndex) SQL() string { - sql := "DROP VECTOR INDEX " - if d.IfExists { - sql += "IF EXISTS " - } - return sql + d.Name.SQL() + return "DROP VECTOR INDEX " + strOpt(d.IfExists, "IF EXISTS ") + d.Name.SQL() } func (d *DropSequence) SQL() string { - sql := "DROP SEQUENCE " - if d.IfExists { - sql += "IF EXISTS " - } - return sql + d.Name.SQL() + return "DROP SEQUENCE " + strOpt(d.IfExists, "IF EXISTS ") + d.Name.SQL() } func (c *CreateRole) SQL() string { @@ -1189,81 +998,30 @@ func (d *DropRole) SQL() string { } func (g *Grant) SQL() string { - sql := "GRANT " - sql += g.Privilege.SQL() - sql += " TO ROLE " + g.Roles[0].SQL() - for _, id := range g.Roles[1:] { - sql += ", " + id.SQL() - } - return sql + return "GRANT " + g.Privilege.SQL() + " TO ROLE " + sqlJoin(g.Roles, ", ") } func (r *Revoke) SQL() string { - sql := "REVOKE " - sql += r.Privilege.SQL() - sql += " FROM ROLE " + r.Roles[0].SQL() - for _, id := range r.Roles[1:] { - sql += ", " + id.SQL() - } - return sql + return "REVOKE " + r.Privilege.SQL() + " FROM ROLE " + sqlJoin(r.Roles, ", ") } func (p *PrivilegeOnTable) SQL() string { - sql := p.Privileges[0].SQL() - for _, p := range p.Privileges[1:] { - sql += ", " + p.SQL() - } - sql += " ON TABLE " - sql += p.Names[0].SQL() - for _, id := range p.Names[1:] { - sql += ", " + id.SQL() - } - return sql + return sqlJoin(p.Privileges, ", ") + " ON TABLE " + sqlJoin(p.Names, ", ") } func (s *SelectPrivilege) SQL() string { - sql := "SELECT" - if len(s.Columns) > 0 { - sql += "(" - for i, c := range s.Columns { - if i > 0 { - sql += ", " - } - sql += c.SQL() - } - sql += ")" - } - return sql + return "SELECT" + + strOpt(len(s.Columns) > 0, "("+sqlJoin(s.Columns, ", ")+")") } func (i *InsertPrivilege) SQL() string { - sql := "INSERT" - if len(i.Columns) > 0 { - sql += "(" - for j, c := range i.Columns { - if j > 0 { - sql += ", " - } - sql += c.SQL() - } - sql += ")" - } - return sql + return "INSERT" + + strOpt(len(i.Columns) > 0, "("+sqlJoin(i.Columns, ", ")+")") } func (u *UpdatePrivilege) SQL() string { - sql := "UPDATE" - if len(u.Columns) > 0 { - sql += "(" - for i, c := range u.Columns { - if i > 0 { - sql += ", " - } - sql += c.SQL() - } - sql += ")" - } - return sql + return "UPDATE" + + strOpt(len(u.Columns) > 0, "("+sqlJoin(u.Columns, ", ")+")") } func (d *DeletePrivilege) SQL() string { @@ -1275,27 +1033,15 @@ func (p *SelectPrivilegeOnChangeStream) SQL() string { } func (s *SelectPrivilegeOnView) SQL() string { - sql := "SELECT ON VIEW " + s.Names[0].SQL() - for _, v := range s.Names[1:] { - sql += ", " + v.SQL() - } - return sql + return "SELECT ON VIEW " + sqlJoin(s.Names, ", ") } func (e *ExecutePrivilegeOnTableFunction) SQL() string { - sql := "EXECUTE ON TABLE FUNCTION " + e.Names[0].SQL() - for _, f := range e.Names[1:] { - sql += ", " + f.SQL() - } - return sql + return "EXECUTE ON TABLE FUNCTION " + sqlJoin(e.Names, ", ") } func (r *RolePrivilege) SQL() string { - sql := "ROLE " + r.Names[0].SQL() - for _, id := range r.Names[1:] { - sql += ", " + id.SQL() - } - return sql + return "ROLE " + sqlJoin(r.Names, ", ") } func (s *AlterStatistics) SQL() string { @@ -1343,14 +1089,8 @@ func (s *ScalarSchemaType) SQL() string { } func (s *SizedSchemaType) SQL() string { - sql := string(s.Name) + "(" - if s.Max { - sql += "MAX" - } else { - sql += s.Size.SQL() - } - sql += ")" - return sql + return string(s.Name) + + "(" + strIfElse(s.Max, "MAX", sqlOpt("", s.Size, "")) + ")" } func (a *ArraySchemaType) SQL() string { @@ -1407,26 +1147,11 @@ func (i *Insert) SQL() string { } func (v *ValuesInput) SQL() string { - sql := "VALUES " - for i, r := range v.Rows { - if i != 0 { - sql += ", " - } - sql += r.SQL() - } - return sql + return "VALUES " + sqlJoin(v.Rows, ", ") } func (v *ValuesRow) SQL() string { - sql := "(" - for i, v := range v.Exprs { - if i != 0 { - sql += ", " - } - sql += v.SQL() - } - sql += ")" - return sql + return "(" + sqlJoin(v.Exprs, ", ") + ")" } func (d *DefaultExpr) SQL() string {