feat: rewrite all logic to single loop ast.Decl (wip)

main
sudo pacman -Syu 2022-11-26 23:26:58 +07:00
parent 1889ddb517
commit ec9ec20189
No known key found for this signature in database
GPG Key ID: D6CB5C6C567C47B0
1 changed files with 105 additions and 58 deletions

View File

@ -9,7 +9,6 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strings" "strings"
"sync" "sync"
@ -126,35 +125,23 @@ func (ft *Formatter) formatFile(path string) error {
return fmt.Errorf("os: failed to read file: [%s] %w", path, err) return fmt.Errorf("os: failed to read file: [%s] %w", path, err)
} }
// Parse ast // Get module name of path
pathASTFile, err := ft.wrapParseAST(path, pathBytes)
if err != nil {
return err
}
// Parse imports
importsAST, err := ft.parseImportsAndCombine(pathASTFile)
if err != nil {
return err
}
// Return if empty imports
if len(importsAST) == 0 {
return nil
}
ft.log("formatFile: importsAST: %+v\n", importsAST)
moduleName, err := ft.moduleName(path) moduleName, err := ft.moduleName(path)
if err != nil { if err != nil {
return err return err
} }
ft.log("formatFile: moduleName: %+v\n", moduleName) ft.log("formatFile: moduleName: %+v\n", moduleName)
groupImports, err := ft.groupImports(importsAST, moduleName) // Parse ast
pathASTFile, err := ft.parseAST(path, pathBytes)
if err != nil { if err != nil {
return err return err
} }
ft.log("formatFile: groupImports: %+v\n", groupImports) ft.log("formatFile: pathASTFile: %+v\n", pathASTFile)
if err := ft.formatASTFile(pathASTFile, moduleName); err != nil {
return err
}
ft.muFormattedPaths.Lock() ft.muFormattedPaths.Lock()
ft.formattedPaths[path] = struct{}{} ft.formattedPaths[path] = struct{}{}
@ -163,7 +150,7 @@ func (ft *Formatter) formatFile(path string) error {
return nil return nil
} }
func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, error) { func (ft *Formatter) parseAST(path string, pathBytes []byte) (*ast.File, error) {
fset := token.NewFileSet() fset := token.NewFileSet()
parserMode := parser.Mode(0) parserMode := parser.Mode(0)
@ -185,27 +172,28 @@ func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, err
// Copy from goimports-reviser // Copy from goimports-reviser
// If exist multi import, combine them into one // If exist multi import, combine them into one
// This func will edit pathASTFile directly to combine // This func will edit pathASTFile directly to combine
func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) { func (ft *Formatter) formatASTFile(
result := make(map[string]*ast.ImportSpec) pathASTFile *ast.File,
moduleName string,
) error {
// Extract imports // Extract imports
importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports)) importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports))
for _, importSpec := range pathASTFile.Imports { for _, importSpec := range pathASTFile.Imports {
ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value) ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value)
importSpecs = append(importSpecs, importSpec) importSpecs = append(importSpecs, importSpec)
var importNameAndPath string
if importSpec.Name != nil {
// Handle alias import
// xyz "github.com/abc/xyz/123"
importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value
} else {
// Handle normal import
// "github.com/abc/xyz"
importNameAndPath = importSpec.Path.Value
} }
result[importNameAndPath] = importSpec groupedImportSpecs, err := ft.groupImportSpecs(importSpecs, moduleName)
if err != nil {
return err
}
formattedImportSpecs, err := ft.formatImportSpecs(
importSpecs,
groupedImportSpecs,
)
if err != nil {
return err
} }
// Combine multi import decl // Combine multi import decl
@ -236,7 +224,7 @@ func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]*
// First import decl take all // First import decl take all
isExistFirstImportDecl = true isExistFirstImportDecl = true
genDecl.Specs = importSpecs genDecl.Specs = formattedImportSpecs
decls = append(decls, genDecl) decls = append(decls, genDecl)
} }
ft.log("parseImportsAndCombine: decls: %+v\n", decls) ft.log("parseImportsAndCombine: decls: %+v\n", decls)
@ -245,44 +233,49 @@ func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]*
pathASTFile.Decls = decls pathASTFile.Decls = decls
} }
return result, nil return nil
} }
// Copy from goimports-reviser // Copy from goimports-reviser
// Group imports to std, third-party, company if exist, local // Group imports to std, third-party, company if exist, local
func (ft *Formatter) groupImports( func (ft *Formatter) groupImportSpecs(
importsAST map[string]*ast.ImportSpec, importSpecs []ast.Spec,
moduleName string, moduleName string,
) (map[string][]string, error) { ) (map[string][]*ast.ImportSpec, error) {
result := make(map[string][]string) result := make(map[string][]*ast.ImportSpec)
result[stdImport] = make([]string, 0, 8) result[stdImport] = make([]*ast.ImportSpec, 0, 8)
result[thirdPartyImport] = make([]string, 0, 8) result[thirdPartyImport] = make([]*ast.ImportSpec, 0, 8)
if ft.companyPrefix != "" { if ft.companyPrefix != "" {
result[companyImport] = make([]string, 0, 8) result[companyImport] = make([]*ast.ImportSpec, 0, 8)
}
result[localImport] = make([]*ast.ImportSpec, 0, 8)
for _, importSpec := range importSpecs {
importSpec, ok := importSpec.(*ast.ImportSpec)
if !ok {
continue
} }
result[localImport] = make([]string, 0, 8)
for importNameAndPath, importAST := range importsAST {
// "github.com/abc/xyz" -> github.com/abc/xyz // "github.com/abc/xyz" -> github.com/abc/xyz
importPath := strings.Trim(importAST.Path.Value, "\"") importPath := strings.Trim(importSpec.Path.Value, "\"")
if _, ok := ft.stdPackages[importPath]; ok { if _, ok := ft.stdPackages[importPath]; ok {
result[stdImport] = append(result[stdImport], importNameAndPath) result[stdImport] = append(result[stdImport], importSpec)
continue continue
} }
if strings.HasPrefix(importPath, moduleName) { if strings.HasPrefix(importPath, moduleName) {
result[localImport] = append(result[localImport], importNameAndPath) result[localImport] = append(result[localImport], importSpec)
continue continue
} }
if ft.companyPrefix != "" && if ft.companyPrefix != "" &&
strings.HasPrefix(importPath, ft.companyPrefix) { strings.HasPrefix(importPath, ft.companyPrefix) {
result[companyImport] = append(result[companyImport], importNameAndPath) result[companyImport] = append(result[companyImport], importSpec)
continue continue
} }
result[thirdPartyImport] = append(result[thirdPartyImport], importNameAndPath) result[thirdPartyImport] = append(result[thirdPartyImport], importSpec)
} }
ft.log("groupImports: std: %+v\n", result[stdImport]) ft.log("groupImports: std: %+v\n", result[stdImport])
@ -292,13 +285,67 @@ func (ft *Formatter) groupImports(
} }
ft.log("groupImports: local: %+v\n", result[localImport]) ft.log("groupImports: local: %+v\n", result[localImport])
// TODO: not sure if this match gofumpt output, but at lease it is sorted return result, nil
sort.Strings(result[stdImport]) }
sort.Strings(result[thirdPartyImport])
if ft.companyPrefix != "" { func (ft *Formatter) formatImportSpecs(
sort.Strings(result[companyImport]) importSpecs []ast.Spec,
groupedImportSpecs map[string][]*ast.ImportSpec,
) ([]ast.Spec, error) {
result := make([]ast.Spec, 0, len(importSpecs))
if stdImportSpecs, ok := groupedImportSpecs[stdImport]; ok && len(stdImportSpecs) != 0 {
for _, importSpec := range stdImportSpecs {
result = append(result, importSpec)
}
}
if thirdPartImportSpecs, ok := groupedImportSpecs[thirdPartyImport]; ok && len(thirdPartImportSpecs) != 0 {
if len(result) != 0 {
result = append(result, &ast.ImportSpec{
Path: &ast.BasicLit{
Value: "",
Kind: token.STRING,
},
})
}
for _, importSpec := range thirdPartImportSpecs {
result = append(result, importSpec)
}
}
if ft.companyPrefix != "" {
if companyImportSpecs, ok := groupedImportSpecs[companyImport]; ok && len(companyImportSpecs) != 0 {
if len(result) != 0 {
result = append(result, &ast.ImportSpec{
Path: &ast.BasicLit{
Value: "",
Kind: token.STRING,
},
})
}
for _, importSpec := range companyImportSpecs {
result = append(result, importSpec)
}
}
}
if localImportSpecs, ok := groupedImportSpecs[localImport]; ok && len(localImportSpecs) != 0 {
if len(result) != 0 {
result = append(result, &ast.ImportSpec{
Path: &ast.BasicLit{
Value: "",
Kind: token.STRING,
},
})
}
for _, importSpec := range localImportSpecs {
result = append(result, importSpec)
}
} }
sort.Strings(result[localImport])
return result, nil return result, nil
} }