diff --git a/Makefile b/Makefile index e07d6d4..9500ee0 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all test test-color coverage coverage-cli coverate-html lint format +.PHONY: all test test-color coverage coverage-cli coverate-html lint format build all: test-color lint format go mod tidy @@ -25,3 +25,6 @@ lint: format: go install mvdan.cc/gofumpt@latest gofumpt -l -w -extra . + +build: + go build ./cmd/gofimports diff --git a/go.mod b/go.mod index a42a881..746cba5 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 require ( github.com/make-go-great/color-go v0.4.1 github.com/urfave/cli/v2 v2.23.5 + golang.org/x/tools v0.3.0 ) require ( @@ -14,5 +15,6 @@ require ( github.com/mattn/go-isatty v0.0.14 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect - golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect + golang.org/x/mod v0.7.0 // indirect + golang.org/x/sys v0.2.0 // indirect ) diff --git a/go.sum b/go.sum index 982c58a..7941c86 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,13 @@ github.com/urfave/cli/v2 v2.23.5 h1:xbrU7tAYviSpqeR3X4nEFWUdB/uDZ6DE+HxmRU7Xtyw= github.com/urfave/cli/v2 v2.23.5/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= +golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM= +golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= diff --git a/internal/cli/action.go b/internal/cli/action.go index becb516..9025eec 100644 --- a/internal/cli/action.go +++ b/internal/cli/action.go @@ -9,9 +9,11 @@ import ( type action struct { flags struct { - list bool - write bool - diff bool + companyPrefix string + list bool + write bool + diff bool + verbose bool } } @@ -23,6 +25,8 @@ func (a *action) getFlags(c *cli.Context) { a.flags.list = c.Bool(flagListName) a.flags.write = c.Bool(flagWriteName) a.flags.diff = c.Bool(flagDiffName) + a.flags.verbose = c.Bool(flagVerboseName) + a.flags.companyPrefix = c.String(flagCompanyPrefixName) } func (a *action) Run(c *cli.Context) error { @@ -40,14 +44,19 @@ func (a *action) Run(c *cli.Context) error { return a.RunHelp(c) } - f := imports.NewFormmater( + ft, err := imports.NewFormmater( imports.FormatterWithList(a.flags.list), imports.FormatterWithWrite(a.flags.write), imports.FormatterWithDiff(a.flags.diff), + imports.FormatterWithVerbose(a.flags.verbose), + imports.FormatterWithCompanyPrefix(a.flags.companyPrefix), ) + if err != nil { + return fmt.Errorf("imports: failed to new formatter: %w", err) + } args := c.Args().Slice() - if err := f.Format(args...); err != nil { + if err := ft.Format(args...); err != nil { return fmt.Errorf("imports formatter: failed to format %v: %w", args, err) } diff --git a/internal/cli/app.go b/internal/cli/app.go index 9b4c475..703211d 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -20,6 +20,12 @@ const ( flagDiffName = "diff" flagDiffUsage = "show diff" + + flagVerboseName = "verbose" + flagVerboseUsage = "show verbose output, for debug only" + + flagCompanyPrefixName = "company" + flagCompanyPrefixUsage = "company prefix, for example github.com/haunt98" ) var ( @@ -55,6 +61,14 @@ func NewApp() *App { Usage: flagDiffUsage, Aliases: flagDiffAliases, }, + &cli.BoolFlag{ + Name: flagVerboseName, + Usage: flagVerboseUsage, + }, + &cli.StringFlag{ + Name: flagCompanyPrefixName, + Usage: flagCompanyPrefixUsage, + }, }, Action: a.Run, } diff --git a/internal/imports/formatter.go b/internal/imports/formatter.go index fd62cfd..a26a2d7 100644 --- a/internal/imports/formatter.go +++ b/internal/imports/formatter.go @@ -3,28 +3,59 @@ package imports import ( "errors" "fmt" + "go/ast" "go/parser" "go/token" + "log" "os" "path/filepath" + "strings" + + "golang.org/x/tools/go/packages" ) -var ErrEmptyPaths = errors.New("empty paths") +const ( + // Use for group imports + stdImport = "std" + thirdPartyImport = "third-party" + companyImport = "company" + localImport = "local" +) + +var ( + ErrEmptyPaths = errors.New("empty paths") + ErrNotGoFile = errors.New("not go file") + ErrGoGeneratedFile = errors.New("go generated file") +) type Formatter struct { - isList bool - isWrite bool - isDiff bool + stdPackages map[string]struct{} + companyPrefix string + isList bool + isWrite bool + isDiff bool + isVerbose bool } -func NewFormmater(opts ...FormatterOptionFn) *Formatter { +func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) { ft := &Formatter{} for _, opt := range opts { opt(ft) } - return ft + stdPackages, err := packages.Load(nil, "std") + if err != nil { + return nil, fmt.Errorf("packages: failed to load std: %w", err) + } + + ft.stdPackages = make(map[string]struct{}) + for _, stdPackage := range stdPackages { + ft.stdPackages[stdPackage.PkgPath] = struct{}{} + } + ft.log("stdPackages: %+v\n", ft.stdPackages) + + return ft, nil } // Accept a list of files or directories @@ -57,17 +88,45 @@ func (ft *Formatter) formatDir(path string) error { } func (ft *Formatter) formatFile(path string) error { - // Check go file + // Return if not go file if !isGoFile(filepath.Base(path)) { - return nil + return ErrNotGoFile } + // Read file first pathBytes, err := os.ReadFile(path) if err != nil { return fmt.Errorf("os: failed to read file: [%s] %w", path, err) } // Parse ast + pathASTFile, err := ft.wrapParseAST(path, pathBytes) + if err != nil { + return err + } + + // Parse imports + importsAST, err := ft.parseImports(pathASTFile) + if err != nil { + return err + } + + // Return if empty imports + if len(importsAST) == 0 { + return nil + } + ft.log("importsAST: %+v\n", importsAST) + + groupImports, err := ft.groupImports(importsAST) + if err != nil { + return err + } + ft.log("groupImports: %+v\n", groupImports) + + return nil +} + +func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, error) { fset := token.NewFileSet() parserMode := parser.Mode(0) @@ -75,15 +134,82 @@ func (ft *Formatter) formatFile(path string) error { pathASTFile, err := parser.ParseFile(fset, path, pathBytes, parserMode) if err != nil { - return fmt.Errorf("parser: failed to parse file [%s]: %w", path, err) + return nil, fmt.Errorf("parser: failed to parse file [%s]: %w", path, err) } // Ignore generated file if isGoGenerated(pathASTFile) { - return nil + return nil, ErrGoGeneratedFile } - // TODO: fix imports - - return nil + return pathASTFile, nil +} + +// Copy from goimports-reviser +func (ft *Formatter) parseImports(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) { + result := make(map[string]*ast.ImportSpec) + + for _, decl := range pathASTFile.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + + if genDecl.Tok != token.IMPORT { + continue + } + + for _, spec := range genDecl.Specs { + importSpec, ok := spec.(*ast.ImportSpec) + if !ok { + continue + } + + var importNameAndPath string + if importSpec.Name != nil { + // Handle alias import + // xyz "github.com/abc/xyz/123" + importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value + } else { + // Handle normal import + // "github.com/abc/xyz" + importNameAndPath = importSpec.Path.Value + } + + result[importNameAndPath] = importSpec + } + } + + return result, nil +} + +// Copy from goimports-reviser +// Group imports to std, third-party, company if exist, local +func (ft *Formatter) groupImports(importsAST map[string]*ast.ImportSpec) (map[string][]string, error) { + result := make(map[string][]string) + result[stdImport] = make([]string, 0, 8) + result[thirdPartyImport] = make([]string, 0, 8) + if ft.companyPrefix != "" { + result[companyImport] = make([]string, 0, 8) + } + result[localImport] = make([]string, 0, 8) + + for importNameAndPath, importAST := range importsAST { + // "github.com/abc/xyz" -> github.com/abc/xyz + importPath := strings.Trim(importAST.Path.Value, "\"") + + if _, ok := ft.stdPackages[importPath]; ok { + result[stdImport] = append(result[stdImport], importNameAndPath) + continue + } + } + + return result, nil +} + +// Wrap log.Printf with verbose flag +func (ft *Formatter) log(format string, v ...any) { + if ft.isVerbose { + log.Printf(format, v...) + } } diff --git a/internal/imports/formatter_option.go b/internal/imports/formatter_option.go index baec46c..c84e5c7 100644 --- a/internal/imports/formatter_option.go +++ b/internal/imports/formatter_option.go @@ -3,19 +3,31 @@ package imports type FormatterOptionFn func(*Formatter) func FormatterWithList(isList bool) FormatterOptionFn { - return func(f *Formatter) { - f.isList = isList + return func(ft *Formatter) { + ft.isList = isList } } func FormatterWithWrite(isWrite bool) FormatterOptionFn { - return func(f *Formatter) { - f.isWrite = isWrite + return func(ft *Formatter) { + ft.isWrite = isWrite } } func FormatterWithDiff(isDiff bool) FormatterOptionFn { - return func(f *Formatter) { - f.isDiff = isDiff + return func(ft *Formatter) { + ft.isDiff = isDiff + } +} + +func FormatterWithVerbose(isVerbose bool) FormatterOptionFn { + return func(ft *Formatter) { + ft.isVerbose = isVerbose + } +} + +func FormatterWithCompanyPrefix(companyPrefix string) FormatterOptionFn { + return func(ft *Formatter) { + ft.companyPrefix = companyPrefix } }