2022-11-24 17:14:47 +00:00
|
|
|
package imports
|
|
|
|
|
2022-11-24 18:17:35 +00:00
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2022-11-25 19:10:19 +00:00
|
|
|
"go/ast"
|
2022-11-24 18:17:35 +00:00
|
|
|
"go/parser"
|
|
|
|
"go/token"
|
2022-11-25 19:10:19 +00:00
|
|
|
"log"
|
2022-11-24 18:17:35 +00:00
|
|
|
"os"
|
|
|
|
"path/filepath"
|
2022-11-25 19:10:19 +00:00
|
|
|
"strings"
|
|
|
|
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
// Use for group imports
|
|
|
|
stdImport = "std"
|
|
|
|
thirdPartyImport = "third-party"
|
|
|
|
companyImport = "company"
|
|
|
|
localImport = "local"
|
2022-11-24 18:17:35 +00:00
|
|
|
)
|
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
var (
|
|
|
|
ErrEmptyPaths = errors.New("empty paths")
|
|
|
|
ErrNotGoFile = errors.New("not go file")
|
|
|
|
ErrGoGeneratedFile = errors.New("go generated file")
|
|
|
|
)
|
2022-11-24 18:17:35 +00:00
|
|
|
|
2022-11-24 17:14:47 +00:00
|
|
|
type Formatter struct {
|
2022-11-25 19:10:19 +00:00
|
|
|
stdPackages map[string]struct{}
|
|
|
|
companyPrefix string
|
|
|
|
isList bool
|
|
|
|
isWrite bool
|
|
|
|
isDiff bool
|
|
|
|
isVerbose bool
|
2022-11-24 17:14:47 +00:00
|
|
|
}
|
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
func NewFormmater(opts ...FormatterOptionFn) (*Formatter, error) {
|
2022-11-24 18:17:35 +00:00
|
|
|
ft := &Formatter{}
|
2022-11-24 17:14:47 +00:00
|
|
|
|
|
|
|
for _, opt := range opts {
|
2022-11-24 18:17:35 +00:00
|
|
|
opt(ft)
|
|
|
|
}
|
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
stdPackages, err := packages.Load(nil, "std")
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("packages: failed to load std: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
ft.stdPackages = make(map[string]struct{})
|
|
|
|
for _, stdPackage := range stdPackages {
|
|
|
|
ft.stdPackages[stdPackage.PkgPath] = struct{}{}
|
|
|
|
}
|
|
|
|
ft.log("stdPackages: %+v\n", ft.stdPackages)
|
|
|
|
|
|
|
|
return ft, nil
|
2022-11-24 18:17:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Accept a list of files or directories
|
|
|
|
func (ft *Formatter) Format(paths ...string) error {
|
|
|
|
if len(paths) == 0 {
|
|
|
|
return ErrEmptyPaths
|
2022-11-24 17:14:47 +00:00
|
|
|
}
|
|
|
|
|
2022-11-24 18:17:35 +00:00
|
|
|
// Logic switch case copy from goimports, gofumpt
|
|
|
|
for _, path := range paths {
|
|
|
|
switch dir, err := os.Stat(path); {
|
|
|
|
case err != nil:
|
|
|
|
return fmt.Errorf("os: failed to stat: [%s] %w", path, err)
|
|
|
|
case dir.IsDir():
|
|
|
|
if err := ft.formatDir(path); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
if err := ft.formatFile(path); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ft *Formatter) formatDir(path string) error {
|
|
|
|
return nil
|
2022-11-24 17:14:47 +00:00
|
|
|
}
|
|
|
|
|
2022-11-24 18:17:35 +00:00
|
|
|
func (ft *Formatter) formatFile(path string) error {
|
2022-11-25 19:10:19 +00:00
|
|
|
// Return if not go file
|
2022-11-24 18:17:35 +00:00
|
|
|
if !isGoFile(filepath.Base(path)) {
|
2022-11-25 19:10:19 +00:00
|
|
|
return ErrNotGoFile
|
2022-11-24 18:17:35 +00:00
|
|
|
}
|
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
// Read file first
|
2022-11-24 18:17:35 +00:00
|
|
|
pathBytes, err := os.ReadFile(path)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("os: failed to read file: [%s] %w", path, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse ast
|
2022-11-25 19:10:19 +00:00
|
|
|
pathASTFile, err := ft.wrapParseAST(path, pathBytes)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse imports
|
|
|
|
importsAST, err := ft.parseImports(pathASTFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Return if empty imports
|
|
|
|
if len(importsAST) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
ft.log("importsAST: %+v\n", importsAST)
|
|
|
|
|
|
|
|
groupImports, err := ft.groupImports(importsAST)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
ft.log("groupImports: %+v\n", groupImports)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ft *Formatter) wrapParseAST(path string, pathBytes []byte) (*ast.File, error) {
|
2022-11-24 18:17:35 +00:00
|
|
|
fset := token.NewFileSet()
|
|
|
|
|
|
|
|
parserMode := parser.Mode(0)
|
|
|
|
parserMode |= parser.ParseComments
|
|
|
|
|
|
|
|
pathASTFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
|
|
|
|
if err != nil {
|
2022-11-25 19:10:19 +00:00
|
|
|
return nil, fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
|
2022-11-24 18:17:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Ignore generated file
|
|
|
|
if isGoGenerated(pathASTFile) {
|
2022-11-25 19:10:19 +00:00
|
|
|
return nil, ErrGoGeneratedFile
|
2022-11-24 18:17:35 +00:00
|
|
|
}
|
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
return pathASTFile, nil
|
|
|
|
}
|
2022-11-24 18:27:01 +00:00
|
|
|
|
2022-11-25 19:10:19 +00:00
|
|
|
// Copy from goimports-reviser
|
|
|
|
func (ft *Formatter) parseImports(pathASTFile *ast.File) (map[string]*ast.ImportSpec, error) {
|
|
|
|
result := make(map[string]*ast.ImportSpec)
|
|
|
|
|
|
|
|
for _, decl := range pathASTFile.Decls {
|
|
|
|
genDecl, ok := decl.(*ast.GenDecl)
|
|
|
|
if !ok {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if genDecl.Tok != token.IMPORT {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, spec := range genDecl.Specs {
|
|
|
|
importSpec, ok := spec.(*ast.ImportSpec)
|
|
|
|
if !ok {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
var importNameAndPath string
|
|
|
|
if importSpec.Name != nil {
|
|
|
|
// Handle alias import
|
|
|
|
// xyz "github.com/abc/xyz/123"
|
|
|
|
importNameAndPath = importSpec.Name.String() + " " + importSpec.Path.Value
|
|
|
|
} else {
|
|
|
|
// Handle normal import
|
|
|
|
// "github.com/abc/xyz"
|
|
|
|
importNameAndPath = importSpec.Path.Value
|
|
|
|
}
|
|
|
|
|
|
|
|
result[importNameAndPath] = importSpec
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Copy from goimports-reviser
|
|
|
|
// Group imports to std, third-party, company if exist, local
|
|
|
|
func (ft *Formatter) groupImports(importsAST map[string]*ast.ImportSpec) (map[string][]string, error) {
|
|
|
|
result := make(map[string][]string)
|
|
|
|
result[stdImport] = make([]string, 0, 8)
|
|
|
|
result[thirdPartyImport] = make([]string, 0, 8)
|
|
|
|
if ft.companyPrefix != "" {
|
|
|
|
result[companyImport] = make([]string, 0, 8)
|
|
|
|
}
|
|
|
|
result[localImport] = make([]string, 0, 8)
|
|
|
|
|
|
|
|
for importNameAndPath, importAST := range importsAST {
|
|
|
|
// "github.com/abc/xyz" -> github.com/abc/xyz
|
|
|
|
importPath := strings.Trim(importAST.Path.Value, "\"")
|
|
|
|
|
|
|
|
if _, ok := ft.stdPackages[importPath]; ok {
|
|
|
|
result[stdImport] = append(result[stdImport], importNameAndPath)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Wrap log.Printf with verbose flag
|
|
|
|
func (ft *Formatter) log(format string, v ...any) {
|
|
|
|
if ft.isVerbose {
|
|
|
|
log.Printf(format, v...)
|
|
|
|
}
|
2022-11-24 17:14:47 +00:00
|
|
|
}
|