diff --git a/internal/populatedb/populatedb.go b/internal/populatedb/populatedb.go index 4199d86..d9c397c 100644 --- a/internal/populatedb/populatedb.go +++ b/internal/populatedb/populatedb.go @@ -18,12 +18,16 @@ import ( const ( dialectMySQL = "mysql" - stmtInsert = "INSERT INTO %s (%s) VALUES (%s);" + // 2^16 + maxQuestionMarks = 65536 + + stmtInsert = "INSERT INTO %s (%s) VALUES %s;" ) var ( - ErrNotSupportDialect = errors.New("not support dialect") - ErrTableNotExist = errors.New("table not exist") + ErrNotSupportDialect = errors.New("not support dialect") + ErrTableNotExist = errors.New("table not exist") + ErrMaximumQuestionMarks = errors.New("maximum question marks") ) type Populator interface { @@ -94,32 +98,19 @@ func NewPopulator( } func (p *populator) Insert(ctx context.Context, tableName string, numberRecord int) error { - table, ok := p.tables[tableName] - if !ok { - return fmt.Errorf("table [%s] not exist: %w", tableName, ErrTableNotExist) - } - - columnNames := make([]string, 0, len(table.Columns)) - questionMarks := make([]string, 0, len(table.Columns)) - argFns := make([]func() any, 0, len(table.Columns)) - for _, column := range table.Columns { - dt, err := ParseDatabaseType(column.Type) - if err != nil { - return fmt.Errorf("failed to parse database type [%s]: %w", column.Type, err) - } - - columnNames = append(columnNames, column.Name) - questionMarks = append(questionMarks, "?") - argFns = append(argFns, dt.Generate) + columnNames, questionMarks, argFns, err := p.prepareInsert(tableName) + if err != nil { + return err } queryInsert := fmt.Sprintf(stmtInsert, tableName, strings.Join(columnNames, ", "), - strings.Join(questionMarks, ", "), + fmt.Sprintf("(%s)", strings.Join(questionMarks, ", ")), ) for i := 0; i < numberRecord; i++ { + // Generate each time insert for different value args := make([]any, 0, len(argFns)) for _, argFn := range argFns { args = append(args, argFn()) @@ -138,3 +129,100 @@ func (p *populator) Insert(ctx context.Context, tableName string, numberRecord i return nil } + +func (p *populator) BatchInsert(ctx context.Context, tableName string, numberRecord int) error { + columnNames, questionMarks, argFns, err := p.prepareInsert(tableName) + if err != nil { + return err + } + + if len(columnNames) == 0 { + return nil + } + + numberRecordEachBatch := maxQuestionMarks / len(questionMarks) + if numberRecordEachBatch == 0 { + return fmt.Errorf("maxium question marks [%d]: %w", len(questionMarks), ErrMaximumQuestionMarks) + } + + numberBatch := numberRecord/numberRecordEachBatch + 1 + numberRecordLastBatch := numberRecord - (numberBatch-1)*numberRecordEachBatch + + generateQueryArgsInsertFn := func(tempNumberRecord int) (string, []any) { + valuesQuestionMarks := make([]string, 0, tempNumberRecord) + argsInsert := make([]any, 0, tempNumberRecord*len(argFns)) + for i := 0; i < tempNumberRecord; i++ { + valuesQuestionMarks = append(valuesQuestionMarks, fmt.Sprintf("(%s)", strings.Join(questionMarks, ", "))) + + // Generate each time insert for different value + args := make([]any, 0, len(argFns)) + for _, argFn := range argFns { + args = append(args, argFn()) + } + argsInsert = append(argsInsert, args...) + } + + queryInsert := fmt.Sprintf(stmtInsert, + tableName, + strings.Join(columnNames, ", "), + strings.Join(valuesQuestionMarks, ", "), + ) + + return queryInsert, argsInsert + } + + for i := 0; i < numberBatch-1; i++ { + queryInsert, argsInsert := generateQueryArgsInsertFn(numberRecordEachBatch) + + if p.verbose { + fmt.Println(i, queryInsert, argsInsert) + } + + if !p.dryRun { + if _, err := p.db.ExecContext(ctx, queryInsert, argsInsert...); err != nil { + return fmt.Errorf("database: failed to exec [%s]: %w", queryInsert, err) + } + } + } + + { + // Last batch + queryInsert, argsInsert := generateQueryArgsInsertFn(numberRecordLastBatch) + + if p.verbose { + fmt.Println(numberBatch-1, queryInsert, argsInsert) + } + + if !p.dryRun { + if _, err := p.db.ExecContext(ctx, queryInsert, argsInsert...); err != nil { + return fmt.Errorf("database: failed to exec [%s]: %w", queryInsert, err) + } + } + } + + return nil +} + +// Return columnNames, questionMarks, argFns +func (p *populator) prepareInsert(tableName string) ([]string, []string, []func() any, error) { + table, ok := p.tables[tableName] + if !ok { + return nil, nil, nil, fmt.Errorf("table [%s] not exist: %w", tableName, ErrTableNotExist) + } + + columnNames := make([]string, 0, len(table.Columns)) + questionMarks := make([]string, 0, len(table.Columns)) + argFns := make([]func() any, 0, len(table.Columns)) + for _, column := range table.Columns { + dt, err := ParseDatabaseType(column.Type) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to parse database type [%s]: %w", column.Type, err) + } + + columnNames = append(columnNames, column.Name) + questionMarks = append(questionMarks, "?") + argFns = append(argFns, dt.Generate) + } + + return columnNames, questionMarks, argFns, nil +}