diff --git a/ast/ast.go b/ast/ast.go index caf16acc..5708a22a 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -32,9 +32,20 @@ // (PosVar, NodeVar, NodeSliceVar, and BoolVar are derived by its struct definition.) package ast -// This file must contain only AST definitions. -// We use the following go:generate directive for generating pos.go. Thus, all AST definitions must have pos and end lines. -//go:generate go run ../tools/gen-ast-pos/main.go -infile ast.go -outfile pos.go +// NOTE: ast.go and ast_*.go are used for automatic generation, so these files are conventional. + +// NOTE: This file defines AST nodes and they are used for automatic generation, +// so this file is conventional. +// +// Conventions: +// +// - Each node interface (except for Node) should have isXXX method (XXX must be a name of the interface itself). +// - `isXXX` methods should be defined after the interface definition +// and the receiver should be the non-pointer node struct type. +// - Each node struct should have pos and end comments. +// - Each node struct should have template lines in its doc comment. + +//go:generate go run ../tools/gen-ast-pos/main.go -astfile ast.go -constfile ast_const.go -outfile pos.go import ( "github.com/cloudspannerecosystem/memefish/token" @@ -562,7 +573,7 @@ type BadNode struct { // BadStatement is a BadNode for Statement. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadStatement struct { // pos = BadNode.pos // end = BadNode.end @@ -572,7 +583,7 @@ type BadStatement struct { // BadQueryExpr is a BadNode for QueryExpr. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadQueryExpr struct { // pos = BadNode.pos // end = BadNode.end @@ -582,7 +593,7 @@ type BadQueryExpr struct { // BadExpr is a BadNode for Expr. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadExpr struct { // pos = BadNode.pos // end = BadNode.end @@ -592,7 +603,7 @@ type BadExpr struct { // BadType is a BadNode for Type. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadType struct { // pos = BadNode.pos // end = BadNode.end @@ -602,7 +613,7 @@ type BadType struct { // BadDDL is a BadNode for DDL. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadDDL struct { // pos = BadNode.pos // end = BadNode.end @@ -612,7 +623,7 @@ type BadDDL struct { // BadDML is a BadNode for DML. // -// {{.BadNode | sql}} +// {{.BadNode | sql}} type BadDML struct { // pos = BadNode.pos // end = BadNode.end @@ -768,6 +779,8 @@ type AsTypeName struct { } // FromQuery is FROM query expression node. +// +// FROM {{.From | sql}} type FromQuery struct { // pos = From.pos // end = From.end @@ -1997,8 +2010,8 @@ type BracedConstructorFieldValue interface { isBracedConstructorFieldValue() } -func (*BracedConstructor) isBracedConstructorFieldValue() {} -func (*BracedConstructorFieldValueExpr) isBracedConstructorFieldValue() {} +func (BracedConstructor) isBracedConstructorFieldValue() {} +func (BracedConstructorFieldValueExpr) isBracedConstructorFieldValue() {} // NewConstructor represents NEW operator which creates a protocol buffer using a parenthesized list of arguments. // @@ -3023,7 +3036,6 @@ type CreateIndex struct { // ON {{.TableName | sql}}({{.ColumnName | sql}}) // {{if .Where}}WHERE {{.Where | sql}}{{end}} // {{.Options | sql}} - type CreateVectorIndex struct { // pos = Create // end = Options.end diff --git a/ast/const.go b/ast/ast_const.go similarity index 93% rename from ast/const.go rename to ast/ast_const.go index d72748d1..eb629325 100644 --- a/ast/const.go +++ b/ast/ast_const.go @@ -1,5 +1,13 @@ package ast +// NOTE: This file defines constants used in AST nodes and they are used for automatic generation, +// so this file is conventional. +// +// Convention: +// +// - Each const types should be defined as a string type. +// - Each value is defined as a string literal. + // AllOrDistinct represents ALL or DISTINCT in SELECT or set operations, etc. // If it is optional, it may be an empty string, so handle it according to the context. type AllOrDistinct string diff --git a/ast/ast_test.go b/ast/ast_test.go deleted file mode 100644 index cc8ee6d3..00000000 --- a/ast/ast_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package ast - -import ( - "testing" -) - -func TestStatement(t *testing.T) { - Statement(&QueryStatement{}).isStatement() - Statement(&CreateDatabase{}).isStatement() - Statement(&AlterDatabase{}).isStatement() - Statement(&CreateTable{}).isStatement() - Statement(&AlterTable{}).isStatement() - Statement(&DropTable{}).isStatement() - Statement(&CreateIndex{}).isStatement() - Statement(&AlterIndex{}).isStatement() - Statement(&DropIndex{}).isStatement() - Statement(&CreateView{}).isStatement() - Statement(&DropView{}).isStatement() - Statement(&CreateChangeStream{}).isStatement() - Statement(&AlterChangeStream{}).isStatement() - Statement(&DropChangeStream{}).isStatement() - Statement(&CreateRole{}).isStatement() - Statement(&DropRole{}).isStatement() - Statement(&Grant{}).isStatement() - Statement(&Revoke{}).isStatement() - Statement(&CreateSequence{}).isStatement() - Statement(&AlterSequence{}).isStatement() - Statement(&DropSequence{}).isStatement() - Statement(&CreateVectorIndex{}).isStatement() - Statement(&DropVectorIndex{}).isStatement() - Statement(&Insert{}).isStatement() - Statement(&Delete{}).isStatement() - Statement(&Update{}).isStatement() -} - -func TestQueryExpr(t *testing.T) { - QueryExpr(&Select{}).isQueryExpr() - QueryExpr(&SubQuery{}).isQueryExpr() - QueryExpr(&CompoundQuery{}).isQueryExpr() -} - -func TestSelectItem(t *testing.T) { - SelectItem(&Star{}).isSelectItem() - SelectItem(&DotStar{}).isSelectItem() - SelectItem(&Alias{}).isSelectItem() - SelectItem(&ExprSelectItem{}).isSelectItem() -} - -func TestSelectAs(t *testing.T) { - SelectAs(&AsStruct{}).isSelectAs() - SelectAs(&AsValue{}).isSelectAs() - SelectAs(&AsTypeName{}).isSelectAs() -} - -func TestTableExpr(t *testing.T) { - TableExpr(&Unnest{}).isTableExpr() - TableExpr(&TableName{}).isTableExpr() - TableExpr(&SubQueryTableExpr{}).isTableExpr() - TableExpr(&ParenTableExpr{}).isTableExpr() - TableExpr(&Join{}).isTableExpr() -} - -func TestJoinCondition(t *testing.T) { - JoinCondition(&On{}).isJoinCondition() - JoinCondition(&Using{}).isJoinCondition() -} - -func TestExpr(t *testing.T) { - Expr(&BinaryExpr{}).isExpr() - Expr(&UnaryExpr{}).isExpr() - Expr(&InExpr{}).isExpr() - Expr(&IsNullExpr{}).isExpr() - Expr(&IsBoolExpr{}).isExpr() - Expr(&BetweenExpr{}).isExpr() - Expr(&SelectorExpr{}).isExpr() - Expr(&IndexExpr{}).isExpr() - Expr(&CallExpr{}).isExpr() - Expr(&CountStarExpr{}).isExpr() - Expr(&CastExpr{}).isExpr() - Expr(&ExtractExpr{}).isExpr() - Expr(&CaseExpr{}).isExpr() - Expr(&ParenExpr{}).isExpr() - Expr(&ScalarSubQuery{}).isExpr() - Expr(&ArraySubQuery{}).isExpr() - Expr(&ExistsSubQuery{}).isExpr() - Expr(&Param{}).isExpr() - Expr(&Ident{}).isExpr() - Expr(&Path{}).isExpr() - Expr(&ArrayLiteral{}).isExpr() - Expr(&TypedStructLiteral{}).isExpr() - Expr(&NullLiteral{}).isExpr() - Expr(&BoolLiteral{}).isExpr() - Expr(&IntLiteral{}).isExpr() - Expr(&FloatLiteral{}).isExpr() - Expr(&StringLiteral{}).isExpr() - Expr(&BytesLiteral{}).isExpr() - Expr(&DateLiteral{}).isExpr() - Expr(&TimestampLiteral{}).isExpr() - Expr(&NumericLiteral{}).isExpr() -} - -func TestArg(t *testing.T) { - Arg(&IntervalArg{}).isArg() - Arg(&ExprArg{}).isArg() - Arg(&SequenceArg{}).isArg() -} - -func TestInCondition(t *testing.T) { - InCondition(&UnnestInCondition{}).isInCondition() - InCondition(&SubQueryInCondition{}).isInCondition() - InCondition(&ValuesInCondition{}).isInCondition() -} - -func TestType(t *testing.T) { - Type(&SimpleType{}).isType() - Type(&ArrayType{}).isType() - Type(&StructType{}).isType() -} - -func TestIntValue(t *testing.T) { - IntValue(&Param{}).isIntValue() - IntValue(&IntLiteral{}).isIntValue() - IntValue(&CastIntValue{}).isIntValue() -} - -func TestNumValue(t *testing.T) { - NumValue(&Param{}).isNumValue() - NumValue(&IntLiteral{}).isNumValue() - NumValue(&FloatLiteral{}).isNumValue() - NumValue(&CastNumValue{}).isNumValue() -} - -func TestStringValue(t *testing.T) { - StringValue(&Param{}).isStringValue() - StringValue(&StringLiteral{}).isStringValue() -} - -func TestDDL(t *testing.T) { - DDL(&CreateDatabase{}).isDDL() - DDL(&AlterDatabase{}).isDDL() - DDL(&CreateTable{}).isDDL() - DDL(&AlterTable{}).isDDL() - DDL(&DropTable{}).isDDL() - DDL(&CreateIndex{}).isDDL() - DDL(&AlterIndex{}).isDDL() - DDL(&DropIndex{}).isDDL() - DDL(&CreateSearchIndex{}).isDDL() - DDL(&DropSearchIndex{}).isDDL() - DDL(&AlterSearchIndex{}).isDDL() - DDL(&CreateView{}).isDDL() - DDL(&DropView{}).isDDL() - DDL(&CreateChangeStream{}).isDDL() - DDL(&AlterChangeStream{}).isDDL() - DDL(&DropChangeStream{}).isDDL() - DDL(&CreateRole{}).isDDL() - DDL(&DropRole{}).isDDL() - DDL(&Grant{}).isDDL() - DDL(&Revoke{}).isDDL() - DDL(&CreateSequence{}).isDDL() - DDL(&AlterSequence{}).isDDL() - DDL(&DropSequence{}).isDDL() - DDL(&CreateVectorIndex{}).isDDL() - DDL(&DropVectorIndex{}).isDDL() -} - -func TestConstraint(t *testing.T) { - Constraint(&ForeignKey{}).isConstraint() - Constraint(&Check{}).isConstraint() -} - -func TestTableAlteration(t *testing.T) { - TableAlteration(&AddColumn{}).isTableAlteration() - TableAlteration(&AddTableConstraint{}).isTableAlteration() - TableAlteration(&DropColumn{}).isTableAlteration() - TableAlteration(&DropConstraint{}).isTableAlteration() - TableAlteration(&SetOnDelete{}).isTableAlteration() - TableAlteration(&AlterColumn{}).isTableAlteration() -} - -func TestPrivilege(t *testing.T) { - Privilege(&PrivilegeOnTable{}).isPrivilege() - Privilege(&SelectPrivilegeOnView{}).isPrivilege() - Privilege(&ExecutePrivilegeOnTableFunction{}).isPrivilege() - Privilege(&RolePrivilege{}).isPrivilege() -} - -func TestTablePrivilege(t *testing.T) { - TablePrivilege(&SelectPrivilege{}).isTablePrivilege() - TablePrivilege(&InsertPrivilege{}).isTablePrivilege() - TablePrivilege(&UpdatePrivilege{}).isTablePrivilege() - TablePrivilege(&DeletePrivilege{}).isTablePrivilege() -} - -func TestSchemaType(t *testing.T) { - SchemaType(&ScalarSchemaType{}).isSchemaType() - SchemaType(&SizedSchemaType{}).isSchemaType() - SchemaType(&ArraySchemaType{}).isSchemaType() -} - -func TestIndexAlteration(t *testing.T) { - IndexAlteration(&AddStoredColumn{}).isIndexAlteration() - IndexAlteration(&DropStoredColumn{}).isIndexAlteration() -} - -func TestDML(t *testing.T) { - DML(&Insert{}).isDML() - DML(&Delete{}).isDML() - DML(&Update{}).isDML() -} - -func TestInsertInput(t *testing.T) { - InsertInput(&ValuesInput{}).isInsertInput() - InsertInput(&SubQueryInput{}).isInsertInput() -} diff --git a/tools/astcatalog/main.go b/tools/astcatalog/main.go new file mode 100644 index 00000000..bdaac2af --- /dev/null +++ b/tools/astcatalog/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "flag" + "fmt" + "log" + + "github.com/MakeNowJust/heredoc/v2" + "github.com/cloudspannerecosystem/memefish/tools/util/astcatalog" + "github.com/k0kubun/pp/v3" +) + +var usage = heredoc.Doc(` + Usage of tools/astcatalog + + An utility to show the AST catalog. + + Example: + + $ go run ./tools/astcatalog/main.go -astfile ast/ast.go -constfile ast/ast_const.go + Print the AST catalog of ast/ast.go and ast/ast_const.go. + + Flags: +`) + +var ( + astfile = flag.String("astfile", "ast/ast.go", "path to ast/ast.go") + constfile = flag.String("constfile", "ast/ast_const.go", "path to ast/ast_const.go") +) + +func main() { + flag.Usage = func() { + fmt.Print(usage) + flag.PrintDefaults() + } + + flag.Parse() + + catalog, err := astcatalog.Load(*astfile, *constfile) + if err != nil { + log.Fatalf("failed to load: %v", err) + } + pprinter := pp.New() + pprinter.SetOmitEmpty(true) + _, _ = pprinter.Println(catalog) +} diff --git a/tools/gen-ast-pos/main.go b/tools/gen-ast-pos/main.go index 2bcee65d..b146e41d 100644 --- a/tools/gen-ast-pos/main.go +++ b/tools/gen-ast-pos/main.go @@ -6,11 +6,11 @@ import ( "fmt" "log" "os" - "regexp" - "strings" + "sort" "unicode" "github.com/MakeNowJust/heredoc/v2" + "github.com/cloudspannerecosystem/memefish/tools/util/astcatalog" "github.com/cloudspannerecosystem/memefish/tools/util/poslang" ) @@ -22,7 +22,7 @@ var ( Example: - $ go run ./tools/gen-ast-pos/main.go -infile ast/ast.go + $ go run ./tools/gen-ast-pos/main.go -astfile ast/ast.go -constfile ast/ast_const.go Print the generated ast/pos.go to stdout. Flags: @@ -39,21 +39,11 @@ var ( ) var ( - infile = flag.String("infile", "", "input filename") - outfile = flag.String("outfile", "", "output filename (if it is not specified, the result is printed to stdout.)") + astfile = flag.String("astfile", "ast/ast.go", "path to ast/ast.go") + constfile = flag.String("constfile", "ast/ast_const.go", "path to ast/ast_const.go") + outfile = flag.String("outfile", "", "output filename (if it is not specified, the result is printed to stdout.)") ) -var ( - reNameLine = regexp.MustCompile(`^\s*type\s+(\w+)\s+struct\s*\{`) - rePosLine = regexp.MustCompile(`^\s*//\s*pos\s*=\s*(.*)`) - reEndLine = regexp.MustCompile(`^\s*//\s*end\s*=\s*(.*)`) -) - -type astNode struct { - name string - posExpr, endExpr poslang.PosExpr -} - func main() { flag.Usage = func() { fmt.Print(usage) @@ -62,59 +52,42 @@ func main() { flag.Parse() - source, err := os.ReadFile(*infile) + catalog, err := astcatalog.Load(*astfile, *constfile) if err != nil { log.Fatal(err) } - var nodes []*astNode - - for _, line := range strings.Split(string(source), "\n") { - if m := reNameLine.FindStringSubmatch(line); m != nil { - name := m[1] - nodes = append(nodes, &astNode{name: name}) - continue - } - - if m := rePosLine.FindStringSubmatch(line); m != nil { - e := m[1] - expr, err := poslang.Parse(e) - if err != nil { - log.Printf("Error at node %s, pos = %s", nodes[len(nodes)-1].name, e) - log.Fatal(err) - } - nodes[len(nodes)-1].posExpr = expr - continue - } - - if m := reEndLine.FindStringSubmatch(line); m != nil { - e := m[1] - expr, err := poslang.Parse(e) - if err != nil { - log.Printf("Error at node %s, end = %s", nodes[len(nodes)-1].name, e) - log.Fatal(err) - } - nodes[len(nodes)-1].endExpr = expr - continue - } + structs := make([]*astcatalog.NodeStructDef, 0, len(catalog.Structs)) + for _, structDef := range catalog.Structs { + structs = append(structs, structDef) } + sort.Slice(structs, func(i, j int) bool { + return structs[i].SourcePos < structs[j].SourcePos + }) var buffer bytes.Buffer buffer.WriteString(prologue) - for _, node := range nodes { - x := string(unicode.ToLower(rune(node.name[0]))) - if node.posExpr == nil || node.endExpr == nil { - log.Fatalf("pos/end is not defined: node %s", node.name) + for _, structDef := range structs { + x := string(unicode.ToLower(rune(structDef.Name[0]))) + + posExpr, err := poslang.Parse(structDef.Pos) + if err != nil { + log.Fatalf("error on parsing pos: %v", err) + } + + endExpr, err := poslang.Parse(structDef.End) + if err != nil { + log.Fatalf("error on parsing pos: %v", err) } fmt.Fprintln(&buffer) - fmt.Fprintf(&buffer, "func (%s *%s) Pos() token.Pos {\n", x, node.name) - fmt.Fprintf(&buffer, "\treturn %s\n", node.posExpr.PosExprToGo(x)) + fmt.Fprintf(&buffer, "func (%s *%s) Pos() token.Pos {\n", x, structDef.Name) + fmt.Fprintf(&buffer, "\treturn %s\n", posExpr.PosExprToGo(x)) fmt.Fprintf(&buffer, "}\n") fmt.Fprintln(&buffer) - fmt.Fprintf(&buffer, "func (%s *%s) End() token.Pos {\n", x, node.name) - fmt.Fprintf(&buffer, "\treturn %s\n", node.endExpr.PosExprToGo(x)) + fmt.Fprintf(&buffer, "func (%s *%s) End() token.Pos {\n", x, structDef.Name) + fmt.Fprintf(&buffer, "\treturn %s\n", endExpr.PosExprToGo(x)) fmt.Fprintf(&buffer, "}\n") } diff --git a/tools/util/astcatalog/catalog.go b/tools/util/astcatalog/catalog.go new file mode 100644 index 00000000..f6d7faa5 --- /dev/null +++ b/tools/util/astcatalog/catalog.go @@ -0,0 +1,94 @@ +package astcatalog + +import ( + "go/token" +) + +// Catalog is a catalog of AST types. +type Catalog struct { + Structs map[NodeStructType]*NodeStructDef + Interfaces map[NodeInterfaceType]*NodeInterfaceDef + Consts map[ConstType]*ConstDef +} + +// NodeStructDef is a definition of node structs in ast/ast.go. +type NodeStructDef struct { + SourcePos token.Pos + Name string + Doc string + Tmpl string + Pos, End string + Fields []*FieldDef + Implements []NodeInterfaceType +} + +// FieldDef is a field definition of node structs in ast/ast.go. +type FieldDef struct { + Name string + Type Type + Comment string +} + +// NodeInterfaceDef is a definition of node interfaces in ast/ast.go. +type NodeInterfaceDef struct { + SourcePos token.Pos + Name string + Implemented []NodeStructType +} + +// ConstDef is a definition of const types in ast/ast_const.go +type ConstDef struct { + SourcePos token.Pos + Name string + Values []*ConstValueDef +} + +// ConstValueDef is a value definition of const types in ast/ast_const.go. +type ConstValueDef struct { + Name string + Value string +} + +// Type represents types used in Catalog. +type Type interface { + isType() +} + +func (SliceType) isType() {} +func (PointerType) isType() {} +func (NodeStructType) isType() {} +func (NodeInterfaceType) isType() {} +func (PrimitiveType) isType() {} +func (ConstType) isType() {} + +// SliceType is a slice type. +type SliceType struct { + Type Type +} + +// PointerType is a pointer type. +type PointerType struct { + Type Type +} + +// NodeStructType is a type name of node structs defined in ast/ast.go. +type NodeStructType string + +// NodeInterfaceType is a type name of node interfaces defined in ast/ast.go. +type NodeInterfaceType string + +// ConstType is a type name of const types defined in ast/ast_const.go. +type ConstType string + +// PrimitiveType is a type name which is neither a node pointer, a node interface, nor a const types. +type PrimitiveType string + +// PrimitiveType values. +const ( + BoolType PrimitiveType = "bool" + ByteType PrimitiveType = "byte" + IntType PrimitiveType = "int" + StringType PrimitiveType = "string" + TokenPosType PrimitiveType = "token.Pos" + TokenTokenType PrimitiveType = "token.Token" +) diff --git a/tools/util/astcatalog/load.go b/tools/util/astcatalog/load.go new file mode 100644 index 00000000..16f96d8e --- /dev/null +++ b/tools/util/astcatalog/load.go @@ -0,0 +1,299 @@ +package astcatalog + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "regexp" + "strings" +) + +// Load loads a catalog from the given AST files. +func Load(astFilename, astConstFilename string) (*Catalog, error) { + fset := token.NewFileSet() + + consts, err := loadConsts(fset, astConstFilename) + if err != nil { + return nil, err + } + + structs, interfaces, err := loadStructs(fset, astFilename, consts) + if err != nil { + return nil, err + } + + catalog := &Catalog{ + Structs: structs, + Interfaces: interfaces, + Consts: consts, + } + return catalog, nil +} + +func loadConsts(fset *token.FileSet, filename string) (map[ConstType]*ConstDef, error) { + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("failed to parse '%s': %w", filename, err) + } + + consts := make(map[ConstType]*ConstDef) + + for _, decl := range f.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + for _, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + name := ConstType(s.Name.Name) + if _, ok := consts[name]; ok { + return nil, fmt.Errorf("duplicated const: %s", name) + } + + consts[name] = &ConstDef{ + SourcePos: d.Pos(), + Name: s.Name.Name, + } + case *ast.ValueSpec: + if s.Type == nil { + return nil, fmt.Errorf("unexpected value spec: %#v", s) + } + ty, err := loadType(s.Type, nil, consts) + if err != nil { + return nil, err + } + name, ok := ty.(ConstType) + if !ok { + return nil, fmt.Errorf("unexpected type: %#v", s.Type) + } + + constDef, ok := consts[name] + if !ok { + return nil, fmt.Errorf("unknown const: %s", name) + } + + if len(s.Values) != 1 { + return nil, fmt.Errorf("unexpected values: %#v", s.Values) + } + lit, ok := s.Values[0].(*ast.BasicLit) + if !(ok && lit.Kind == token.STRING) { + return nil, fmt.Errorf("unexpected value: %#v", s.Values[0]) + } + v := strings.Trim(lit.Value, "\"") + + for _, name := range s.Names { + constDef.Values = append(constDef.Values, &ConstValueDef{ + Name: name.Name, + Value: v, + }) + } + } + } + } + } + + return consts, nil +} + +// Regular expressions to extract pos/end and template comments. +var ( + rePosLine = regexp.MustCompile(`(?m)^\s*pos\s*=\s*(.*)`) + reEndLine = regexp.MustCompile(`(?m)^\s*end\s*=\s*(.*)`) + reTmplLines = regexp.MustCompile(`(?m)((?:^\t.*\n)+)+`) +) + +func loadStructs(fset *token.FileSet, filename string, consts map[ConstType]*ConstDef) (map[NodeStructType]*NodeStructDef, map[NodeInterfaceType]*NodeInterfaceDef, error) { + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse '%s': %w", filename, err) + } + + structs := make(map[NodeStructType]*NodeStructDef) + interfaces := make(map[NodeInterfaceType]*NodeInterfaceDef) + + commentMap := ast.NewCommentMap(fset, f, f.Comments) + + for _, decl := range f.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + for _, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + switch t := s.Type.(type) { + case *ast.StructType: + name := NodeStructType(s.Name.Name) + if _, ok := structs[name]; !ok { + structs[name] = &NodeStructDef{ + Name: s.Name.Name, + } + } + + structDef := structs[name] + structDef.SourcePos = s.Pos() + structDef.Doc = d.Doc.Text() + + if m := reTmplLines.FindAllStringSubmatch(structDef.Doc, -1); m != nil { + structDef.Tmpl = m[len(m)-1][0] + } else { + return nil, nil, fmt.Errorf("no template found: %s", name) + } + + comments := commentMap.Filter(t).Comments() + if len(comments) == 0 { + return nil, nil, fmt.Errorf("no pos/end comment found: %s", name) + } + + // We assume the first comment group in the struct should a pos/end comment. + posComment := comments[0].Text() + if m := rePosLine.FindStringSubmatch(posComment); m != nil { + structDef.Pos = m[1] + } else { + return nil, nil, fmt.Errorf("no pos comment found: %s", name) + } + if m := reEndLine.FindStringSubmatch(posComment); m != nil { + structDef.End = m[1] + } else { + return nil, nil, fmt.Errorf("no end coment found: %s", name) + } + + for _, f := range t.Fields.List { + ty, err := loadType(f.Type, interfaces, consts) + if err != nil { + return nil, nil, err + } + + comment := "" + comments := commentMap.Filter(f).Comments() + for _, c := range comments { + if f.End() < c.Pos() { + comment = c.Text() + break + } + } + for _, name := range f.Names { + structDef.Fields = append(structDef.Fields, &FieldDef{ + Name: name.Name, + Type: ty, + Comment: comment, + }) + } + } + case *ast.InterfaceType: + name := NodeInterfaceType(s.Name.Name) + if _, ok := interfaces[name]; ok { + return nil, nil, fmt.Errorf("duplicated interface: %s", name) + } + + // Node interface is special, so we skip it. + if name == "Node" { + continue + } + + interfaces[name] = &NodeInterfaceDef{ + SourcePos: s.Pos(), + Name: string(name), + } + default: + return nil, nil, fmt.Errorf("unexpected spec: %#v", t) + } + } + } + + case *ast.FuncDecl: + if d.Recv == nil || len(d.Recv.List) != 1 { + return nil, nil, fmt.Errorf("unexpected func decl: %#v", d) + } + + recv, err := loadType(d.Recv.List[0].Type, interfaces, consts) + if err != nil { + return nil, nil, err + } + + structName, ok := recv.(NodeStructType) + if !ok { + return nil, nil, fmt.Errorf("unexpected receiver type: %#v", recv) + } + + funcName := d.Name.Name + cutName, found := strings.CutPrefix(funcName, "is") + if !found { + return nil, nil, fmt.Errorf("unexpected func name: %s", funcName) + } + + interfaceName := NodeInterfaceType(cutName) + interfaceDef, ok := interfaces[interfaceName] + if !ok { + return nil, nil, fmt.Errorf("unknown interface: %s", interfaceName) + } + + interfaceDef.Implemented = append(interfaceDef.Implemented, structName) + + if _, ok := structs[structName]; !ok { + structs[structName] = &NodeStructDef{ + Name: string(structName), + } + } + + structDef := structs[structName] + structDef.Implements = append(structDef.Implements, interfaceName) + } + } + + return structs, interfaces, nil +} + +func loadType(t ast.Expr, interfaces map[NodeInterfaceType]*NodeInterfaceDef, consts map[ConstType]*ConstDef) (Type, error) { + switch t := t.(type) { + case *ast.Ident: + switch t.Name { + case "bool": + return BoolType, nil + case "byte": + return ByteType, nil + case "int": + return IntType, nil + case "string": + return StringType, nil + default: + if _, ok := consts[ConstType(t.Name)]; ok { + return ConstType(t.Name), nil + } + if _, ok := interfaces[NodeInterfaceType(t.Name)]; ok { + return NodeInterfaceType(t.Name), nil + } + return NodeStructType(t.Name), nil + } + case *ast.SelectorExpr: + if x, ok := t.X.(*ast.Ident); !(ok && x.Name == "token") { + return nil, fmt.Errorf("unexpected selector expr: %#v", t) + } + switch t.Sel.Name { + case "Pos": + return TokenPosType, nil + case "Token": + return TokenTokenType, nil + default: + return nil, fmt.Errorf("unexpected selector name: %#v", t) + } + case *ast.StarExpr: + ty, err := loadType(t.X, interfaces, consts) + if err != nil { + return nil, err + } + + return PointerType{Type: ty}, nil + case *ast.ArrayType: + if t.Len != nil { + return nil, fmt.Errorf("unexpected array type: %#v", t) + } + + ty, err := loadType(t.Elt, interfaces, consts) + if err != nil { + return nil, err + } + + return SliceType{Type: ty}, nil + default: + return nil, fmt.Errorf("unexpected type: %#v", t) + } +}