feat: support batch insert (wip)
parent
8be88da353
commit
07a4563762
|
@ -18,12 +18,16 @@ import (
|
||||||
const (
|
const (
|
||||||
dialectMySQL = "mysql"
|
dialectMySQL = "mysql"
|
||||||
|
|
||||||
stmtInsert = "INSERT INTO %s (%s) VALUES (%s);"
|
// 2^16
|
||||||
|
maxQuestionMarks = 65536
|
||||||
|
|
||||||
|
stmtInsert = "INSERT INTO %s (%s) VALUES %s;"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrNotSupportDialect = errors.New("not support dialect")
|
ErrNotSupportDialect = errors.New("not support dialect")
|
||||||
ErrTableNotExist = errors.New("table not exist")
|
ErrTableNotExist = errors.New("table not exist")
|
||||||
|
ErrMaximumQuestionMarks = errors.New("maximum question marks")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Populator interface {
|
type Populator interface {
|
||||||
|
@ -94,32 +98,19 @@ func NewPopulator(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *populator) Insert(ctx context.Context, tableName string, numberRecord int) error {
|
func (p *populator) Insert(ctx context.Context, tableName string, numberRecord int) error {
|
||||||
table, ok := p.tables[tableName]
|
columnNames, questionMarks, argFns, err := p.prepareInsert(tableName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return fmt.Errorf("table [%s] not exist: %w", tableName, ErrTableNotExist)
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
columnNames := make([]string, 0, len(table.Columns))
|
|
||||||
questionMarks := make([]string, 0, len(table.Columns))
|
|
||||||
argFns := make([]func() any, 0, len(table.Columns))
|
|
||||||
for _, column := range table.Columns {
|
|
||||||
dt, err := ParseDatabaseType(column.Type)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse database type [%s]: %w", column.Type, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
columnNames = append(columnNames, column.Name)
|
|
||||||
questionMarks = append(questionMarks, "?")
|
|
||||||
argFns = append(argFns, dt.Generate)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
queryInsert := fmt.Sprintf(stmtInsert,
|
queryInsert := fmt.Sprintf(stmtInsert,
|
||||||
tableName,
|
tableName,
|
||||||
strings.Join(columnNames, ", "),
|
strings.Join(columnNames, ", "),
|
||||||
strings.Join(questionMarks, ", "),
|
fmt.Sprintf("(%s)", strings.Join(questionMarks, ", ")),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i := 0; i < numberRecord; i++ {
|
for i := 0; i < numberRecord; i++ {
|
||||||
|
// Generate each time insert for different value
|
||||||
args := make([]any, 0, len(argFns))
|
args := make([]any, 0, len(argFns))
|
||||||
for _, argFn := range argFns {
|
for _, argFn := range argFns {
|
||||||
args = append(args, argFn())
|
args = append(args, argFn())
|
||||||
|
@ -138,3 +129,100 @@ func (p *populator) Insert(ctx context.Context, tableName string, numberRecord i
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *populator) BatchInsert(ctx context.Context, tableName string, numberRecord int) error {
|
||||||
|
columnNames, questionMarks, argFns, err := p.prepareInsert(tableName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columnNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
numberRecordEachBatch := maxQuestionMarks / len(questionMarks)
|
||||||
|
if numberRecordEachBatch == 0 {
|
||||||
|
return fmt.Errorf("maxium question marks [%d]: %w", len(questionMarks), ErrMaximumQuestionMarks)
|
||||||
|
}
|
||||||
|
|
||||||
|
numberBatch := numberRecord/numberRecordEachBatch + 1
|
||||||
|
numberRecordLastBatch := numberRecord - (numberBatch-1)*numberRecordEachBatch
|
||||||
|
|
||||||
|
generateQueryArgsInsertFn := func(tempNumberRecord int) (string, []any) {
|
||||||
|
valuesQuestionMarks := make([]string, 0, tempNumberRecord)
|
||||||
|
argsInsert := make([]any, 0, tempNumberRecord*len(argFns))
|
||||||
|
for i := 0; i < tempNumberRecord; i++ {
|
||||||
|
valuesQuestionMarks = append(valuesQuestionMarks, fmt.Sprintf("(%s)", strings.Join(questionMarks, ", ")))
|
||||||
|
|
||||||
|
// Generate each time insert for different value
|
||||||
|
args := make([]any, 0, len(argFns))
|
||||||
|
for _, argFn := range argFns {
|
||||||
|
args = append(args, argFn())
|
||||||
|
}
|
||||||
|
argsInsert = append(argsInsert, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
queryInsert := fmt.Sprintf(stmtInsert,
|
||||||
|
tableName,
|
||||||
|
strings.Join(columnNames, ", "),
|
||||||
|
strings.Join(valuesQuestionMarks, ", "),
|
||||||
|
)
|
||||||
|
|
||||||
|
return queryInsert, argsInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numberBatch-1; i++ {
|
||||||
|
queryInsert, argsInsert := generateQueryArgsInsertFn(numberRecordEachBatch)
|
||||||
|
|
||||||
|
if p.verbose {
|
||||||
|
fmt.Println(i, queryInsert, argsInsert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.dryRun {
|
||||||
|
if _, err := p.db.ExecContext(ctx, queryInsert, argsInsert...); err != nil {
|
||||||
|
return fmt.Errorf("database: failed to exec [%s]: %w", queryInsert, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Last batch
|
||||||
|
queryInsert, argsInsert := generateQueryArgsInsertFn(numberRecordLastBatch)
|
||||||
|
|
||||||
|
if p.verbose {
|
||||||
|
fmt.Println(numberBatch-1, queryInsert, argsInsert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.dryRun {
|
||||||
|
if _, err := p.db.ExecContext(ctx, queryInsert, argsInsert...); err != nil {
|
||||||
|
return fmt.Errorf("database: failed to exec [%s]: %w", queryInsert, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return columnNames, questionMarks, argFns
|
||||||
|
func (p *populator) prepareInsert(tableName string) ([]string, []string, []func() any, error) {
|
||||||
|
table, ok := p.tables[tableName]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, nil, fmt.Errorf("table [%s] not exist: %w", tableName, ErrTableNotExist)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNames := make([]string, 0, len(table.Columns))
|
||||||
|
questionMarks := make([]string, 0, len(table.Columns))
|
||||||
|
argFns := make([]func() any, 0, len(table.Columns))
|
||||||
|
for _, column := range table.Columns {
|
||||||
|
dt, err := ParseDatabaseType(column.Type)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to parse database type [%s]: %w", column.Type, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNames = append(columnNames, column.Name)
|
||||||
|
questionMarks = append(questionMarks, "?")
|
||||||
|
argFns = append(argFns, dt.Generate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnNames, questionMarks, argFns, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue