gofimports/internal/imports/formatter.go

488 lines
11 KiB
Go
Raw Permalink Normal View History

2022-11-24 17:14:47 +00:00
package imports
import (
2022-11-26 17:08:16 +00:00
"bytes"
"errors"
"fmt"
"go/parser"
"go/token"
2022-11-27 17:54:56 +00:00
"io/fs"
"os"
"path/filepath"
"strings"
2022-11-26 05:34:28 +00:00
"sync"
2023-01-17 04:45:17 +00:00
"github.com/dave/dst"
"github.com/dave/dst/decorator"
2022-11-26 18:07:15 +00:00
"github.com/pkg/diff"
2022-11-26 05:39:34 +00:00
"golang.org/x/mod/modfile"
2023-01-17 10:30:58 +00:00
"golang.org/x/sync/errgroup"
"golang.org/x/tools/go/packages"
)
const (
// Use for group imports
stdImport = "std"
thirdPartyImport = "third-party"
companyImport = "company"
localImport = "local"
)
var (
2022-11-26 05:39:34 +00:00
ErrEmptyPaths = errors.New("empty paths")
ErrNotGoFile = errors.New("not go file")
ErrGoGeneratedFile = errors.New("go generated file")
2023-01-17 04:45:17 +00:00
ErrAlreadyFormatted = errors.New("already formatted")
2023-01-17 04:53:49 +00:00
ErrEmptyImport = errors.New("empty import")
2022-11-26 05:39:34 +00:00
ErrGoModNotExist = errors.New("go mod not exist")
ErrGoModEmptyModule = errors.New("go mod empty module")
2023-02-25 16:53:39 +00:00
ErrNotBytesBuffer = errors.New("not bytes.Buffer")
ErrNotDSTGenDecl = errors.New("not dst.GenDecl")
)
// https://pkg.go.dev/sync#Pool
var bufPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
2023-01-17 16:58:44 +00:00
// stdPackages -> save std packages for later search.
// moduleNames -> map path to its go.mod module name.
// formattedPaths -> make sure we not format path more than 1 time.
2022-11-24 17:14:47 +00:00
type Formatter struct {
2022-11-26 05:34:28 +00:00
stdPackages map[string]struct{}
2022-11-26 05:39:34 +00:00
moduleNames map[string]string
2022-11-26 05:34:28 +00:00
formattedPaths map[string]struct{}
companyPrefixes map[string]struct{}
2023-01-17 10:30:58 +00:00
eg errgroup.Group
2022-11-26 05:39:34 +00:00
muModuleNames sync.RWMutex
2022-11-26 05:34:28 +00:00
muFormattedPaths sync.RWMutex
isList bool
isWrite bool
isDiff bool
isVerbose bool
isStock bool
2022-11-24 17:14:47 +00:00
}
func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) {
ft := &Formatter{}
2022-11-24 17:14:47 +00:00
for _, opt := range opts {
opt(ft)
}
stdPackages, err := packages.Load(nil, "std")
if err != nil {
return nil, fmt.Errorf("packages: failed to load std: %w", err)
}
ft.stdPackages = make(map[string]struct{})
for _, stdPackage := range stdPackages {
ft.stdPackages[stdPackage.PkgPath] = struct{}{}
}
2022-11-26 05:39:34 +00:00
ft.moduleNames = make(map[string]string)
2022-11-26 05:34:28 +00:00
ft.formattedPaths = make(map[string]struct{})
return ft, nil
}
// Accept a list of files or directories
func (ft *Formatter) Format(paths ...string) error {
if len(paths) == 0 {
return ErrEmptyPaths
2022-11-24 17:14:47 +00:00
}
// Logic switch case copy from goimports, gofumpt
for _, path := range paths {
2023-01-17 10:30:58 +00:00
path := strings.TrimSpace(path)
2022-11-26 05:34:28 +00:00
if path == "" {
continue
}
switch dir, err := os.Stat(path); {
case err != nil:
return fmt.Errorf("os: failed to stat: [%s] %w", path, err)
case dir.IsDir():
if err := ft.formatDir(path); err != nil {
return err
}
default:
2023-01-17 10:30:58 +00:00
ft.eg.Go(func() error {
if err := ft.formatFile(path); err != nil {
if ft.isIgnoreError(err) {
return nil
}
return err
2023-01-17 04:45:17 +00:00
}
2023-01-17 10:30:58 +00:00
return nil
})
}
}
2023-01-17 10:30:58 +00:00
if err := ft.eg.Wait(); err != nil {
return err
}
return nil
}
2022-11-27 17:54:56 +00:00
// Copy from gofumpt
func (ft *Formatter) formatDir(path string) error {
2022-11-27 17:54:56 +00:00
if err := filepath.WalkDir(path, func(path string, dirEntry fs.DirEntry, err error) error {
if filepath.Base(path) == "vendor" {
return filepath.SkipDir
}
if err != nil {
return err
}
if dirEntry.IsDir() {
// Get module name ASAP to cache it
moduleName, err := ft.moduleName(path)
if err != nil {
return err
}
ft.log("formatFile: moduleName: [%s]\n", moduleName)
2022-11-27 17:54:56 +00:00
return nil
}
2023-01-17 10:30:58 +00:00
ft.eg.Go(func() error {
if err := ft.formatFile(path); err != nil {
if ft.isIgnoreError(err) {
return nil
}
return err
}
2023-01-17 10:30:58 +00:00
return nil
})
return nil
2022-11-27 17:54:56 +00:00
}); err != nil {
return fmt.Errorf("filepath: failed to walk dir: [%s] %w", path, err)
}
return nil
2022-11-24 17:14:47 +00:00
}
func (ft *Formatter) formatFile(path string) error {
2022-11-26 05:34:28 +00:00
ft.muFormattedPaths.RLock()
if _, ok := ft.formattedPaths[path]; ok {
ft.muFormattedPaths.RUnlock()
return nil
}
ft.muFormattedPaths.RUnlock()
// Return if not go file
if !isGoFile(filepath.Base(path)) {
return ErrNotGoFile
}
// Read file first
pathBytes, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("os: failed to read file: [%s] %w", path, err)
}
// Get module name of path
moduleName, err := ft.moduleName(path)
if err != nil {
return err
}
2023-01-17 04:45:17 +00:00
ft.log("formatFile: moduleName: [%s]\n", moduleName)
2022-11-26 18:07:15 +00:00
formattedBytes, err := ft.formatImports(path, pathBytes, moduleName)
if err != nil {
2022-11-26 05:34:28 +00:00
return err
}
if bytes.Equal(pathBytes, formattedBytes) {
2023-01-17 04:45:17 +00:00
return ErrAlreadyFormatted
}
2022-11-26 18:07:15 +00:00
if ft.isList {
2022-11-27 17:39:59 +00:00
fmt.Println("Formatted: ", path)
2022-11-26 18:07:15 +00:00
}
if ft.isWrite {
2022-11-27 17:39:59 +00:00
if err := os.WriteFile(path, formattedBytes, 0o600); err != nil {
return fmt.Errorf("os: failed to write file: [%s] %w", path, err)
}
2022-11-26 18:07:15 +00:00
}
if ft.isDiff {
2022-11-26 18:08:34 +00:00
if err := diff.Text(path+" before", path+" after", pathBytes, formattedBytes, os.Stdout); err != nil {
2022-11-26 18:07:15 +00:00
return fmt.Errorf("diff: failed to slices: %w", err)
}
}
2022-11-26 05:34:28 +00:00
ft.muFormattedPaths.Lock()
ft.formattedPaths[path] = struct{}{}
ft.muFormattedPaths.Unlock()
return nil
}
2023-01-17 04:53:49 +00:00
// Copy from goimports, gofumpt, goimports-reviser.
// First parse ast.
// Then group imports.
// Then format imports.
2023-01-17 16:58:44 +00:00
// Then print to bytes.
2022-11-26 17:08:16 +00:00
func (ft *Formatter) formatImports(
path string,
pathBytes []byte,
moduleName string,
2022-11-26 18:07:15 +00:00
) ([]byte, error) {
2022-11-26 17:08:16 +00:00
// Parse ast
fset := token.NewFileSet()
2022-11-28 04:14:32 +00:00
// Copy from gofumpt
2022-11-28 03:35:50 +00:00
parserMode := parser.Mode(0)
parserMode |= parser.ParseComments
parserMode |= parser.SkipObjectResolution
2022-11-26 17:08:16 +00:00
astFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
if err != nil {
2022-11-26 18:07:15 +00:00
return nil, fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
}
// Ignore generated file
2022-11-26 17:08:16 +00:00
if isGoGenerated(astFile) {
2022-11-26 18:07:15 +00:00
return nil, ErrGoGeneratedFile
}
dec := decorator.NewDecorator(fset)
dstFile, err := dec.DecorateFile(astFile)
2023-01-17 04:45:17 +00:00
if err != nil {
return nil, fmt.Errorf("decorator: failed to parse file [%s]: %w", path, err)
}
2023-01-17 10:20:05 +00:00
if len(dstFile.Imports) == 0 || len(dstFile.Decls) == 0 {
2023-01-17 04:53:49 +00:00
return nil, ErrEmptyImport
}
2023-01-17 04:45:17 +00:00
ft.logDSTImportSpecs("formatImports: dstImportSpecs", dstFile.Imports)
2023-01-17 04:54:57 +00:00
groupedDSTImportSpecs, err := ft.groupDSTImportSpecs(dstFile.Imports, moduleName)
if err != nil {
2022-11-26 18:07:15 +00:00
return nil, err
}
2023-01-17 04:45:17 +00:00
formattedDSTImportSpecs, err := ft.formatDSTImportSpecs(groupedDSTImportSpecs)
if err != nil {
2022-11-26 18:07:15 +00:00
return nil, err
}
2023-01-17 04:45:17 +00:00
ft.logDSTImportSpecs("formatImports: formattedDSTImportSpecs: ", formattedDSTImportSpecs)
2023-01-17 10:20:05 +00:00
// First update
2023-01-17 04:45:17 +00:00
dstFile.Imports = formattedDSTImportSpecs
2022-11-28 04:14:32 +00:00
2023-02-25 16:53:39 +00:00
genDecl, ok := dstFile.Decls[0].(*dst.GenDecl)
if !ok {
return nil, ErrNotDSTGenDecl
}
formattedGenSpecs := make([]dst.Spec, 0, len(genDecl.Specs))
2023-01-17 10:20:05 +00:00
// Append all imports first
for _, importSpec := range formattedDSTImportSpecs {
formattedGenSpecs = append(formattedGenSpecs, importSpec)
}
// Append all non imports later
2023-02-25 16:53:39 +00:00
for _, genSpec := range genDecl.Specs {
2023-01-17 10:20:05 +00:00
if _, ok := genSpec.(*dst.ImportSpec); !ok {
formattedGenSpecs = append(formattedGenSpecs, genSpec)
continue
}
}
// Second update
2023-02-25 16:53:39 +00:00
genDecl.Specs = formattedGenSpecs
2023-01-17 10:20:05 +00:00
2023-02-25 16:53:39 +00:00
b, ok := bufPool.Get().(*bytes.Buffer)
if !ok {
return nil, ErrNotBytesBuffer
}
b.Reset()
defer bufPool.Put(b)
if err := decorator.Fprint(b, dstFile); err != nil {
2023-01-17 04:45:17 +00:00
return nil, fmt.Errorf("decorator: failed to fprint [%s]: %w", path, err)
2022-11-26 17:08:16 +00:00
}
result := make([]byte, b.Len())
copy(result, b.Bytes())
return result, nil
}
2023-01-17 04:54:57 +00:00
func (ft *Formatter) groupDSTImportSpecs(importSpecs []*dst.ImportSpec, moduleName string) (map[string][]*dst.ImportSpec, error) {
2023-01-17 04:45:17 +00:00
result := make(map[string][]*dst.ImportSpec)
result[stdImport] = make([]*dst.ImportSpec, 0, 8)
result[thirdPartyImport] = make([]*dst.ImportSpec, 0, 8)
if !ft.isStock {
// Only split company, local imports if not stock
// Otherwise everything is third party imports
if len(ft.companyPrefixes) != 0 {
result[companyImport] = make([]*dst.ImportSpec, 0, 8)
}
result[localImport] = make([]*dst.ImportSpec, 0, 8)
}
for _, importSpec := range importSpecs {
// "github.com/abc/xyz" -> github.com/abc/xyz
importPath := strings.Trim(importSpec.Path.Value, `"`)
if _, ok := ft.stdPackages[importPath]; ok {
result[stdImport] = append(result[stdImport], importSpec)
continue
}
if !ft.isStock {
if strings.HasPrefix(importPath, moduleName) {
result[localImport] = append(result[localImport], importSpec)
continue
}
if len(ft.companyPrefixes) != 0 {
existImport := false
for companyPrefix := range ft.companyPrefixes {
if strings.HasPrefix(importPath, companyPrefix) {
result[companyImport] = append(result[companyImport], importSpec)
existImport = true
break
}
}
if existImport {
continue
}
}
}
result[thirdPartyImport] = append(result[thirdPartyImport], importSpec)
}
2023-01-17 04:45:17 +00:00
ft.logDSTImportSpecs("groupDSTImportSpecs: stdImport", result[stdImport])
ft.logDSTImportSpecs("groupDSTImportSpecs: thirdPartyImport", result[thirdPartyImport])
if len(ft.companyPrefixes) != 0 {
2023-01-17 04:45:17 +00:00
ft.logDSTImportSpecs("groupDSTImportSpecs: companyImport", result[companyImport])
}
2023-01-17 04:45:17 +00:00
ft.logDSTImportSpecs("groupDSTImportSpecs: localImport", result[localImport])
return result, nil
}
2023-01-17 04:45:17 +00:00
func (ft *Formatter) formatDSTImportSpecs(groupedImportSpecs map[string][]*dst.ImportSpec,
) ([]*dst.ImportSpec, error) {
result := make([]*dst.ImportSpec, 0, 32)
appendToResultFn := func(groupImportType string) {
2023-01-17 04:45:17 +00:00
importSpecs, ok := groupedImportSpecs[groupImportType]
if !ok || len(importSpecs) == 0 {
return
}
2023-01-17 04:45:17 +00:00
for _, importSpec := range importSpecs {
importSpec.Decs.Before = dst.NewLine
importSpec.Decs.After = dst.NewLine
}
2022-11-28 03:55:16 +00:00
2023-01-17 04:45:17 +00:00
importSpecs[len(importSpecs)-1].Decs.After = dst.EmptyLine
2022-11-28 03:55:16 +00:00
2023-01-17 04:45:17 +00:00
result = append(result, importSpecs...)
}
appendToResultFn(stdImport)
appendToResultFn(thirdPartyImport)
2023-01-17 04:45:17 +00:00
appendToResultFn(companyImport)
appendToResultFn(localImport)
2022-11-25 19:18:02 +00:00
2023-01-17 04:47:51 +00:00
if len(result) == 0 {
return result, nil
}
2023-01-17 04:45:17 +00:00
result[len(result)-1].Decs.After = dst.NewLine
return result, nil
}
2022-11-26 17:08:16 +00:00
// Copy from goimports-reviser
2022-11-28 05:22:10 +00:00
// Get module name from go.mod of path
// If current path doesn't have go.mod, recursive find its parent path
2022-11-26 05:39:34 +00:00
func (ft *Formatter) moduleName(path string) (string, error) {
ft.muModuleNames.RLock()
if pkgName, ok := ft.moduleNames[path]; ok {
ft.muModuleNames.RUnlock()
2022-11-26 05:34:28 +00:00
return pkgName, nil
}
2022-11-26 05:39:34 +00:00
ft.muModuleNames.RUnlock()
2022-11-26 05:34:28 +00:00
// Copy from goimports-reviser
// Check path/go.mod first
// If not exist -> check ../go.mod
// Assume path is dir path, maybe wrong but it is ok for now
dirPath := filepath.Clean(path)
var goModPath string
for {
ft.muModuleNames.RLock()
if pkgName, ok := ft.moduleNames[dirPath]; ok {
ft.muModuleNames.RUnlock()
return pkgName, nil
}
ft.muModuleNames.RUnlock()
2022-11-26 05:34:28 +00:00
goModPath = filepath.Join(dirPath, "go.mod")
fileInfo, err := os.Stat(goModPath)
if err == nil && !fileInfo.IsDir() {
break
}
// Check ..
if dirPath == filepath.Dir(dirPath) {
// Reach root
break
}
dirPath = filepath.Dir(dirPath)
}
if goModPath == "" {
return "", ErrGoModNotExist
}
ft.log("moduleName: goModPath: %+v\n", goModPath)
2022-11-26 05:34:28 +00:00
2022-11-26 05:39:34 +00:00
goModPathBytes, err := os.ReadFile(goModPath)
if err != nil {
return "", fmt.Errorf("os: failed to read file: [%s] %w", goModPath, err)
}
goModFile, err := modfile.Parse(goModPath, goModPathBytes, nil)
if err != nil {
return "", fmt.Errorf("modfile: failed to parse: [%s] %w", goModPath, err)
}
result := goModFile.Module.Mod.Path
if result == "" {
return "", ErrGoModEmptyModule
}
2022-11-26 05:40:34 +00:00
ft.muModuleNames.Lock()
ft.moduleNames[path] = result
ft.muModuleNames.Unlock()
2022-11-26 05:39:34 +00:00
return result, nil
2022-11-26 05:34:28 +00:00
}
2023-01-17 04:45:17 +00:00
func (ft *Formatter) isIgnoreError(err error) bool {
return errors.Is(err, ErrNotGoFile) ||
errors.Is(err, ErrGoGeneratedFile) ||
2023-01-17 04:53:49 +00:00
errors.Is(err, ErrAlreadyFormatted) ||
errors.Is(err, ErrEmptyImport)
2023-01-17 04:45:17 +00:00
}