feat: actually print file from ast

main
sudo pacman -Syu 2022-11-27 00:08:16 +07:00
parent ec9ec20189
commit 30f79e9322
No known key found for this signature in database
GPG Key ID: D6CB5C6C567C47B0
1 changed files with 64 additions and 37 deletions

View File

@ -1,10 +1,12 @@
package imports
import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"log"
"os"
@ -132,14 +134,7 @@ func (ft *Formatter) formatFile(path string) error {
}
ft.log("formatFile: moduleName: %+v\n", moduleName)
// Parse ast
pathASTFile, err := ft.parseAST(path, pathBytes)
if err != nil {
return err
}
ft.log("formatFile: pathASTFile: %+v\n", pathASTFile)
if err := ft.formatASTFile(pathASTFile, moduleName); err != nil {
if err := ft.formatImports(path, pathBytes, moduleName); err != nil {
return err
}
@ -150,36 +145,36 @@ func (ft *Formatter) formatFile(path string) error {
return nil
}
func (ft *Formatter) parseAST(path string, pathBytes []byte) (*ast.File, error) {
// Copy from goimports-reviser
func (ft *Formatter) formatImports(
path string,
pathBytes []byte,
moduleName string,
) error {
// Parse ast
fset := token.NewFileSet()
parserMode := parser.Mode(0)
parserMode |= parser.ParseComments
pathASTFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
astFile, err := parser.ParseFile(fset, path, pathBytes, parserMode)
if err != nil {
return nil, fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
return fmt.Errorf("parser: failed to parse file [%s]: %w", path, err)
}
// Ignore generated file
if isGoGenerated(pathASTFile) {
return nil, ErrGoGeneratedFile
if isGoGenerated(astFile) {
return ErrGoGeneratedFile
}
return pathASTFile, nil
}
// Copy from goimports-reviser
// If exist multi import, combine them into one
// This func will edit pathASTFile directly to combine
func (ft *Formatter) formatASTFile(
pathASTFile *ast.File,
moduleName string,
) error {
// Extract imports
importSpecs := make([]ast.Spec, 0, len(pathASTFile.Imports))
for _, importSpec := range pathASTFile.Imports {
ft.log("parseImportsAndCombine: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value)
importSpecs := make([]ast.Spec, 0, len(astFile.Imports))
for _, importSpec := range astFile.Imports {
if importSpec.Path.Value == "" {
continue
}
ft.log("formatImports: importSpec: %+v %+v\n", importSpec.Name.String(), importSpec.Path.Value)
importSpecs = append(importSpecs, importSpec)
}
@ -195,12 +190,13 @@ func (ft *Formatter) formatASTFile(
if err != nil {
return err
}
ft.mustLogImportSpecs("formatImports: formattedImportSpecs: ", formattedImportSpecs)
// Combine multi import decl
// Combine multi import decl into one
isMultiImportDecl := false
isExistFirstImportDecl := false
decls := make([]ast.Decl, 0, len(pathASTFile.Decls))
for _, decl := range pathASTFile.Decls {
decls := make([]ast.Decl, 0, len(astFile.Decls))
for _, decl := range astFile.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
decls = append(decls, decl)
@ -215,8 +211,8 @@ func (ft *Formatter) formatASTFile(
if isExistFirstImportDecl {
isMultiImportDecl = true
// TODO: explain this
storedGenDecl := decls[len(decls)-1].(*ast.GenDecl)
if storedGenDecl.Tok == token.IMPORT {
storedGenDecl, ok := decls[len(decls)-1].(*ast.GenDecl)
if ok && storedGenDecl.Tok == token.IMPORT {
storedGenDecl.Rparen = genDecl.End()
}
continue
@ -227,12 +223,20 @@ func (ft *Formatter) formatASTFile(
genDecl.Specs = formattedImportSpecs
decls = append(decls, genDecl)
}
ft.log("parseImportsAndCombine: decls: %+v\n", decls)
if isMultiImportDecl {
pathASTFile.Decls = decls
astFile.Decls = decls
}
// Print formatted bytes from formatted ast
var formattedBytes []byte
formattedBuffer := bytes.NewBuffer(formattedBytes)
if err := printer.Fprint(formattedBuffer, fset, astFile); err != nil {
return err
}
fmt.Println(formattedBuffer.String())
return nil
}
@ -278,16 +282,17 @@ func (ft *Formatter) groupImportSpecs(
result[thirdPartyImport] = append(result[thirdPartyImport], importSpec)
}
ft.log("groupImports: std: %+v\n", result[stdImport])
ft.log("groupImports: third-party: %+v\n", result[thirdPartyImport])
ft.logImportSpecs("stdImport", result[stdImport])
ft.logImportSpecs("thirdPartyImport", result[thirdPartyImport])
if ft.companyPrefix != "" {
ft.log("groupImports: company: %+v\n", result[companyImport])
ft.logImportSpecs("companyImport", result[companyImport])
}
ft.log("groupImports: local: %+v\n", result[localImport])
ft.logImportSpecs("localImport", result[localImport])
return result, nil
}
// Copy from goimports-reviser
func (ft *Formatter) formatImportSpecs(
importSpecs []ast.Spec,
groupedImportSpecs map[string][]*ast.ImportSpec,
@ -350,6 +355,7 @@ func (ft *Formatter) formatImportSpecs(
return result, nil
}
// Copy from goimports-reviser
func (ft *Formatter) moduleName(path string) (string, error) {
ft.muModuleNames.RLock()
if pkgName, ok := ft.moduleNames[path]; ok {
@ -413,3 +419,24 @@ func (ft *Formatter) log(format string, v ...any) {
log.Printf(format, v...)
}
}
func (ft *Formatter) logImportSpecs(logPrefix string, importSpecs []*ast.ImportSpec) {
if ft.isVerbose {
for _, importSpec := range importSpecs {
log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value)
}
}
}
func (ft *Formatter) mustLogImportSpecs(logPrefix string, importSpecs []ast.Spec) {
if ft.isVerbose {
for _, importSpec := range importSpecs {
importSpec, ok := importSpec.(*ast.ImportSpec)
if !ok {
continue
}
log.Printf("%s: importSpec: %+v %+v\n", logPrefix, importSpec.Name.String(), importSpec.Path.Value)
}
}
}