diff --git a/internal/imports/formatter.go b/internal/imports/formatter.go index aca2cb8..a1a283e 100644 --- a/internal/imports/formatter.go +++ b/internal/imports/formatter.go @@ -65,7 +65,7 @@ func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) { for _, stdPackage := range stdPackages { ft.stdPackages[stdPackage.PkgPath] = struct{}{} } - ft.log("stdPackages: %+v\n", ft.stdPackages) + ft.log("NewFormmater: stdPackages: %+v\n", ft.stdPackages) ft.moduleNames = make(map[string]string) ft.formattedPaths = make(map[string]struct{}) @@ -133,7 +133,7 @@ func (ft *Formatter) formatFile(path string) error { } // Parse imports - importsAST, err := ft.parseImports(pathASTFile) + importsAST, err := ft.parseImportsAndCombine(pathASTFile) if err != nil { return err } @@ -142,19 +142,19 @@ func (ft *Formatter) formatFile(path string) error { if len(importsAST) == 0 { return nil } - ft.log("importsAST: %+v\n", importsAST) + ft.log("formatFile: importsAST: %+v\n", importsAST) moduleName, err := ft.moduleName(path) if err != nil { return err } - ft.log("moduleName: %+v\n", moduleName) + ft.log("formatFile: moduleName: %+v\n", moduleName) groupImports, err := ft.groupImports(importsAST, moduleName) if err != nil { return err } - ft.log("groupImports: %+v\n", groupImports) + ft.log("formatFile: groupImports: %+v\n", groupImports) ft.muFormattedPaths.Lock() ft.formattedPaths[path] = struct{}{} @@ -183,38 +183,66 @@ func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, err } // Copy from goimports-reviser -func (ft *Formatter) parseImports(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) { +// If exist multi import, combine them into one +// This func will edit pathASTFile directly to combine +func (ft *Formatter) parseImportsAndCombine(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) { result := make(map[string]*ast.ImportSpec) + // Extract imports + importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports)) + for _, importSpec := range pathASTFile.Imports { + ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value) + importSpecs = append(importSpecs, importSpec) + + var importNameAndPath string + if importSpec.Name != nil { + // Handle alias import + // xyz "github.com/abc/xyz/123" + importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value + } else { + // Handle normal import + // "github.com/abc/xyz" + importNameAndPath = importSpec.Path.Value + } + + result[importNameAndPath] = importSpec + } + + // Combine multi import decl + isMultiImportDecl := false + isExistFirstImportDecl := false + decls := make([]ast.Decl, 0, len(pathASTFile.Decls)) for _, decl := range pathASTFile.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok { + decls = append(decls, decl) continue } if genDecl.Tok != token.IMPORT { + decls = append(decls, decl) continue } - for _, spec := range genDecl.Specs { - importSpec, ok := spec.(*ast.ImportSpec) - if !ok { - continue + if isExistFirstImportDecl { + isMultiImportDecl = true + // TODO: explain this + storedGenDecl := decls[len(decls)-1].(*ast.GenDecl) + if storedGenDecl.Tok == token.IMPORT { + storedGenDecl.Rparen = genDecl.End() } - - var importNameAndPath string - if importSpec.Name != nil { - // Handle alias import - // xyz "github.com/abc/xyz/123" - importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value - } else { - // Handle normal import - // "github.com/abc/xyz" - importNameAndPath = importSpec.Path.Value - } - - result[importNameAndPath] = importSpec + continue } + + // First import decl take all + isExistFirstImportDecl = true + genDecl.Specs = importSpecs + decls = append(decls, genDecl) + } + ft.log("parseImportsAndCombine: decls: %+v\n", decls) + + if isMultiImportDecl { + pathASTFile.Decls = decls } return result, nil @@ -257,12 +285,12 @@ func (ft *Formatter) groupImports( result[thirdPartyImport] = append(result[thirdPartyImport], importNameAndPath) } - ft.log("std %+v\n", result[stdImport]) - ft.log("third-party %+v\n", result[thirdPartyImport]) + ft.log("groupImports: std: %+v\n", result[stdImport]) + ft.log("groupImports: third-party: %+v\n", result[thirdPartyImport]) if ft.companyPrefix != "" { - ft.log("company %+v\n", result[companyImport]) + ft.log("groupImports: company: %+v\n", result[companyImport]) } - ft.log("local %+v\n", result[localImport]) + ft.log("groupImports: local: %+v\n", result[localImport]) // TODO: not sure if this match gofumpt output, but at lease it is sorted sort.Strings(result[stdImport]) @@ -308,7 +336,7 @@ func (ft *Formatter) moduleName(path string) (string, error) { if goModPath == "" { return "", ErrGoModNotExist } - ft.log("goModPath: %+v\n", goModPath) + ft.log("moduleName: goModPath: %+v\n", goModPath) goModPathBytes, err := os.ReadFile(goModPath) if err != nil {