-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bugfix: error image when use null value as image query condition in insert on duplicate #704 #725
base: master
Are you sure you want to change the base?
Changes from 5 commits
abdd938
4debbde
42fb93a
56f2645
c52fd82
321906e
b60e020
5675b29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,68 +97,120 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a | |
if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { | ||
return "", nil, err | ||
} | ||
var selectArgs []driver.Value | ||
|
||
// Reset primary keys map | ||
u.BeforeImageSqlPrimaryKeys = make(map[string]bool, len(metaData.Indexs)) | ||
|
||
pkIndexMap := u.getPkIndex(insertStmt, metaData) | ||
var pkIndexArray []int | ||
for _, val := range pkIndexMap { | ||
tmpVal := val | ||
pkIndexArray = append(pkIndexArray, tmpVal) | ||
pkIndexArray = append(pkIndexArray, val) | ||
} | ||
insertRows, err := getInsertRows(insertStmt, pkIndexArray) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
insertNum := len(insertRows) | ||
paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
if len(paramMap) == 0 || len(metaData.Indexs) == 0 { | ||
return "", nil, nil | ||
} | ||
hasPK := false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
goto 不太好理解? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
那这里还需要改为goto吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要使用 goto 吧 |
||
for _, index := range metaData.Indexs { | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
hasPK = true | ||
break | ||
} | ||
} | ||
if !hasPK { | ||
return "", nil, nil | ||
} | ||
var sql strings.Builder | ||
sql.WriteString("SELECT * FROM " + metaData.TableName + " ") | ||
|
||
sql := strings.Builder{} | ||
sql.WriteString("SELECT * FROM " + metaData.TableName + " ") | ||
var selectArgs []driver.Value | ||
isContainWhere := false | ||
for i := 0; i < insertNum; i++ { | ||
finalI := i | ||
paramAppenderTempList := make([]driver.Value, 0) | ||
hasConditions := false | ||
for i := 0; i < len(insertRows); i++ { | ||
var rowConditions = make([]string, 0, cap(insertRows[i])) | ||
var rowArgs = make([]driver.Value, 0, cap(insertRows[i])) | ||
usedParams := make(map[string]bool, len(paramMap)) | ||
// First try unique indexes | ||
for _, index := range metaData.Indexs { | ||
//unique index | ||
if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false { | ||
if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) { | ||
continue | ||
} | ||
columnIsNull := true | ||
uniqueList := make([]string, 0) | ||
for _, columnMeta := range index.Columns { | ||
columnName := columnMeta.ColumnName | ||
imageParameters, ok := paramMap[columnName] | ||
if !ok && columnMeta.ColumnDef != nil { | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
u.BeforeImageSqlPrimaryKeys[columnName] = true | ||
} | ||
uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") | ||
columnIsNull = false | ||
continue | ||
if !isIndexValueNotNull(index, paramMap, i) { | ||
continue | ||
} | ||
var indexConditions []string | ||
var indexArgs []driver.Value | ||
allColumnsPresent := true | ||
for _, colMeta := range index.Columns { | ||
columnName := colMeta.ColumnName | ||
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil { | ||
indexConditions = append(indexConditions, columnName+" = ? ") | ||
indexArgs = append(indexArgs, params[i]) | ||
usedParams[columnName] = true | ||
} else if colMeta.ColumnDef != nil { | ||
indexConditions = append(indexConditions, columnName+" = DEFAULT("+columnName+")") | ||
} else { | ||
allColumnsPresent = false | ||
break | ||
} | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
u.BeforeImageSqlPrimaryKeys[columnName] = true | ||
} | ||
if allColumnsPresent && len(indexConditions) > 0 { | ||
rowConditions = append(rowConditions, "("+strings.Join(indexConditions, " and ")+")") | ||
rowArgs = append(rowArgs, indexArgs...) | ||
hasConditions = true | ||
} | ||
} | ||
// Then try primary key | ||
for _, index := range metaData.Indexs { | ||
if !strings.EqualFold("PRIMARY", index.Name) { | ||
continue | ||
} | ||
var pkConditions []string | ||
var pkArgs []driver.Value | ||
for _, colMeta := range index.Columns { | ||
columnName := colMeta.ColumnName | ||
u.BeforeImageSqlPrimaryKeys[columnName] = true | ||
if params, ok := paramMap[columnName]; ok && len(params) > i && params[i] != nil && !usedParams[columnName] { | ||
pkConditions = append(pkConditions, columnName+" = ? ") | ||
pkArgs = append(pkArgs, params[i]) | ||
} | ||
columnIsNull = false | ||
uniqueList = append(uniqueList, columnName+" = ? ") | ||
paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI]) | ||
} | ||
|
||
if !columnIsNull { | ||
if isContainWhere { | ||
sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") | ||
} else { | ||
sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") | ||
isContainWhere = true | ||
if len(pkConditions) > 0 { | ||
rowConditions = append(rowConditions, "("+strings.Join(pkConditions, " and ")+")") | ||
rowArgs = append(rowArgs, pkArgs...) | ||
hasConditions = true | ||
} | ||
} | ||
if len(rowConditions) > 0 { | ||
if !isContainWhere { | ||
sql.WriteString("WHERE ") | ||
isContainWhere = true | ||
} else { | ||
sql.WriteString(" OR ") | ||
} | ||
for j, condition := range rowConditions { | ||
if j > 0 { | ||
sql.WriteString(" OR ") | ||
} | ||
sql.WriteString(condition + " ") | ||
} | ||
selectArgs = append(selectArgs, rowArgs...) | ||
} | ||
selectArgs = append(selectArgs, paramAppenderTempList...) | ||
} | ||
log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String()) | ||
return sql.String(), selectArgs, nil | ||
if !hasConditions { | ||
return "", nil, nil | ||
} | ||
sqlStr := sql.String() | ||
log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr) | ||
return sqlStr, selectArgs, nil | ||
|
||
} | ||
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { | ||
|
@@ -168,14 +220,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e | |
log.Errorf("build prepare stmt: %+v", err) | ||
return nil, err | ||
} | ||
|
||
defer stmt.Close() | ||
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O | ||
metaData := execCtx.MetaDataMap[tableName] | ||
rows, err := stmt.Query(selectArgs) | ||
if err != nil { | ||
log.Errorf("stmt query: %+v", err) | ||
return nil, err | ||
} | ||
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O | ||
metaData := execCtx.MetaDataMap[tableName] | ||
defer rows.Close() | ||
image, err := u.buildRecordImages(rows, &metaData) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -185,11 +237,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e | |
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) { | ||
selectSQL, selectArgs := u.BeforeSelectSql, u.Args | ||
|
||
var beforeImage *types.RecordImage | ||
if len(beforeImages) > 0 { | ||
beforeImage = beforeImages[0] | ||
} | ||
if beforeImage == nil || len(beforeImage.Rows) == 0 { | ||
return selectSQL, selectArgs | ||
} | ||
primaryValueMap := make(map[string][]interface{}) | ||
for _, row := range beforeImage.Rows { | ||
for _, col := range row.Columns { | ||
|
@@ -198,25 +252,46 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co | |
} | ||
} | ||
} | ||
|
||
var afterImageSql strings.Builder | ||
var primaryValues []driver.Value | ||
afterImageSql.WriteString(selectSQL) | ||
for i := 0; i < len(beforeImage.Rows); i++ { | ||
wherePrimaryList := make([]string, 0) | ||
for name, value := range primaryValueMap { | ||
if !u.BeforeImageSqlPrimaryKeys[name] { | ||
wherePrimaryList = append(wherePrimaryList, name+" = ? ") | ||
primaryValues = append(primaryValues, value[i]) | ||
if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) { | ||
return selectSQL, selectArgs | ||
} | ||
var primaryValues []driver.Value | ||
usedPrimaryKeys := make(map[string]bool) | ||
for name := range primaryValueMap { | ||
if !u.BeforeImageSqlPrimaryKeys[name] { | ||
usedPrimaryKeys[name] = true | ||
for i := 0; i < len(beforeImage.Rows); i++ { | ||
if value := primaryValueMap[name][i]; value != nil { | ||
if dv, ok := value.(driver.Value); ok { | ||
primaryValues = append(primaryValues, dv) | ||
} else { | ||
primaryValues = append(primaryValues, value) | ||
} | ||
} | ||
} | ||
} | ||
if len(wherePrimaryList) != 0 { | ||
afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") | ||
} | ||
if len(primaryValues) > 0 { | ||
afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ") | ||
} | ||
finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues)) | ||
copy(finalArgs, selectArgs) | ||
copy(finalArgs[len(selectArgs):], primaryValues) | ||
sqlStr := afterImageSql.String() | ||
log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr) | ||
return sqlStr, finalArgs | ||
} | ||
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string { | ||
var conditions []string | ||
for name := range primaryValueMap { | ||
if !usedPrimaryKeys[name] { | ||
conditions = append(conditions, name+" = ? ") | ||
} | ||
} | ||
selectArgs = append(selectArgs, primaryValues...) | ||
log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) | ||
return afterImageSql.String(), selectArgs | ||
return conditions | ||
} | ||
|
||
func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { | ||
|
@@ -243,11 +318,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e | |
|
||
// build sql params | ||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { | ||
var ( | ||
parameterMap = make(map[string][]driver.Value) | ||
) | ||
parameterMap := make(map[string][]driver.Value) | ||
insertColumns := getInsertColumns(insert) | ||
var placeHolderIndex = 0 | ||
placeHolderIndex := 0 | ||
|
||
for _, row := range insertRows { | ||
if len(row) != len(insertColumns) { | ||
log.Errorf("insert row's column size not equal to insert column size") | ||
|
@@ -256,13 +330,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast. | |
for i, col := range insertColumns { | ||
columnName := executor.DelEscape(col, types.DBTypeMySQL) | ||
val := row[i] | ||
rStr, ok := val.(string) | ||
if ok && strings.EqualFold(rStr, SqlPlaceholder) { | ||
objects := args[placeHolderIndex] | ||
parameterMap[columnName] = append(parameterMap[col], objects) | ||
if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) { | ||
if placeHolderIndex >= len(args) { | ||
return nil, fmt.Errorf("not enough parameters for placeholders") | ||
} | ||
parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex]) | ||
placeHolderIndex++ | ||
} else { | ||
parameterMap[columnName] = append(parameterMap[col], val) | ||
parameterMap[columnName] = append(parameterMap[columnName], val) | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,26 @@ func TestInsertOnDuplicateBuildBeforeImageSQL(t *testing.T) { | |
expectQuery1: "SELECT * FROM t_user WHERE (name = ? and age = ? ) OR (name = ? and age = ? ) ", | ||
expectQueryArgs1: []driver.Value{"Jack1", 81, "Michal", int64(35)}, | ||
}, | ||
// Test case for null unique index | ||
{ | ||
execCtx: &types.ExecContext{ | ||
Query: "insert into t_unique(id, a, b) values(1, NULL, 2) on duplicate key update b = 5", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. values(1, NULL, 2) 换成占位符 ? |
||
MetaDataMap: map[string]types.TableMeta{"t_unique": tableMeta1}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MetaDataMap 会在后面使用,确保这个 key 被用到,可以 debug 看看是否走到了你预想的步骤 |
||
}, | ||
sourceQueryArgs: []driver.Value{1, nil, 2}, | ||
expectQuery1: "SELECT * FROM t_unique WHERE (id = ? ) ", | ||
expectQueryArgs1: []driver.Value{1}, | ||
}, | ||
// Test case for null primary key | ||
{ | ||
execCtx: &types.ExecContext{ | ||
Query: "insert into t_unique(id, b) values(NULL, 2) on duplicate key update b = 5", | ||
MetaDataMap: map[string]types.TableMeta{"t_unique": tableMeta1}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. t_unique 改为 t_user |
||
}, | ||
sourceQueryArgs: []driver.Value{nil, 2}, | ||
expectQuery1: "SELECT * FROM t_unique WHERE (b = ? ) ", | ||
expectQueryArgs1: []driver.Value{2}, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use make and cap