diff --git a/go.mod b/go.mod index a74abf2..12116a7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/haunt98/gofimports go 1.19 require ( + github.com/dave/dst v0.27.2 github.com/make-go-great/color-go v0.4.1 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e github.com/urfave/cli/v2 v2.23.7 diff --git a/go.sum b/go.sum index 314cb02..45dd33d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/dave/dst v0.27.2 h1:4Y5VFTkhGLC1oddtNwuxxe36pnyLxMFXT51FOzH8Ekc= +github.com/dave/dst v0.27.2/go.mod h1:jHh6EOibnHgcUW3WjKHisiooEkYwqpHLBSX1iOBhEyc= +github.com/dave/jennifer v1.5.0 h1:HmgPN93bVDpkQyYbqhCHj5QlgvUkvEOzMyEvKLgCRrg= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/make-go-great/color-go v0.4.1 h1:4HIat5r1oQneu5BSDUjHKN/9x4Ay2qEEkswdOgk/VRY= @@ -13,6 +16,7 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgc github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY= github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= diff --git a/internal/imports/formatter.go b/internal/imports/formatter.go index 4976a4b..bb63b0a 100644 --- a/internal/imports/formatter.go +++ b/internal/imports/formatter.go @@ -4,9 +4,7 @@ import ( "bytes" "errors" "fmt" - "go/ast" "go/parser" - "go/printer" "go/token" "io/fs" "os" @@ -14,6 +12,8 @@ import ( "strings" "sync" + "github.com/dave/dst" + "github.com/dave/dst/decorator" "github.com/pkg/diff" "golang.org/x/mod/modfile" "golang.org/x/tools/go/packages" @@ -31,6 +31,7 @@ var ( ErrEmptyPaths = errors.New("empty paths") ErrNotGoFile = errors.New("not go file") ErrGoGeneratedFile = errors.New("go generated file") + ErrAlreadyFormatted = errors.New("already formatted") ErrGoModNotExist = errors.New("go mod not exist") ErrGoModEmptyModule = errors.New("go mod empty module") ) @@ -99,6 +100,10 @@ func (ft *Formatter) Format(paths ...string) error { } default: if err := ft.formatFile(path); err != nil { + if ft.isIgnoreError(err) { + continue + } + return err } } @@ -123,8 +128,7 @@ func (ft *Formatter) formatDir(path string) error { } if err := ft.formatFile(path); err != nil { - if errors.Is(err, ErrNotGoFile) || - errors.Is(err, ErrGoGeneratedFile) { + if ft.isIgnoreError(err) { return nil } @@ -163,13 +167,17 @@ func (ft *Formatter) formatFile(path string) error { if err != nil { return err } - ft.log("formatFile: moduleName: %+v\n", moduleName) + ft.log("formatFile: moduleName: [%s]\n", moduleName) formattedBytes, err := ft.formatImports(path, pathBytes, moduleName) if err != nil { return err } + if bytes.Compare(pathBytes, formattedBytes) == 0 { + return ErrAlreadyFormatted + } + if ft.isList { fmt.Println("Formatted: ", path) } @@ -226,103 +234,49 @@ func (ft *Formatter) formatImports( return nil, ErrGoGeneratedFile } - // Extract imports - importSpecs := make([]ast.Spec, 0, len(astFile.Imports)) - for _, importSpec := range astFile.Imports { - importSpecs = append(importSpecs, importSpec) + dstFile, err := decorator.Parse(pathBytes) + if err != nil { + return nil, fmt.Errorf("decorator: failed to parse file [%s]: %w", path, err) } - ft.mustLogImportSpecs("formatImports: importSpecs", importSpecs) + ft.logDSTImportSpecs("formatImports: dstImportSpecs", dstFile.Imports) - groupedImportSpecs, err := ft.groupImportSpecs( - importSpecs, + groupedDSTImportSpecs, err := ft.groupDSTImportSpecs( + dstFile.Imports, moduleName, ) if err != nil { return nil, err } - formattedImportSpecs, err := ft.formatImportSpecs( - importSpecs, - groupedImportSpecs, - ) + formattedDSTImportSpecs, err := ft.formatDSTImportSpecs(groupedDSTImportSpecs) if err != nil { return nil, err } - ft.mustLogImportSpecs("formatImports: formattedImportSpecs: ", formattedImportSpecs) + ft.logDSTImportSpecs("formatImports: formattedDSTImportSpecs: ", formattedDSTImportSpecs) - // Combine multi import decl into one - isExistFirstImportDecl := false - decls := make([]ast.Decl, 0, len(astFile.Decls)) + dstFile.Imports = formattedDSTImportSpecs - for _, decl := range astFile.Decls { - genDecl, ok := decl.(*ast.GenDecl) - if !ok { - decls = append(decls, decl) - continue - } - - if genDecl.Tok != token.IMPORT { - decls = append(decls, decl) - continue - } - - // Ignore second import decl and more - if isExistFirstImportDecl { - continue - } - - // Ignore empty import - if len(genDecl.Specs) == 0 { - continue - } - - // First import decl take all - isExistFirstImportDecl = true - genDecl.Specs = formattedImportSpecs - decls = append(decls, genDecl) + var buf bytes.Buffer + if err := decorator.Fprint(&buf, dstFile); err != nil { + return nil, fmt.Errorf("decorator: failed to fprint [%s]: %w", path, err) } - // Update ast decls - astFile.Decls = decls - - // Print formatted bytes from formatted ast - var formattedBytes []byte - formattedBuffer := bytes.NewBuffer(formattedBytes) - - // Copy from goimports - printerMode := printer.UseSpaces - printerMode |= printer.TabIndent - printerCfg := &printer.Config{ - Mode: printerMode, - Tabwidth: 8, - } - if err := printerCfg.Fprint(formattedBuffer, fset, astFile); err != nil { - return nil, fmt.Errorf("printer: failed to fprint: %w", err) - } - - return formattedBuffer.Bytes(), nil + return buf.Bytes(), nil } -// Copy from goimports-reviser -// Group imports to std, third-party, company if exist, local -func (ft *Formatter) groupImportSpecs( - importSpecs []ast.Spec, +func (ft *Formatter) groupDSTImportSpecs( + importSpecs []*dst.ImportSpec, moduleName string, -) (map[string][]*ast.ImportSpec, error) { - result := make(map[string][]*ast.ImportSpec) - result[stdImport] = make([]*ast.ImportSpec, 0, 8) - result[thirdPartyImport] = make([]*ast.ImportSpec, 0, 8) +) (map[string][]*dst.ImportSpec, error) { + result := make(map[string][]*dst.ImportSpec) + result[stdImport] = make([]*dst.ImportSpec, 0, 8) + result[thirdPartyImport] = make([]*dst.ImportSpec, 0, 8) if ft.companyPrefix != "" { - result[companyImport] = make([]*ast.ImportSpec, 0, 8) + result[companyImport] = make([]*dst.ImportSpec, 0, 8) } - result[localImport] = make([]*ast.ImportSpec, 0, 8) + result[localImport] = make([]*dst.ImportSpec, 0, 8) for _, importSpec := range importSpecs { - importSpec, ok := importSpec.(*ast.ImportSpec) - if !ok { - continue - } - // "github.com/abc/xyz" -> github.com/abc/xyz importPath := strings.Trim(importSpec.Path.Value, `"`) @@ -345,61 +299,43 @@ func (ft *Formatter) groupImportSpecs( result[thirdPartyImport] = append(result[thirdPartyImport], importSpec) } - ft.logImportSpecs("stdImport", result[stdImport]) - ft.logImportSpecs("thirdPartyImport", result[thirdPartyImport]) + ft.logDSTImportSpecs("groupDSTImportSpecs: stdImport", result[stdImport]) + ft.logDSTImportSpecs("groupDSTImportSpecs: thirdPartyImport", result[thirdPartyImport]) if ft.companyPrefix != "" { - ft.logImportSpecs("companyImport", result[companyImport]) + ft.logDSTImportSpecs("groupDSTImportSpecs: companyImport", result[companyImport]) } - ft.logImportSpecs("localImport", result[localImport]) + ft.logDSTImportSpecs("groupDSTImportSpecs: localImport", result[localImport]) return result, nil } -// Copy from goimports-reviser -// Insert empty import (empty path) between groups -func (ft *Formatter) formatImportSpecs( - importSpecs []ast.Spec, - groupedImportSpecs map[string][]*ast.ImportSpec, -) ([]ast.Spec, error) { - result := make([]ast.Spec, 0, len(importSpecs)) +func (ft *Formatter) formatDSTImportSpecs(groupedImportSpecs map[string][]*dst.ImportSpec, +) ([]*dst.ImportSpec, error) { + result := make([]*dst.ImportSpec, 0, 32) appendToResultFn := func(groupImportType string) { - if specs, ok := groupedImportSpecs[groupImportType]; ok && len(specs) != 0 { - if len(result) != 0 { - result = append(result, &ast.ImportSpec{ - Path: &ast.BasicLit{ - Value: "", - Kind: token.STRING, - }, - }) - } - - for _, spec := range specs { - newSpec := &ast.ImportSpec{ - Path: &ast.BasicLit{ - Value: spec.Path.Value, - Kind: token.STRING, - }, - } - - if spec.Name != nil { - newSpec.Name = &ast.Ident{ - Name: spec.Name.Name, - } - } - - result = append(result, newSpec) - } + importSpecs, ok := groupedImportSpecs[groupImportType] + if !ok || len(importSpecs) == 0 { + return } + + for _, importSpec := range importSpecs { + importSpec.Decs.Before = dst.NewLine + importSpec.Decs.After = dst.NewLine + } + + importSpecs[len(importSpecs)-1].Decs.After = dst.EmptyLine + + result = append(result, importSpecs...) } appendToResultFn(stdImport) appendToResultFn(thirdPartyImport) - if ft.companyPrefix != "" { - appendToResultFn(companyImport) - } + appendToResultFn(companyImport) appendToResultFn(localImport) + result[len(result)-1].Decs.After = dst.NewLine + return result, nil } @@ -462,3 +398,9 @@ func (ft *Formatter) moduleName(path string) (string, error) { return result, nil } + +func (ft *Formatter) isIgnoreError(err error) bool { + return errors.Is(err, ErrNotGoFile) || + errors.Is(err, ErrGoGeneratedFile) || + errors.Is(err, ErrAlreadyFormatted) +} diff --git a/internal/imports/formatter_log.go b/internal/imports/formatter_log.go index 01aaa61..aa64bc8 100644 --- a/internal/imports/formatter_log.go +++ b/internal/imports/formatter_log.go @@ -1,8 +1,9 @@ package imports import ( - "go/ast" "log" + + "github.com/dave/dst" ) // Wrap log.Printf with verbose flag @@ -12,23 +13,10 @@ func (ft *Formatter) log(format string, v ...any) { } } -func (ft *Formatter) logImportSpecs(logPrefix string, importSpecs []*ast.ImportSpec) { +func (ft *Formatter) logDSTImportSpecs(logPrefix string, importSpecs []*dst.ImportSpec) { if ft.isVerbose { for _, importSpec := range importSpecs { - log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value) - } - } -} - -func (ft *Formatter) mustLogImportSpecs(logPrefix string, importSpecs []ast.Spec) { - if ft.isVerbose { - for _, importSpec := range importSpecs { - importSpec, ok := importSpec.(*ast.ImportSpec) - if !ok { - continue - } - - log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value) + log.Printf("%s: [%s] [%s] before %v after %v\n", logPrefix, importSpec.Name, importSpec.Path.Value, importSpec.Decs.Before, importSpec.Decs.After) } } }