diff --git a/internal/imports/formatter.go b/internal/imports/formatter.go index 27f661d..4ed4c02 100644 --- a/internal/imports/formatter.go +++ b/internal/imports/formatter.go @@ -11,6 +11,7 @@ import ( "path/filepath" "sort" "strings" + "sync" "golang.org/x/tools/go/packages" ) @@ -27,15 +28,23 @@ var ( ErrEmptyPaths = errors.New("empty paths") ErrNotGoFile = errors.New("not go file") ErrGoGeneratedFile = errors.New("go generated file") + ErrGoModNotExist = errors.New("go mod not exist") ) +// stdPackages -> save std packages for later search +// packageNames -> map path to its go.mod package name +// formattedPaths -> make sure we not format path more than 1 time type Formatter struct { - stdPackages map[string]struct{} - companyPrefix string - isList bool - isWrite bool - isDiff bool - isVerbose bool + stdPackages map[string]struct{} + packageNames map[string]string + formattedPaths map[string]struct{} + companyPrefix string + muPackageNames sync.RWMutex + muFormattedPaths sync.RWMutex + isList bool + isWrite bool + isDiff bool + isVerbose bool } func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) { @@ -56,6 +65,9 @@ func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) { } ft.log("stdPackages: %+v\n", ft.stdPackages) + ft.packageNames = make(map[string]string) + ft.formattedPaths = make(map[string]struct{}) + return ft, nil } @@ -67,6 +79,11 @@ func (ft *Formatter) Format(paths ...string) error { // Logic switch case copy from goimports, gofumpt for _, path := range paths { + path = strings.TrimSpace(path) + if path == "" { + continue + } + switch dir, err := os.Stat(path); { case err != nil: return fmt.Errorf("os: failed to stat: [%s] %w", path, err) @@ -89,6 +106,13 @@ func (ft *Formatter) formatDir(path string) error { } func (ft *Formatter) formatFile(path string) error { + ft.muFormattedPaths.RLock() + if _, ok := ft.formattedPaths[path]; ok { + ft.muFormattedPaths.RUnlock() + return nil + } + ft.muFormattedPaths.RUnlock() + // Return if not go file if !isGoFile(filepath.Base(path)) { return ErrNotGoFile @@ -118,12 +142,23 @@ func (ft *Formatter) formatFile(path string) error { } ft.log("importsAST: %+v\n", importsAST) + // TODO: Find dir go.mod package name + pkgName, err := ft.packageName(path) + if err != nil { + return err + } + ft.log("pkgName: %+v\n", pkgName) + groupImports, err := ft.groupImports(importsAST) if err != nil { return err } ft.log("groupImports: %+v\n", groupImports) + ft.muFormattedPaths.Lock() + ft.formattedPaths[path] = struct{}{} + ft.muFormattedPaths.Unlock() + return nil } @@ -216,6 +251,44 @@ func (ft *Formatter) groupImports(importsAST map[string]*ast.ImportSpec) (map[st return result, nil } +func (ft *Formatter) packageName(path string) (string, error) { + ft.muPackageNames.RLock() + if pkgName, ok := ft.packageNames[path]; ok { + ft.muPackageNames.RUnlock() + return pkgName, nil + } + ft.muPackageNames.RUnlock() + + // Copy from goimports-reviser + // Check path/go.mod first + // If not exist -> check ../go.mod + // Assume path is dir path, maybe wrong but it is ok for now + dirPath := filepath.Clean(path) + var goModPath string + for { + goModPath = filepath.Join(dirPath, "go.mod") + fileInfo, err := os.Stat(goModPath) + if err == nil && !fileInfo.IsDir() { + break + } + + // Check .. + if dirPath == filepath.Dir(dirPath) { + // Reach root + break + } + + dirPath = filepath.Dir(dirPath) + } + + if goModPath == "" { + return "", ErrGoModNotExist + } + ft.log("goModPath: %+v\n", goModPath) + + return "", nil +} + // Wrap log.Printf with verbose flag func (ft *Formatter) log(format string, v ...any) { if ft.isVerbose {