diff --git a/ast/ast.go b/ast/ast.go index 81b00ef7..b720b41f 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -32,9 +32,20 @@ // (PosVar, NodeVar, NodeSliceVar, and BoolVar are derived by its struct definition.) package ast -// This file must contain only AST definitions. -// We use the following go:generate directive for generating pos.go. Thus, all AST definitions must have pos and end lines. -//go:generate go run ../tools/gen-ast-pos/main.go -infile ast.go -outfile pos.go +// NOTE: ast.go and ast_*.go are used for automatic generation, so these files are conventional. + +// NOTE: This file defines AST nodes and they are used for automatic generation, +// so this file is conventional. +// +// Conventions: +// +// - Each node interface (except for Node) should have isXXX method (XXX must be a name of the interface itself). +// - `isXXX` methods should be defined after the interface definition +// and the receiver should be the non-pointer node struct type. +// - Each node struct should have pos and end comments. +// - Each node struct should have template lines in its doc comment. + +//go:generate go run ../tools/gen-ast-pos/main.go -astfile ast.go -constfile ast_const.go -outfile pos.go import ( "github.com/cloudspannerecosystem/memefish/token" diff --git a/ast/const.go b/ast/ast_const.go similarity index 93% rename from ast/const.go rename to ast/ast_const.go index d72748d1..eb629325 100644 --- a/ast/const.go +++ b/ast/ast_const.go @@ -1,5 +1,13 @@ package ast +// NOTE: This file defines constants used in AST nodes and they are used for automatic generation, +// so this file is conventional. +// +// Convention: +// +// - Each const types should be defined as a string type. +// - Each value is defined as a string literal. + // AllOrDistinct represents ALL or DISTINCT in SELECT or set operations, etc. // If it is optional, it may be an empty string, so handle it according to the context. type AllOrDistinct string diff --git a/ast/ast_test.go b/ast/ast_test.go deleted file mode 100644 index cc8ee6d3..00000000 --- a/ast/ast_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package ast - -import ( - "testing" -) - -func TestStatement(t *testing.T) { - Statement(&QueryStatement{}).isStatement() - Statement(&CreateDatabase{}).isStatement() - Statement(&AlterDatabase{}).isStatement() - Statement(&CreateTable{}).isStatement() - Statement(&AlterTable{}).isStatement() - Statement(&DropTable{}).isStatement() - Statement(&CreateIndex{}).isStatement() - Statement(&AlterIndex{}).isStatement() - Statement(&DropIndex{}).isStatement() - Statement(&CreateView{}).isStatement() - Statement(&DropView{}).isStatement() - Statement(&CreateChangeStream{}).isStatement() - Statement(&AlterChangeStream{}).isStatement() - Statement(&DropChangeStream{}).isStatement() - Statement(&CreateRole{}).isStatement() - Statement(&DropRole{}).isStatement() - Statement(&Grant{}).isStatement() - Statement(&Revoke{}).isStatement() - Statement(&CreateSequence{}).isStatement() - Statement(&AlterSequence{}).isStatement() - Statement(&DropSequence{}).isStatement() - Statement(&CreateVectorIndex{}).isStatement() - Statement(&DropVectorIndex{}).isStatement() - Statement(&Insert{}).isStatement() - Statement(&Delete{}).isStatement() - Statement(&Update{}).isStatement() -} - -func TestQueryExpr(t *testing.T) { - QueryExpr(&Select{}).isQueryExpr() - QueryExpr(&SubQuery{}).isQueryExpr() - QueryExpr(&CompoundQuery{}).isQueryExpr() -} - -func TestSelectItem(t *testing.T) { - SelectItem(&Star{}).isSelectItem() - SelectItem(&DotStar{}).isSelectItem() - SelectItem(&Alias{}).isSelectItem() - SelectItem(&ExprSelectItem{}).isSelectItem() -} - -func TestSelectAs(t *testing.T) { - SelectAs(&AsStruct{}).isSelectAs() - SelectAs(&AsValue{}).isSelectAs() - SelectAs(&AsTypeName{}).isSelectAs() -} - -func TestTableExpr(t *testing.T) { - TableExpr(&Unnest{}).isTableExpr() - TableExpr(&TableName{}).isTableExpr() - TableExpr(&SubQueryTableExpr{}).isTableExpr() - TableExpr(&ParenTableExpr{}).isTableExpr() - TableExpr(&Join{}).isTableExpr() -} - -func TestJoinCondition(t *testing.T) { - JoinCondition(&On{}).isJoinCondition() - JoinCondition(&Using{}).isJoinCondition() -} - -func TestExpr(t *testing.T) { - Expr(&BinaryExpr{}).isExpr() - Expr(&UnaryExpr{}).isExpr() - Expr(&InExpr{}).isExpr() - Expr(&IsNullExpr{}).isExpr() - Expr(&IsBoolExpr{}).isExpr() - Expr(&BetweenExpr{}).isExpr() - Expr(&SelectorExpr{}).isExpr() - Expr(&IndexExpr{}).isExpr() - Expr(&CallExpr{}).isExpr() - Expr(&CountStarExpr{}).isExpr() - Expr(&CastExpr{}).isExpr() - Expr(&ExtractExpr{}).isExpr() - Expr(&CaseExpr{}).isExpr() - Expr(&ParenExpr{}).isExpr() - Expr(&ScalarSubQuery{}).isExpr() - Expr(&ArraySubQuery{}).isExpr() - Expr(&ExistsSubQuery{}).isExpr() - Expr(&Param{}).isExpr() - Expr(&Ident{}).isExpr() - Expr(&Path{}).isExpr() - Expr(&ArrayLiteral{}).isExpr() - Expr(&TypedStructLiteral{}).isExpr() - Expr(&NullLiteral{}).isExpr() - Expr(&BoolLiteral{}).isExpr() - Expr(&IntLiteral{}).isExpr() - Expr(&FloatLiteral{}).isExpr() - Expr(&StringLiteral{}).isExpr() - Expr(&BytesLiteral{}).isExpr() - Expr(&DateLiteral{}).isExpr() - Expr(&TimestampLiteral{}).isExpr() - Expr(&NumericLiteral{}).isExpr() -} - -func TestArg(t *testing.T) { - Arg(&IntervalArg{}).isArg() - Arg(&ExprArg{}).isArg() - Arg(&SequenceArg{}).isArg() -} - -func TestInCondition(t *testing.T) { - InCondition(&UnnestInCondition{}).isInCondition() - InCondition(&SubQueryInCondition{}).isInCondition() - InCondition(&ValuesInCondition{}).isInCondition() -} - -func TestType(t *testing.T) { - Type(&SimpleType{}).isType() - Type(&ArrayType{}).isType() - Type(&StructType{}).isType() -} - -func TestIntValue(t *testing.T) { - IntValue(&Param{}).isIntValue() - IntValue(&IntLiteral{}).isIntValue() - IntValue(&CastIntValue{}).isIntValue() -} - -func TestNumValue(t *testing.T) { - NumValue(&Param{}).isNumValue() - NumValue(&IntLiteral{}).isNumValue() - NumValue(&FloatLiteral{}).isNumValue() - NumValue(&CastNumValue{}).isNumValue() -} - -func TestStringValue(t *testing.T) { - StringValue(&Param{}).isStringValue() - StringValue(&StringLiteral{}).isStringValue() -} - -func TestDDL(t *testing.T) { - DDL(&CreateDatabase{}).isDDL() - DDL(&AlterDatabase{}).isDDL() - DDL(&CreateTable{}).isDDL() - DDL(&AlterTable{}).isDDL() - DDL(&DropTable{}).isDDL() - DDL(&CreateIndex{}).isDDL() - DDL(&AlterIndex{}).isDDL() - DDL(&DropIndex{}).isDDL() - DDL(&CreateSearchIndex{}).isDDL() - DDL(&DropSearchIndex{}).isDDL() - DDL(&AlterSearchIndex{}).isDDL() - DDL(&CreateView{}).isDDL() - DDL(&DropView{}).isDDL() - DDL(&CreateChangeStream{}).isDDL() - DDL(&AlterChangeStream{}).isDDL() - DDL(&DropChangeStream{}).isDDL() - DDL(&CreateRole{}).isDDL() - DDL(&DropRole{}).isDDL() - DDL(&Grant{}).isDDL() - DDL(&Revoke{}).isDDL() - DDL(&CreateSequence{}).isDDL() - DDL(&AlterSequence{}).isDDL() - DDL(&DropSequence{}).isDDL() - DDL(&CreateVectorIndex{}).isDDL() - DDL(&DropVectorIndex{}).isDDL() -} - -func TestConstraint(t *testing.T) { - Constraint(&ForeignKey{}).isConstraint() - Constraint(&Check{}).isConstraint() -} - -func TestTableAlteration(t *testing.T) { - TableAlteration(&AddColumn{}).isTableAlteration() - TableAlteration(&AddTableConstraint{}).isTableAlteration() - TableAlteration(&DropColumn{}).isTableAlteration() - TableAlteration(&DropConstraint{}).isTableAlteration() - TableAlteration(&SetOnDelete{}).isTableAlteration() - TableAlteration(&AlterColumn{}).isTableAlteration() -} - -func TestPrivilege(t *testing.T) { - Privilege(&PrivilegeOnTable{}).isPrivilege() - Privilege(&SelectPrivilegeOnView{}).isPrivilege() - Privilege(&ExecutePrivilegeOnTableFunction{}).isPrivilege() - Privilege(&RolePrivilege{}).isPrivilege() -} - -func TestTablePrivilege(t *testing.T) { - TablePrivilege(&SelectPrivilege{}).isTablePrivilege() - TablePrivilege(&InsertPrivilege{}).isTablePrivilege() - TablePrivilege(&UpdatePrivilege{}).isTablePrivilege() - TablePrivilege(&DeletePrivilege{}).isTablePrivilege() -} - -func TestSchemaType(t *testing.T) { - SchemaType(&ScalarSchemaType{}).isSchemaType() - SchemaType(&SizedSchemaType{}).isSchemaType() - SchemaType(&ArraySchemaType{}).isSchemaType() -} - -func TestIndexAlteration(t *testing.T) { - IndexAlteration(&AddStoredColumn{}).isIndexAlteration() - IndexAlteration(&DropStoredColumn{}).isIndexAlteration() -} - -func TestDML(t *testing.T) { - DML(&Insert{}).isDML() - DML(&Delete{}).isDML() - DML(&Update{}).isDML() -} - -func TestInsertInput(t *testing.T) { - InsertInput(&ValuesInput{}).isInsertInput() - InsertInput(&SubQueryInput{}).isInsertInput() -} diff --git a/tools/astcatalog/main.go b/tools/astcatalog/main.go index 9b7d8c32..b8de13e7 100644 --- a/tools/astcatalog/main.go +++ b/tools/astcatalog/main.go @@ -9,13 +9,14 @@ import ( ) var ( - infile = flag.String("infile", "", "input filename") + astfile = flag.String("astfile", "ast/ast.go", "path to ast/ast.go") + constfile = flag.String("constfile", "ast/ast_const.go", "path to ast/ast_const.go") ) func main() { flag.Parse() - catalog, err := astcatalog.Load(*infile) + catalog, err := astcatalog.Load(*astfile, *constfile) if err != nil { log.Fatalf("failed to load: %v", err) } diff --git a/tools/gen-ast-pos/main.go b/tools/gen-ast-pos/main.go index ecaff8c4..7ccdd4d4 100644 --- a/tools/gen-ast-pos/main.go +++ b/tools/gen-ast-pos/main.go @@ -39,8 +39,9 @@ var ( ) var ( - infile = flag.String("infile", "", "input filename") - outfile = flag.String("outfile", "", "output filename (if it is not specified, the result is printed to stdout.)") + astfile = flag.String("astfile", "ast/ast.go", "path to ast/ast.go") + constfile = flag.String("constfile", "ast/ast_const.go", "path to ast/ast_const.go") + outfile = flag.String("outfile", "", "output filename (if it is not specified, the result is printed to stdout.)") ) func main() { @@ -51,41 +52,41 @@ func main() { flag.Parse() - catalog, err := astcatalog.Load(*infile) + catalog, err := astcatalog.Load(*astfile, *constfile) if err != nil { log.Fatal(err) } - nodes := make([]*astcatalog.NodeDef, 0, len(catalog)) - for _, node := range catalog { - nodes = append(nodes, node) + structs := make([]*astcatalog.NodeStructDef, 0, len(catalog.Structs)) + for _, structDef := range catalog.Structs { + structs = append(structs, structDef) } - sort.Slice(nodes, func(i, j int) bool { - return nodes[i].SourcePos < nodes[j].SourcePos + sort.Slice(structs, func(i, j int) bool { + return structs[i].SourcePos < structs[j].SourcePos }) var buffer bytes.Buffer buffer.WriteString(prologue) - for _, node := range nodes { - x := string(unicode.ToLower(rune(node.Name[0]))) + for _, structDef := range structs { + x := string(unicode.ToLower(rune(structDef.Name[0]))) - posExpr, err := poslang.Parse(node.Pos) + posExpr, err := poslang.Parse(structDef.Pos) if err != nil { log.Fatalf("error on parsing pos: %v", err) } - endExpr, err := poslang.Parse(node.End) + endExpr, err := poslang.Parse(structDef.End) if err != nil { log.Fatalf("error on parsing pos: %v", err) } fmt.Fprintln(&buffer) - fmt.Fprintf(&buffer, "func (%s *%s) Pos() token.Pos {\n", x, node.Name) + fmt.Fprintf(&buffer, "func (%s *%s) Pos() token.Pos {\n", x, structDef.Name) fmt.Fprintf(&buffer, "\treturn %s\n", posExpr.PosExprToGo(x)) fmt.Fprintf(&buffer, "}\n") fmt.Fprintln(&buffer) - fmt.Fprintf(&buffer, "func (%s *%s) End() token.Pos {\n", x, node.Name) + fmt.Fprintf(&buffer, "func (%s *%s) End() token.Pos {\n", x, structDef.Name) fmt.Fprintf(&buffer, "\treturn %s\n", endExpr.PosExprToGo(x)) fmt.Fprintf(&buffer, "}\n") } diff --git a/tools/util/astcatalog/catalog.go b/tools/util/astcatalog/catalog.go index 4ded6ab1..8c7ca55e 100644 --- a/tools/util/astcatalog/catalog.go +++ b/tools/util/astcatalog/catalog.go @@ -4,10 +4,15 @@ import ( "go/token" ) -type Catalog map[NodeStructType]*NodeDef +// Catalog is a catalog of AST types. +type Catalog struct { + Structs map[NodeStructType]*NodeStructDef + Interfaces map[NodeInterfaceType]*NodeInterfaceDef + Consts map[ConstType]*ConstDef +} -// NodeDef represents a node definition. -type NodeDef struct { +// NodeStructDef is a definition of node structs in ast/ast.go. +type NodeStructDef struct { SourcePos token.Pos Name string Doc string @@ -17,29 +22,68 @@ type NodeDef struct { Implements []NodeInterfaceType } +// FieldDef is a field definition of node structs in ast/ast.go. type FieldDef struct { Name string - Type FieldType + Type Type Comment string } -type FieldType interface{} +// NodeInterfaceDef is a definition of node interfaces in ast/ast.go. +type NodeInterfaceDef struct { + SourcePos token.Pos + Name string + Implemented []NodeStructType +} + +// ConstDef is a definition of const types in ast/ast_const.go +type ConstDef struct { + SourcePos token.Pos + Name string + Values []*ConstValueDef +} + +// ConstValueDef is a value definition of const types in ast/ast_const.go. +type ConstValueDef struct { + Name string + Value string +} + +// Type represents types used in Catalog. +type Type interface { + isType() +} +func (SliceType) isType() {} +func (PointerType) isType() {} +func (NodeStructType) isType() {} +func (NodeInterfaceType) isType() {} +func (PrimitiveType) isType() {} +func (ConstType) isType() {} + +// SliceType is a slice type. type SliceType struct { - Type FieldType + Type Type } +// PointerType is a pointer type. type PointerType struct { - Type FieldType + Type Type } +// NodeStructType is a type name of node structs defined in ast/ast.go. type NodeStructType string +// NodeInterfaceType is a type name of node interfaces defined in ast/ast.go. type NodeInterfaceType string -// PrimitiveType is a type name which is neither a node pointer nor a node interface. +// ConstType is a type name of const types defined in ast/ast_const.go. +type ConstType string + +// PrimitiveType is a type name which is neither a node pointer, a node interface, nor a const types. type PrimitiveType string +// PrimitiveType values. const ( BoolType PrimitiveType = "bool" IntType PrimitiveType = "int" diff --git a/tools/util/astcatalog/load.go b/tools/util/astcatalog/load.go index 8231a77d..43a56d3b 100644 --- a/tools/util/astcatalog/load.go +++ b/tools/util/astcatalog/load.go @@ -9,21 +9,107 @@ import ( "strings" ) +// Load loads a catalog from the given AST files. +func Load(astFilename, astConstFilename string) (*Catalog, error) { + fset := token.NewFileSet() + + consts, err := loadConsts(fset, astConstFilename) + if err != nil { + return nil, err + } + + structs, interfaces, err := loadStructs(fset, astFilename, consts) + if err != nil { + return nil, err + } + + catalog := &Catalog{ + Structs: structs, + Interfaces: interfaces, + Consts: consts, + } + return catalog, nil +} + +func loadConsts(fset *token.FileSet, filename string) (map[ConstType]*ConstDef, error) { + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("failed to parse '%s': %w", filename, err) + } + + consts := make(map[ConstType]*ConstDef) + + for _, decl := range f.Decls { + switch d := decl.(type) { + case *ast.GenDecl: + for _, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + name := ConstType(s.Name.Name) + if _, ok := consts[name]; ok { + return nil, fmt.Errorf("duplicated const: %s", name) + } + + consts[name] = &ConstDef{ + SourcePos: d.Pos(), + Name: s.Name.Name, + } + case *ast.ValueSpec: + if s.Type == nil { + return nil, fmt.Errorf("unexpected value spec: %#v", s) + } + ty, err := loadType(s.Type, nil, consts) + if err != nil { + return nil, err + } + name, ok := ty.(ConstType) + if !ok { + return nil, fmt.Errorf("unexpected type: %#v", s.Type) + } + + constDef, ok := consts[name] + if !ok { + return nil, fmt.Errorf("unknown const: %s", name) + } + + if len(s.Values) != 1 { + return nil, fmt.Errorf("unexpected values: %#v", s.Values) + } + lit, ok := s.Values[0].(*ast.BasicLit) + if !(ok && lit.Kind == token.STRING) { + return nil, fmt.Errorf("unexpected value: %#v", s.Values[0]) + } + v := strings.Trim(lit.Value, "\"") + + for _, name := range s.Names { + constDef.Values = append(constDef.Values, &ConstValueDef{ + Name: name.Name, + Value: v, + }) + } + } + } + } + } + + return consts, nil +} + +// Regular expressions to extract pos/end and template comments. var ( rePosLine = regexp.MustCompile(`(?m)^\s*pos\s*=\s*(.*)`) reEndLine = regexp.MustCompile(`(?m)^\s*end\s*=\s*(.*)`) reTmplLines = regexp.MustCompile(`(?m)((?:^\t.*\n)+)+`) ) -func Load(filename string) (Catalog, error) { - fset := token.NewFileSet() +func loadStructs(fset *token.FileSet, filename string, consts map[ConstType]*ConstDef) (map[NodeStructType]*NodeStructDef, map[NodeInterfaceType]*NodeInterfaceDef, error) { f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { - return nil, fmt.Errorf("failed to parse file: %w", err) + return nil, nil, fmt.Errorf("failed to parse '%s': %w", filename, err) } - catalog := make(Catalog) - interfaces := make(map[NodeInterfaceType]struct{}) + structs := make(map[NodeStructType]*NodeStructDef) + interfaces := make(map[NodeInterfaceType]*NodeInterfaceDef) commentMap := ast.NewCommentMap(fset, f, f.Comments) @@ -36,44 +122,44 @@ func Load(filename string) (Catalog, error) { switch t := s.Type.(type) { case *ast.StructType: name := NodeStructType(s.Name.Name) - if _, ok := catalog[name]; !ok { - catalog[name] = &NodeDef{ + if _, ok := structs[name]; !ok { + structs[name] = &NodeStructDef{ Name: s.Name.Name, } } - node := catalog[name] - node.Doc = d.Doc.Text() - node.SourcePos = s.Pos() + structDef := structs[name] + structDef.SourcePos = s.Pos() + structDef.Doc = d.Doc.Text() - if m := reTmplLines.FindAllStringSubmatch(node.Doc, -1); m != nil { - node.Tmpl = m[len(m)-1][0] + if m := reTmplLines.FindAllStringSubmatch(structDef.Doc, -1); m != nil { + structDef.Tmpl = m[len(m)-1][0] } else { - return nil, fmt.Errorf("no template found: %s", name) + return nil, nil, fmt.Errorf("no template found: %s", name) } comments := commentMap.Filter(t).Comments() if len(comments) == 0 { - return nil, fmt.Errorf("no pos/end comment found: %s", name) + return nil, nil, fmt.Errorf("no pos/end comment found: %s", name) } - // We assume the first comment group should a pos/end comment. + // We assume the first comment group in the struct should a pos/end comment. posComment := comments[0].Text() if m := rePosLine.FindStringSubmatch(posComment); m != nil { - node.Pos = m[1] + structDef.Pos = m[1] } else { - return nil, fmt.Errorf("no pos comment found: %s", name) + return nil, nil, fmt.Errorf("no pos comment found: %s", name) } if m := reEndLine.FindStringSubmatch(posComment); m != nil { - node.End = m[1] + structDef.End = m[1] } else { - return nil, fmt.Errorf("no end coment found: %s", name) + return nil, nil, fmt.Errorf("no end coment found: %s", name) } for _, f := range t.Fields.List { - ft, err := loadType(f.Type, interfaces) + ty, err := loadType(f.Type, interfaces, consts) if err != nil { - return nil, err + return nil, nil, err } comment := "" @@ -85,9 +171,9 @@ func Load(filename string) (Catalog, error) { } } for _, name := range f.Names { - node.Fields = append(node.Fields, &FieldDef{ + structDef.Fields = append(structDef.Fields, &FieldDef{ Name: name.Name, - Type: ft, + Type: ty, Comment: comment, }) } @@ -95,56 +181,68 @@ func Load(filename string) (Catalog, error) { case *ast.InterfaceType: name := NodeInterfaceType(s.Name.Name) if _, ok := interfaces[name]; ok { - return nil, fmt.Errorf("duplicated interface: %s", name) + return nil, nil, fmt.Errorf("duplicated interface: %s", name) + } + + // Node interface is special, so we skip it. + if name == "Node" { + continue } - interfaces[name] = struct{}{} + interfaces[name] = &NodeInterfaceDef{ + SourcePos: s.Pos(), + Name: string(name), + } default: - return nil, fmt.Errorf("unexpected spec: %#v", t) + return nil, nil, fmt.Errorf("unexpected spec: %#v", t) } } } + case *ast.FuncDecl: - if d.Recv == nil { - return nil, fmt.Errorf("unexpected func decl: %#v", d) + if d.Recv == nil || len(d.Recv.List) != 1 { + return nil, nil, fmt.Errorf("unexpected func decl: %#v", d) } - recv, err := loadType(d.Recv.List[0].Type, interfaces) + recv, err := loadType(d.Recv.List[0].Type, interfaces, consts) if err != nil { - return nil, err + return nil, nil, err } structName, ok := recv.(NodeStructType) if !ok { - return nil, fmt.Errorf("unexpected receiver type: %#v", recv) + return nil, nil, fmt.Errorf("unexpected receiver type: %#v", recv) } funcName := d.Name.Name cutName, found := strings.CutPrefix(funcName, "is") if !found { - return nil, fmt.Errorf("unexpected func name: %s", funcName) + return nil, nil, fmt.Errorf("unexpected func name: %s", funcName) } interfaceName := NodeInterfaceType(cutName) - if _, ok := interfaces[interfaceName]; !ok { - return nil, fmt.Errorf("unknown interface: %s", interfaceName) + interfaceDef, ok := interfaces[interfaceName] + if !ok { + return nil, nil, fmt.Errorf("unknown interface: %s", interfaceName) } - if _, ok := catalog[structName]; !ok { - catalog[structName] = &NodeDef{ + interfaceDef.Implemented = append(interfaceDef.Implemented, structName) + + if _, ok := structs[structName]; !ok { + structs[structName] = &NodeStructDef{ Name: string(structName), } } - node := catalog[structName] - node.Implements = append(node.Implements, interfaceName) + structDef := structs[structName] + structDef.Implements = append(structDef.Implements, interfaceName) } } - return catalog, nil + return structs, interfaces, nil } -func loadType(t ast.Expr, interfaces map[NodeInterfaceType]struct{}) (FieldType, error) { +func loadType(t ast.Expr, interfaces map[NodeInterfaceType]*NodeInterfaceDef, consts map[ConstType]*ConstDef) (Type, error) { switch t := t.(type) { case *ast.Ident: switch t.Name { @@ -155,6 +253,9 @@ func loadType(t ast.Expr, interfaces map[NodeInterfaceType]struct{}) (FieldType, case "string": return StringType, nil default: + if _, ok := consts[ConstType(t.Name)]; ok { + return ConstType(t.Name), nil + } if _, ok := interfaces[NodeInterfaceType(t.Name)]; ok { return NodeInterfaceType(t.Name), nil } @@ -173,23 +274,23 @@ func loadType(t ast.Expr, interfaces map[NodeInterfaceType]struct{}) (FieldType, return nil, fmt.Errorf("unexpected selector name: %#v", t) } case *ast.StarExpr: - ft, err := loadType(t.X, interfaces) + ty, err := loadType(t.X, interfaces, consts) if err != nil { return nil, err } - return PointerType{Type: ft}, nil + return PointerType{Type: ty}, nil case *ast.ArrayType: if t.Len != nil { return nil, fmt.Errorf("unexpected array type: %#v", t) } - ft, err := loadType(t.Elt, interfaces) + ty, err := loadType(t.Elt, interfaces, consts) if err != nil { return nil, err } - return SliceType{Type: ft}, nil + return SliceType{Type: ty}, nil default: return nil, fmt.Errorf("unexpected type: %#v", t) }