diff --git a/internal/imports/formatter.go b/internal/imports/formatter.go index a1a283e..ce90a7b 100644 --- a/internal/imports/formatter.go +++ b/internal/imports/formatter.go @@ -9,7 +9,6 @@ import ( "log" "os" "path/filepath" - "sort" "strings" "sync" @@ -126,35 +125,23 @@ func (ft *Formatter) formatFile(path string) error { return fmt.Errorf("os: failed to read file: [%s] %w", path, err) } - // Parse ast - pathASTFile, err := ft.wrapParseAST(path, pathBytes) - if err != nil { - return err - } - - // Parse imports - importsAST, err := ft.parseImportsAndCombine(pathASTFile) - if err != nil { - return err - } - - // Return if empty imports - if len(importsAST) == 0 { - return nil - } - ft.log("formatFile: importsAST: %+v\n", importsAST) - + // Get module name of path moduleName, err := ft.moduleName(path) if err != nil { return err } ft.log("formatFile: moduleName: %+v\n", moduleName) - groupImports, err := ft.groupImports(importsAST, moduleName) + // Parse ast + pathASTFile, err := ft.parseAST(path, pathBytes) if err != nil { return err } - ft.log("formatFile: groupImports: %+v\n", groupImports) + ft.log("formatFile: pathASTFile: %+v\n", pathASTFile) + + if err := ft.formatASTFile(pathASTFile, moduleName); err != nil { + return err + } ft.muFormattedPaths.Lock() ft.formattedPaths[path] = struct{}{} @@ -163,7 +150,7 @@ func (ft *Formatter) formatFile(path string) error { return nil } -func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, error) { +func (ft *Formatter) parseAST(path string, pathBytes []byte) (*ast.File, error) { fset := token.NewFileSet() parserMode := parser.Mode(0) @@ -185,27 +172,28 @@ func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, err // Copy from goimports-reviser // If exist multi import, combine them into one // This func will edit pathASTFile directly to combine -func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) { - result := make(map[string]*ast.ImportSpec) - +func (ft *Formatter) formatASTFile( + pathASTFile *ast.File, + moduleName string, +) error { // Extract imports importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports)) for _, importSpec := range pathASTFile.Imports { ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value) importSpecs = append(importSpecs, importSpec) + } - var importNameAndPath string - if importSpec.Name != nil { - // Handle alias import - // xyz "github.com/abc/xyz/123" - importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value - } else { - // Handle normal import - // "github.com/abc/xyz" - importNameAndPath = importSpec.Path.Value - } + groupedImportSpecs, err := ft.groupImportSpecs(importSpecs, moduleName) + if err != nil { + return err + } - result[importNameAndPath] = importSpec + formattedImportSpecs, err := ft.formatImportSpecs( + importSpecs, + groupedImportSpecs, + ) + if err != nil { + return err } // Combine multi import decl @@ -236,7 +224,7 @@ func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]* // First import decl take all isExistFirstImportDecl = true - genDecl.Specs = importSpecs + genDecl.Specs = formattedImportSpecs decls = append(decls, genDecl) } ft.log("parseImportsAndCombine: decls: %+v\n", decls) @@ -245,44 +233,49 @@ func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]* pathASTFile.Decls = decls } - return result, nil + return nil } // Copy from goimports-reviser // Group imports to std, third-party, company if exist, local -func (ft *Formatter) groupImports( - importsAST map[string]*ast.ImportSpec, +func (ft *Formatter) groupImportSpecs( + importSpecs []ast.Spec, moduleName string, -) (map[string][]string, error) { - result := make(map[string][]string) - result[stdImport] = make([]string, 0, 8) - result[thirdPartyImport] = make([]string, 0, 8) +) (map[string][]*ast.ImportSpec, error) { + result := make(map[string][]*ast.ImportSpec) + result[stdImport] = make([]*ast.ImportSpec, 0, 8) + result[thirdPartyImport] = make([]*ast.ImportSpec, 0, 8) if ft.companyPrefix != "" { - result[companyImport] = make([]string, 0, 8) + result[companyImport] = make([]*ast.ImportSpec, 0, 8) } - result[localImport] = make([]string, 0, 8) + result[localImport] = make([]*ast.ImportSpec, 0, 8) + + for _, importSpec := range importSpecs { + importSpec, ok := importSpec.(*ast.ImportSpec) + if !ok { + continue + } - for importNameAndPath, importAST := range importsAST { // "github.com/abc/xyz" -> github.com/abc/xyz - importPath := strings.Trim(importAST.Path.Value, "\"") + importPath := strings.Trim(importSpec.Path.Value, "\"") if _, ok := ft.stdPackages[importPath]; ok { - result[stdImport] = append(result[stdImport], importNameAndPath) + result[stdImport] = append(result[stdImport], importSpec) continue } if strings.HasPrefix(importPath, moduleName) { - result[localImport] = append(result[localImport], importNameAndPath) + result[localImport] = append(result[localImport], importSpec) continue } if ft.companyPrefix != "" && strings.HasPrefix(importPath, ft.companyPrefix) { - result[companyImport] = append(result[companyImport], importNameAndPath) + result[companyImport] = append(result[companyImport], importSpec) continue } - result[thirdPartyImport] = append(result[thirdPartyImport], importNameAndPath) + result[thirdPartyImport] = append(result[thirdPartyImport], importSpec) } ft.log("groupImports: std: %+v\n", result[stdImport]) @@ -292,13 +285,67 @@ func (ft *Formatter) groupImports( } ft.log("groupImports: local: %+v\n", result[localImport]) - // TODO: not sure if this match gofumpt output, but at lease it is sorted - sort.Strings(result[stdImport]) - sort.Strings(result[thirdPartyImport]) - if ft.companyPrefix != "" { - sort.Strings(result[companyImport]) + return result, nil +} + +func (ft *Formatter) formatImportSpecs( + importSpecs []ast.Spec, + groupedImportSpecs map[string][]*ast.ImportSpec, +) ([]ast.Spec, error) { + result := make([]ast.Spec, 0, len(importSpecs)) + + if stdImportSpecs, ok := groupedImportSpecs[stdImport]; ok && len(stdImportSpecs) != 0 { + for _, importSpec := range stdImportSpecs { + result = append(result, importSpec) + } + } + + if thirdPartImportSpecs, ok := groupedImportSpecs[thirdPartyImport]; ok && len(thirdPartImportSpecs) != 0 { + if len(result) != 0 { + result = append(result, &ast.ImportSpec{ + Path: &ast.BasicLit{ + Value: "", + Kind: token.STRING, + }, + }) + } + + for _, importSpec := range thirdPartImportSpecs { + result = append(result, importSpec) + } + } + + if ft.companyPrefix != "" { + if companyImportSpecs, ok := groupedImportSpecs[companyImport]; ok && len(companyImportSpecs) != 0 { + if len(result) != 0 { + result = append(result, &ast.ImportSpec{ + Path: &ast.BasicLit{ + Value: "", + Kind: token.STRING, + }, + }) + } + + for _, importSpec := range companyImportSpecs { + result = append(result, importSpec) + } + } + } + + if localImportSpecs, ok := groupedImportSpecs[localImport]; ok && len(localImportSpecs) != 0 { + if len(result) != 0 { + result = append(result, &ast.ImportSpec{ + Path: &ast.BasicLit{ + Value: "", + Kind: token.STRING, + }, + }) + } + + for _, importSpec := range localImportSpecs { + result = append(result, importSpec) + } } - sort.Strings(result[localImport]) return result, nil }