-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathast.go
145 lines (124 loc) · 3.45 KB
/
ast.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package runtask
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"strings"
"unicode"
)
func overridePackageDecl(code, name string) string {
scanner := bufio.NewScanner(strings.NewReader(code))
var builder strings.Builder
builder.WriteString("package " + name + "\n")
for scanner.Scan() {
line := scanner.Text()
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "package ") {
continue
}
builder.WriteString(line + "\n")
}
return builder.String()
}
func parseSource(fset *token.FileSet, src, name string) (*ast.File, error) {
src = overridePackageDecl(src, name)
return parser.ParseFile(fset, "", src, parser.ParseComments)
}
func astToString(fset *token.FileSet, a *ast.File) string {
var buf bytes.Buffer
if err := printer.Fprint(&buf, fset, a); err != nil {
fmt.Printf("printer.Fprint error: %v\n", err)
return ""
}
// Reformat the code
formattedCode, err := format.Source(buf.Bytes())
if err != nil {
fmt.Printf("format.Source error: %v\n", err)
fmt.Printf("%s\n", string(buf.Bytes()))
return string(buf.Bytes())
}
return string(formattedCode)
}
func mergeASTs(files ...*ast.File) *ast.File {
newImportDecl := &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{},
}
newFile := &ast.File{
Name: ast.NewIdent("main"),
Decls: []ast.Decl{newImportDecl},
}
funcMap := make(map[string]*ast.FuncDecl)
for _, file := range files {
for _, decl := range file.Decls {
switch t := decl.(type) {
case *ast.GenDecl:
if t.Tok == token.IMPORT {
for _, spec := range t.Specs {
if importSpec, ok := spec.(*ast.ImportSpec); ok {
newImportDecl.Specs = append(newImportDecl.Specs, importSpec)
}
}
} else {
newFile.Decls = append(newFile.Decls, decl)
}
case *ast.FuncDecl:
key := t.Name.Name // Default key is the function name
if t.Recv != nil && len(t.Recv.List) > 0 {
if starExpr, ok := t.Recv.List[0].Type.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
key = ident.Name + "__" + t.Name.Name
}
} else if ident, ok := t.Recv.List[0].Type.(*ast.Ident); ok {
key = ident.Name + "__" + t.Name.Name
}
}
funcMap[key] = t
default:
newFile.Decls = append(newFile.Decls, decl)
}
}
}
for _, funcDecl := range funcMap {
newFile.Decls = append(newFile.Decls, funcDecl)
}
return newFile
}
func extractTasks(file *ast.File) (map[string]string, map[string]string, map[string][]string) {
functions := make(map[string]string)
comments := make(map[string]string)
argNames := make(map[string][]string)
for _, decl := range file.Decls {
if fn, ok := decl.(*ast.FuncDecl); ok {
funcName := fn.Name.Name
// Only handle "exported" functions
if fn.Recv != nil || !unicode.IsUpper(rune(funcName[0])) {
continue
}
taskName := strings.ToLower(funcName)
functions[taskName] = funcName
// Comments
var comment string
if fn.Doc != nil {
comment = strings.TrimSpace(fn.Doc.Text())
}
comments[taskName] = comment
// Argument
var args []string
if fn.Type.Params != nil {
for _, param := range fn.Type.Params.List {
for _, ident := range param.Names {
args = append(args, ident.Name)
}
}
}
argNames[taskName] = args
}
}
return functions, comments, argNames
}