diff --git a/pkg/pb/plan/plan.pb.go b/pkg/pb/plan/plan.pb.go index d1ddffea76bf9..c40932503398d 100644 --- a/pkg/pb/plan/plan.pb.go +++ b/pkg/pb/plan/plan.pb.go @@ -5859,6 +5859,7 @@ type UpdateCtx struct { InsertCols []ColRef `protobuf:"bytes,7,rep,name=insert_cols,json=insertCols,proto3" json:"insert_cols"` DeleteCols []ColRef `protobuf:"bytes,8,rep,name=delete_cols,json=deleteCols,proto3" json:"delete_cols"` PartitionCols []ColRef `protobuf:"bytes,9,rep,name=partition_cols,json=partitionCols,proto3" json:"partition_cols"` + IsReplace bool `protobuf:"varint,10,opt,name=is_replace,json=isReplace,proto3" json:"is_replace,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -5932,6 +5933,13 @@ func (m *UpdateCtx) GetPartitionCols() []ColRef { return nil } +func (m *UpdateCtx) GetIsReplace() bool { + if m != nil { + return m.IsReplace + } + return false +} + type InsertCtx struct { Ref *ObjectRef `protobuf:"bytes,1,opt,name=ref,proto3" json:"ref,omitempty"` AddAffectedRows bool `protobuf:"varint,2,opt,name=add_affected_rows,json=addAffectedRows,proto3" json:"add_affected_rows,omitempty"` @@ -18840,6 +18848,16 @@ func (m *UpdateCtx) MarshalToSizedBuffer(dAtA []byte) (int, error) { i -= len(m.XXX_unrecognized) copy(dAtA[i:], m.XXX_unrecognized) } + if m.IsReplace { + i-- + if m.IsReplace { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x50 + } if len(m.PartitionCols) > 0 { for iNdEx := len(m.PartitionCols) - 1; iNdEx >= 0; iNdEx-- { { @@ -28439,6 +28457,9 @@ func (m *UpdateCtx) ProtoSize() (n int) { n += 1 + l + sovPlan(uint64(l)) } } + if m.IsReplace { + n += 2 + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -42466,6 +42487,26 @@ func (m *UpdateCtx) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field IsReplace", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPlan + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.IsReplace = bool(v != 0) default: iNdEx = preIndex skippy, err := skipPlan(dAtA[iNdEx:]) diff --git a/pkg/sql/colexec/multi_update/affected_rows_test.go b/pkg/sql/colexec/multi_update/affected_rows_test.go new file mode 100644 index 0000000000000..3c034cae32874 --- /dev/null +++ b/pkg/sql/colexec/multi_update/affected_rows_test.go @@ -0,0 +1,110 @@ +// Copyright 2021-2024 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package multi_update + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// Validates that REPLACE INTO counts both inserted and deleted rows toward +// AffectedRows, while regular UPDATE / DELETE / INSERT are unchanged. +func TestAddAffectRows_PerAction(t *testing.T) { + cases := []struct { + name string + action actionType + isReplace bool + tableType UpdateTableType + insertRows uint64 + deleteRows uint64 + wantAfterInsFn uint64 // after only addInsertAffectRows + wantAfterDelFn uint64 // after addInsertAffectRows + addDeleteAffectRows + }{ + { + name: "INSERT main table", + action: actionInsert, + tableType: UpdateMainTable, + insertRows: 5, + deleteRows: 0, + wantAfterInsFn: 5, + wantAfterDelFn: 5, + }, + { + name: "DELETE main table", + action: actionDelete, + tableType: UpdateMainTable, + insertRows: 0, + deleteRows: 7, + wantAfterInsFn: 0, + wantAfterDelFn: 7, + }, + { + name: "UPDATE main table counts insert side only", + action: actionUpdate, + isReplace: false, + tableType: UpdateMainTable, + insertRows: 3, + deleteRows: 3, + wantAfterInsFn: 3, + wantAfterDelFn: 3, + }, + { + name: "REPLACE main table counts insert + delete", + action: actionUpdate, + isReplace: true, + tableType: UpdateMainTable, + insertRows: 4, + deleteRows: 2, + wantAfterInsFn: 4, + wantAfterDelFn: 6, + }, + { + name: "REPLACE unique index table is skipped", + action: actionUpdate, + isReplace: true, + tableType: UpdateUniqueIndexTable, + insertRows: 10, + deleteRows: 10, + wantAfterInsFn: 0, + wantAfterDelFn: 0, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + op := &MultiUpdate{ + MultiUpdateCtx: []*MultiUpdateCtx{{IsReplace: tc.isReplace}}, + } + op.addAffectedRowsFunc = op.doAddAffectedRows + op.ctr.action = tc.action + + op.addInsertAffectRows(tc.tableType, tc.insertRows) + require.Equal(t, tc.wantAfterInsFn, op.GetAffectedRows(), + "after addInsertAffectRows") + + op.addDeleteAffectRows(tc.tableType, tc.deleteRows) + require.Equal(t, tc.wantAfterDelFn, op.GetAffectedRows(), + "after addDeleteAffectRows") + }) + } +} + +// isReplace must tolerate an empty MultiUpdateCtx slice (defensive — shouldn't +// happen in practice but the operator must not panic on it). +func TestIsReplace_EmptyCtx(t *testing.T) { + op := &MultiUpdate{} + require.False(t, op.isReplace()) +} diff --git a/pkg/sql/colexec/multi_update/types.go b/pkg/sql/colexec/multi_update/types.go index 615c0b08495e7..af023ff19be00 100644 --- a/pkg/sql/colexec/multi_update/types.go +++ b/pkg/sql/colexec/multi_update/types.go @@ -110,6 +110,10 @@ type MultiUpdateCtx struct { InsertCols []int DeleteCols []int PartitionCols []int + // IsReplace marks this ctx as belonging to a REPLACE INTO statement's main + // table. Used by addDeleteAffectRows to count replaced-out rows toward + // affected_rows (MySQL semantics: affected_rows = inserted + deleted). + IsReplace bool } func (update MultiUpdate) TypeName() string { @@ -190,13 +194,8 @@ func (update *MultiUpdate) addInsertAffectRows(tableType UpdateTableType, rowCou if tableType != UpdateMainTable { return } - // For REPLACE INTO, we always count INSERT rows, regardless of update.ctr.action - // because REPLACE INTO should return at least the number of rows being inserted switch update.ctr.action { - case actionInsert: - update.addAffectedRowsFunc(rowCount) - case actionUpdate: - // For REPLACE INTO with both DELETE and INSERT, count INSERT rows + case actionInsert, actionUpdate: update.addAffectedRowsFunc(rowCount) } } @@ -205,19 +204,27 @@ func (update *MultiUpdate) addDeleteAffectRows(tableType UpdateTableType, rowCou if tableType != UpdateMainTable { return } - // For REPLACE INTO, we don't count DELETE rows in affected rows - // REPLACE INTO should only return the number of INSERT rows - // Only count DELETE rows for regular UPDATE operations switch update.ctr.action { case actionDelete: - // Regular DELETE operation, count it update.addAffectedRowsFunc(rowCount) case actionUpdate: - // For UPDATE operations (not REPLACE INTO), count DELETE rows - // But for REPLACE INTO, this should not be called or should be ignored - // REPLACE INTO uses actionUpdate but should only count INSERT - // So we don't count DELETE here for actionUpdate + // MySQL REPLACE semantics: affected_rows = inserted + deleted. + // Regular UPDATE still counts only the insert side (== rows_matched), + // so we only add the delete side when this MULTI_UPDATE was produced + // by REPLACE INTO. + if update.isReplace() { + update.addAffectedRowsFunc(rowCount) + } + } +} + +// isReplace reports whether this MultiUpdate was produced by REPLACE INTO. +// The flag lives on the main-table UpdateCtx (index 0) — see bind_replace.go. +func (update *MultiUpdate) isReplace() bool { + if len(update.MultiUpdateCtx) == 0 { + return false } + return update.MultiUpdateCtx[0].IsReplace } func (update *MultiUpdate) doAddAffectedRows(affectedRows uint64) { diff --git a/pkg/sql/compile/operator.go b/pkg/sql/compile/operator.go index 33a8348cc1f4a..052ba990d34b8 100644 --- a/pkg/sql/compile/operator.go +++ b/pkg/sql/compile/operator.go @@ -823,6 +823,7 @@ func constructMultiUpdate( InsertCols: insertCols, DeleteCols: deleteCols, PartitionCols: partitionCols, + IsReplace: updateCtx.IsReplace, } } arg.Action = action diff --git a/pkg/sql/compile/remoterun.go b/pkg/sql/compile/remoterun.go index b469e4e4787b8..966eb26a603e5 100644 --- a/pkg/sql/compile/remoterun.go +++ b/pkg/sql/compile/remoterun.go @@ -774,8 +774,9 @@ func convertToPipelineInstruction(op vm.Operator, proc *process.Process, ctx *sc updateCtxList := make([]*plan.UpdateCtx, len(t.MultiUpdateCtx)) for i, muCtx := range t.MultiUpdateCtx { updateCtxList[i] = &plan.UpdateCtx{ - ObjRef: muCtx.ObjRef, - TableDef: muCtx.TableDef, + ObjRef: muCtx.ObjRef, + TableDef: muCtx.TableDef, + IsReplace: muCtx.IsReplace, } updateCtxList[i].InsertCols = make([]plan.ColRef, len(muCtx.InsertCols)) @@ -1230,8 +1231,9 @@ func convertToVmOperator(opr *pipeline.Instruction, ctx *scopeContext, eng engin for i, muCtx := range t.UpdateCtxList { arg.MultiUpdateCtx[i] = &multi_update.MultiUpdateCtx{ - ObjRef: muCtx.ObjRef, - TableDef: muCtx.TableDef, + ObjRef: muCtx.ObjRef, + TableDef: muCtx.TableDef, + IsReplace: muCtx.IsReplace, } arg.MultiUpdateCtx[i].InsertCols = make([]int, len(muCtx.InsertCols)) diff --git a/pkg/sql/plan/bind_replace.go b/pkg/sql/plan/bind_replace.go index 6ff96f42fe4a1..1d916515b0374 100644 --- a/pkg/sql/plan/bind_replace.go +++ b/pkg/sql/plan/bind_replace.go @@ -27,6 +27,8 @@ import ( "github.com/matrixorigin/matrixone/pkg/sql/util" ) +const maxReplaceStaticFilterRows = 1024 + func (builder *QueryBuilder) bindReplace(stmt *tree.Replace, bindCtx *BindContext) (int32, error) { dmlCtx := NewDMLContext() // REPLACE has its own conflict handling; bypass the generic FK table rejection @@ -44,7 +46,109 @@ func (builder *QueryBuilder) bindReplace(stmt *tree.Replace, bindCtx *BindContex return 0, err } - return builder.appendDedupAndMultiUpdateNodesForBindReplace(bindCtx, dmlCtx, lastNodeID, colName2Idx, skipUniqueIdx) + staticFilterValues, err := builder.collectReplaceStaticFilterValues(stmt, dmlCtx.tableDefs[0]) + if err != nil { + return 0, err + } + + return builder.appendDedupAndMultiUpdateNodesForBindReplace( + bindCtx, + dmlCtx, + lastNodeID, + colName2Idx, + skipUniqueIdx, + staticFilterValues, + ) +} + +func (builder *QueryBuilder) collectReplaceStaticFilterValues(stmt *tree.Replace, tableDef *plan.TableDef) (map[string][]*plan.Expr, error) { + if stmt == nil || stmt.Rows == nil || stmt.Rows.Select == nil { + return nil, nil + } + + valuesClause, ok := stmt.Rows.Select.(*tree.ValuesClause) + if !ok || len(valuesClause.Rows) == 0 { + return nil, nil + } + if len(valuesClause.Rows) > maxReplaceStaticFilterRows { + return nil, nil + } + + insertColumns, err := builder.getInsertColsFromStmt(stmt.Columns, tableDef) + if err != nil { + return nil, err + } + + colCount := len(insertColumns) + for rowIdx, row := range valuesClause.Rows { + if len(row) != colCount { + return nil, moerr.NewWrongValueCountOnRow(builder.GetContext(), rowIdx+1) + } + } + + proc := builder.compCtx.GetProcess() + staticValues := make(map[string][]*plan.Expr, colCount) + for i, colName := range insertColumns { + colIdx, ok := tableDef.Name2ColIndex[colName] + if !ok { + return nil, moerr.NewInternalErrorf(builder.GetContext(), "replace static filter missing column %s", colName) + } + colDef := tableDef.Cols[colIdx] + colTyp := makeTypeByPlan2Type(colDef.Typ) + targetTyp := &plan.Expr{ + Typ: colDef.Typ, + Expr: &plan.Expr_T{ + T: &plan.TargetType{}, + }, + } + binder := NewDefaultBinder(builder.GetContext(), nil, nil, colDef.Typ, nil) + binder.builder = builder + + for _, row := range valuesClause.Rows { + astExpr := row[i] + if _, isDefault := astExpr.(*tree.DefaultVal); isDefault { + return nil, nil + } + + var valueExpr *plan.Expr + if nv, isNum := astExpr.(*tree.NumVal); isNum && !isEnumOrSetPlanType(&colDef.Typ) { + valueExpr, err = MakeInsertValueConstExpr(proc, nv, &colTyp) + if err != nil { + return nil, err + } + } + if valueExpr == nil { + valueExpr, err = binder.BindExpr(astExpr, 0, true) + if err != nil { + return nil, nil + } + if isEnumPlanType(&colDef.Typ) { + valueExpr, err = funcCastForEnumType(builder.GetContext(), valueExpr, colDef.Typ) + if err != nil { + return nil, err + } + } else if isSetPlanType(&colDef.Typ) { + valueExpr, err = funcCastForSetType(builder.GetContext(), valueExpr, colDef.Typ) + if err != nil { + return nil, err + } + } else if isGeometryPlanType(&colDef.Typ) { + valueExpr, err = funcCastForGeometryType(builder.GetContext(), valueExpr, colDef.Typ) + if err != nil { + return nil, err + } + } + } + + valueExpr, err = forceCastExpr2(builder.GetContext(), valueExpr, colTyp, targetTyp) + if err != nil { + return nil, err + } + staticValues[colName] = append(staticValues[colName], valueExpr) + } + } + + return staticValues, nil } func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( @@ -53,6 +157,7 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( lastNodeID int32, colName2Idx map[string]int32, skipUniqueIdx []bool, + staticFilterValues map[string][]*plan.Expr, ) (int32, error) { objRef := dmlCtx.objRefs[0] tableDef := dmlCtx.tableDefs[0] @@ -77,6 +182,165 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }) } + colExpr := func(tag, pos int32, typ plan.Type) *plan.Expr { + return &plan.Expr{ + Typ: typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: tag, + ColPos: pos, + }, + }, + } + } + nullExpr := func(typ plan.Type) *plan.Expr { + return &plan.Expr{ + Typ: typ, + Expr: &plan.Expr_Lit{ + Lit: &plan.Literal{Isnull: true}, + }, + } + } + bindFn := func(name string, args ...*plan.Expr) (*plan.Expr, error) { + copiedArgs := make([]*plan.Expr, len(args)) + for i, arg := range args { + copiedArgs[i] = DeepCopyExpr(arg) + } + expr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), name, copiedArgs) + if err != nil { + return nil, err + } + if expr == nil || expr.Expr == nil { + return nil, moerr.NewInternalErrorf(builder.GetContext(), "bind function %s got nil expression", name) + } + return expr, nil + } + nullSafeEq := func(left, right *plan.Expr) (*plan.Expr, error) { + leftIsNull, err := bindFn("isnull", left) + if err != nil { + return nil, err + } + rightIsNull, err := bindFn("isnull", right) + if err != nil { + return nil, err + } + bothNull, err := bindFn("and", leftIsNull, rightIsNull) + if err != nil { + return nil, err + } + leftNotNull, err := bindFn("isnotnull", left) + if err != nil { + return nil, err + } + rightNotNull, err := bindFn("isnotnull", right) + if err != nil { + return nil, err + } + bothNotNull, err := bindFn("and", leftNotNull, rightNotNull) + if err != nil { + return nil, err + } + eq, err := bindFn("=", left, right) + if err != nil { + // Fallback to "not equal" to keep REPLACE semantics correct when + // typed equality cannot be bound for specific index key types. + return makePlan2BoolConstExprWithType(false), nil + } + notNullEq, err := bindFn("and", bothNotNull, eq) + if err != nil { + return nil, err + } + return bindFn("or", bothNull, notNullEq) + } + makeNeedRewriteIdxExpr := func(oldRowID, oldIdx, newIdx, oldMainPK, newMainPK *plan.Expr) (*plan.Expr, error) { + oldRowIDIsNull, err := bindFn("isnull", oldRowID) + if err != nil { + return nil, err + } + sameIdx, err := nullSafeEq(oldIdx, newIdx) + if err != nil { + return nil, err + } + sameMainPK, err := nullSafeEq(oldMainPK, newMainPK) + if err != nil { + return nil, err + } + sameIdxAndPK, err := bindFn("and", sameIdx, sameMainPK) + if err != nil { + return nil, err + } + notSame, err := bindFn("not", sameIdxAndPK) + if err != nil { + return nil, err + } + return bindFn("or", oldRowIDIsNull, notSame) + } + makeIfExpr := func(cond, whenTrue, whenFalse *plan.Expr) (*plan.Expr, error) { + return bindFn("if", cond, whenTrue, whenFalse) + } + buildStaticScanFilter := func(scanCol *plan.Expr, values []*plan.Expr) (*plan.Expr, error) { + nonNullVals := make([]*plan.Expr, 0, len(values)) + seenVals := make(map[string]struct{}, len(values)) + for _, value := range values { + if value == nil { + continue + } + if lit := value.GetLit(); lit != nil && lit.Isnull { + continue + } + key, marshalErr := value.Marshal() + if marshalErr == nil { + if _, ok := seenVals[string(key)]; ok { + continue + } + seenVals[string(key)] = struct{}{} + } + nonNullVals = append(nonNullVals, value) + } + if len(nonNullVals) == 0 { + return nil, nil + } + + if len(nonNullVals) == 1 { + filterExpr, err := bindFn("=", scanCol, nonNullVals[0]) + if err != nil { + return nil, nil + } + return filterExpr, nil + } + + inExpr := &plan.Expr{ + Typ: scanCol.Typ, + Expr: &plan.Expr_List{ + List: &plan.ExprList{ + List: nonNullVals, + }, + }, + } + filterExpr, err := bindFn("in", scanCol, inExpr) + if err == nil { + return filterExpr, nil + } + + // Fallback to OR-equality chain when IN is unsupported for this type. + filterExpr, err = bindFn("=", scanCol, nonNullVals[0]) + if err != nil { + // Filter pushdown is an optimization only. + return nil, nil + } + for i := 1; i < len(nonNullVals); i++ { + eqExpr, bindErr := bindFn("=", scanCol, nonNullVals[i]) + if bindErr != nil { + return nil, nil + } + filterExpr, bindErr = bindFn("or", filterExpr, eqExpr) + if bindErr != nil { + return nil, nil + } + } + return filterExpr, nil + } + idxObjRefs := make([]*plan.ObjectRef, len(tableDef.Indexes)) idxTableDefs := make([]*plan.TableDef, len(tableDef.Indexes)) @@ -339,13 +603,13 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( // handle primary/unique key confliction builder.addNameByColRef(scanTag, tableDef) - scanNodeID := builder.appendNode(&plan.Node{ + scanNode := &plan.Node{ NodeType: plan.Node_TABLE_SCAN, TableDef: tableDef, ObjRef: objRef, BindingTags: []int32{scanTag}, ScanSnapshot: bindCtx.snapshot, - }, bindCtx) + } pkPos := tableDef.Name2ColIndex[pkName] pkTyp := tableDef.Cols[pkPos].Typ @@ -358,6 +622,16 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }, }, } + if len(tableDef.Pkey.Names) == 1 { + filterExpr, filterErr := buildStaticScanFilter(leftExpr, staticFilterValues[tableDef.Pkey.Names[0]]) + if filterErr != nil { + return 0, filterErr + } + if filterExpr != nil { + scanNode.FilterList = append(scanNode.FilterList, filterExpr) + } + } + scanNodeID := builder.appendNode(scanNode, bindCtx) rightExpr := &plan.Expr{ Typ: pkTyp, @@ -438,7 +712,7 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( // detect unique key confliction for i, idxDef := range tableDef.Indexes { - if !idxDef.Unique { + if !idxDef.Unique || skipUniqueIdx[i] { continue } @@ -466,6 +740,16 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }, }, } + if len(idxDef.Parts) == 1 { + partName := catalog.ResolveAlias(idxDef.Parts[0]) + filterExpr, filterErr := buildStaticScanFilter(leftExpr, staticFilterValues[partName]) + if filterErr != nil { + return 0, filterErr + } + if filterExpr != nil { + idxScanNode.FilterList = append(idxScanNode.FilterList, filterExpr) + } + } rightExpr := &plan.Expr{ Typ: pkTyp, @@ -553,6 +837,16 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }, }, } + if len(idxDef.Parts) == 1 { + partName := catalog.ResolveAlias(idxDef.Parts[0]) + filterExpr, filterErr := buildStaticScanFilter(leftExpr, staticFilterValues[partName]) + if filterErr != nil { + return 0, filterErr + } + if filterExpr != nil { + idxScanNode.FilterList = append(idxScanNode.FilterList, filterExpr) + } + } oldPkPos := oldColName2Idx[idxTableDefs[i].Name+"."+lookupColName] oldColName2Idx[idxTableDefs[i].Name+"."+lookupColName] = [2]int32{idxTag, idxTableDefs[i].Name2ColIndex[lookupColName]} @@ -661,31 +955,24 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( TableDef: tableDef, InsertCols: insertCols, DeleteCols: deleteCols, + IsReplace: true, }) } + newMainPkPos := colName2Idx[tableDef.Name+"."+tableDef.Pkey.PkeyColName] + newMainPkExpr := colExpr(fullProjTag, newMainPkPos, fullProjList[newMainPkPos].Typ) + oldMainPkPos := oldColName2Idx[tableDef.Name+"."+tableDef.Pkey.PkeyColName] + oldMainPkExpr := colExpr(oldMainPkPos[0], oldMainPkPos[1], fullProjList[oldMainPkPos[1]].Typ) + for i, idxDef := range tableDef.Indexes { insertCols := make([]plan.ColRef, 2) deleteCols := make([]plan.ColRef, 2) - newIdxPos := colName2Idx[idxDef.IndexTableName+"."+catalog.IndexTableIndexColName] - if indexTableStoresSerializedKey(idxDef) { - idxExpr := &plan.Expr{ - Typ: fullProjList[newIdxPos].Typ, - Expr: &plan.Expr_Col{ - Col: &plan.ColRef{ - RelPos: fullProjTag, - ColPos: newIdxPos, - }, - }, - } - newIdxPos = int32(len(finalProjList)) - finalProjList = append(finalProjList, idxExpr) - } + newIdxSourcePos := colName2Idx[idxDef.IndexTableName+"."+catalog.IndexTableIndexColName] + newIdxExpr := colExpr(fullProjTag, newIdxSourcePos, fullProjList[newIdxSourcePos].Typ) - oldRowIdPos := int32(len(finalProjList)) oldColRef := oldColName2Idx[idxDef.IndexTableName+"."+catalog.Row_ID] - rowIdExpr := &plan.Expr{ + oldRowIDExpr := &plan.Expr{ Typ: idxTableDefs[i].Cols[idxTableDefs[i].Name2ColIndex[catalog.Row_ID]].Typ, Expr: &plan.Expr_Col{ Col: &plan.ColRef{ @@ -694,13 +981,11 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }, }, } - finalProjList = append(finalProjList, rowIdExpr) - oldIdxPos := int32(len(finalProjList)) lookupColName := indexLookupColumnName(idxDef) lookupColIdx := idxTableDefs[i].Name2ColIndex[lookupColName] oldColRef = oldColName2Idx[idxDef.IndexTableName+"."+lookupColName] - idxExpr := &plan.Expr{ + oldIdxExpr := &plan.Expr{ Typ: idxTableDefs[i].Cols[lookupColIdx].Typ, Expr: &plan.Expr_Col{ Col: &plan.ColRef{ @@ -709,17 +994,46 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( }, }, } - finalProjList = append(finalProjList, idxExpr) + + newIdxProjExpr := newIdxExpr + oldRowIDProjExpr := oldRowIDExpr + oldIdxProjExpr := oldIdxExpr + + needRewriteIdxExpr, err := makeNeedRewriteIdxExpr(oldRowIDExpr, oldIdxExpr, newIdxExpr, oldMainPkExpr, newMainPkExpr) + if err == nil { + newIdxProjExpr, err = makeIfExpr(needRewriteIdxExpr, newIdxExpr, nullExpr(newIdxExpr.Typ)) + } + if err == nil { + oldRowIDProjExpr, err = makeIfExpr(needRewriteIdxExpr, oldRowIDExpr, nullExpr(oldRowIDExpr.Typ)) + } + if err == nil { + oldIdxProjExpr, err = makeIfExpr(needRewriteIdxExpr, oldIdxExpr, nullExpr(oldIdxExpr.Typ)) + } + if err != nil { + // Conditional index rewrite is an optimization only. For types that + // cannot bind IF/equals safely (e.g. geometry), fall back to always + // rewriting this index row to keep REPLACE semantics correct. + newIdxProjExpr = newIdxExpr + oldRowIDProjExpr = oldRowIDExpr + oldIdxProjExpr = oldIdxExpr + } + + newIdxPos := int32(len(finalProjList)) + finalProjList = append(finalProjList, newIdxProjExpr) + oldRowIdPos := int32(len(finalProjList)) + finalProjList = append(finalProjList, oldRowIDProjExpr) + oldIdxPos := int32(len(finalProjList)) + finalProjList = append(finalProjList, oldIdxProjExpr) insertCols[0].RelPos = finalProjTag - insertCols[0].ColPos = int32(newIdxPos) + insertCols[0].ColPos = newIdxPos insertCols[1].RelPos = finalProjTag insertCols[1].ColPos = newPkIdx deleteCols[0].RelPos = finalProjTag deleteCols[0].ColPos = oldRowIdPos deleteCols[1].RelPos = finalProjTag - deleteCols[1].ColPos = int32(oldIdxPos) + deleteCols[1].ColPos = oldIdxPos updateCtxList = append(updateCtxList, &plan.UpdateCtx{ ObjRef: idxObjRefs[i], @@ -728,17 +1042,17 @@ func (builder *QueryBuilder) appendDedupAndMultiUpdateNodesForBindReplace( DeleteCols: deleteCols, }) - if idxDef.Unique { + if idxDef.Unique && !skipUniqueIdx[i] { lockTargets = append(lockTargets, &plan.LockTarget{ TableId: idxTableDefs[i].TblId, ObjRef: idxObjRefs[i], - PrimaryColIdxInBat: int32(newIdxPos), + PrimaryColIdxInBat: newIdxPos, PrimaryColRelPos: finalProjTag, PrimaryColTyp: finalProjList[newIdxPos].Typ, }, &plan.LockTarget{ TableId: idxTableDefs[i].TblId, ObjRef: idxObjRefs[i], - PrimaryColIdxInBat: int32(oldIdxPos), + PrimaryColIdxInBat: oldIdxPos, PrimaryColRelPos: finalProjTag, PrimaryColTyp: finalProjList[oldIdxPos].Typ, }) diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index cc51ff4029134..06ca1fd03e313 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -805,6 +805,230 @@ func TestReplacePlanStructure(t *testing.T) { assert.True(t, hasDedupJoin, "REPLACE plan should contain DEDUP JOIN node") } +func TestReplaceSkipUniqueDedupWhenUniqueColsAreDefaultNull(t *testing.T) { + mock := NewMockOptimizer(true) + + withUniqueValue, err := runOneStmt(mock, t, "REPLACE INTO dept VALUES (1, 'Sales', 'NY')") + if err != nil { + t.Fatalf("%+v", err) + } + withDefaultNull, err := runOneStmt(mock, t, "REPLACE INTO dept(deptno, loc) VALUES (1, 'NY')") + if err != nil { + t.Fatalf("%+v", err) + } + + countDedup := func(q *plan.Query) int { + cnt := 0 + for _, node := range q.Nodes { + if node.NodeType == plan.Node_JOIN && node.JoinType == plan.Node_DEDUP { + cnt++ + } + } + return cnt + } + + // pk dedup + unique(dname) dedup + assert.Equal(t, 2, countDedup(withUniqueValue.GetQuery())) + // only pk dedup: unique(dname) is default NULL, dedup should be skipped. + assert.Equal(t, 1, countDedup(withDefaultNull.GetQuery())) +} + +func TestReplaceIndexUpdateUsesConditionalProjection(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, "REPLACE INTO dept VALUES (1, 'Sales', 'NY')") + if err != nil { + t.Fatalf("%+v", err) + } + + query := logicPlan.GetQuery() + assert.NotNil(t, query) + + var muNode *plan.Node + for _, node := range query.Nodes { + if node.NodeType == plan.Node_MULTI_UPDATE { + muNode = node + break + } + } + if !assert.NotNil(t, muNode) { + return + } + if !assert.NotEmpty(t, muNode.Children) { + return + } + + child := query.Nodes[muNode.Children[0]] + if child.NodeType == plan.Node_LOCK_OP { + if !assert.NotEmpty(t, child.Children) { + return + } + child = query.Nodes[child.Children[0]] + } + if !assert.Equal(t, plan.Node_PROJECT, child.NodeType) { + return + } + finalProj := child + + for _, ctx := range muNode.UpdateCtxList { + if ctx.TableDef == nil { + continue + } + if !(catalog.IsUniqueIndexTable(ctx.TableDef.Name) || catalog.IsSecondaryIndexTable(ctx.TableDef.Name)) { + continue + } + if !assert.NotEmpty(t, ctx.InsertCols) || len(ctx.DeleteCols) < 2 { + continue + } + + newIdxExpr := finalProj.ProjectList[ctx.InsertCols[0].ColPos] + oldRowIDExpr := finalProj.ProjectList[ctx.DeleteCols[0].ColPos] + oldIdxExpr := finalProj.ProjectList[ctx.DeleteCols[1].ColPos] + + assert.NotNil(t, newIdxExpr.GetF(), "index insert key should be guarded by IF expression") + assert.Equal(t, "if", newIdxExpr.GetF().Func.ObjName) + assert.NotNil(t, oldRowIDExpr.GetF(), "index delete rowid should be guarded by IF expression") + assert.Equal(t, "if", oldRowIDExpr.GetF().Func.ObjName) + assert.NotNil(t, oldIdxExpr.GetF(), "index delete key should be guarded by IF expression") + assert.Equal(t, "if", oldIdxExpr.GetF().Func.ObjName) + } +} + +func TestReplaceStaticScanFilterPushdownForValues(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, "REPLACE INTO dept VALUES (1, 'Sales', 'NY'), (2, 'HR', 'LA')") + if err != nil { + t.Fatalf("%+v", err) + } + + query := logicPlan.GetQuery() + assert.NotNil(t, query) + + mainScanFiltered := 0 + uniqueScanFiltered := 0 + for _, node := range query.Nodes { + if node.NodeType != plan.Node_TABLE_SCAN || len(node.FilterList) == 0 || node.TableDef == nil { + continue + } + if node.TableDef.Name == "dept" { + mainScanFiltered++ + } + if catalog.IsUniqueIndexTable(node.TableDef.Name) { + uniqueScanFiltered++ + } + } + + assert.Equal(t, 1, mainScanFiltered, "pk dedup scan should carry static filter for VALUES REPLACE") + assert.GreaterOrEqual(t, uniqueScanFiltered, 1, "single-part unique dedup scan should carry static filter") +} + +func TestReplaceStaticScanFilterPushdownForExplicitColumns(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, "REPLACE INTO dept(deptno, dname, loc) VALUES (1, 'Sales', 'NY')") + if err != nil { + t.Fatalf("%+v", err) + } + + query := logicPlan.GetQuery() + assert.NotNil(t, query) + + mainScanFiltered := 0 + uniqueScanFiltered := 0 + for _, node := range query.Nodes { + if node.NodeType != plan.Node_TABLE_SCAN || len(node.FilterList) == 0 || node.TableDef == nil { + continue + } + if node.TableDef.Name == "dept" { + mainScanFiltered++ + } + if catalog.IsUniqueIndexTable(node.TableDef.Name) { + uniqueScanFiltered++ + } + } + + assert.Equal(t, 1, mainScanFiltered, "explicit-column REPLACE should push down static filter for pk dedup scan") + assert.GreaterOrEqual(t, uniqueScanFiltered, 1, "explicit-column REPLACE should push down static filter for unique dedup scan") +} + +func TestReplaceStaticScanFilterUsesInWithDeduplicatedValues(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, "REPLACE INTO dept VALUES (1, 'Sales', 'NY'), (1, 'Sales2', 'SF'), (2, 'HR', 'LA')") + if err != nil { + t.Fatalf("%+v", err) + } + + query := logicPlan.GetQuery() + assert.NotNil(t, query) + + foundInFilter := false + foundInListForm := false + foundDedupedTwoValues := false + for _, node := range query.Nodes { + if node.NodeType != plan.Node_TABLE_SCAN || len(node.FilterList) == 0 { + continue + } + for _, filter := range node.FilterList { + f := filter.GetF() + if f == nil || f.Func == nil || f.Func.ObjName != "in" { + continue + } + foundInFilter = true + if len(f.Args) != 2 { + continue + } + if f.Args[1].GetList() == nil { + continue + } + foundInListForm = true + if len(f.Args[1].GetList().List) == 2 { + foundDedupedTwoValues = true + } + } + } + + assert.True(t, foundInFilter, "static scan filter should prefer IN when multiple values exist") + if foundInListForm { + assert.True(t, foundDedupedTwoValues, "duplicate key values should be deduplicated in IN list") + } +} + +func TestReplaceStaticScanFilterPushdownForIndexRowidLookup(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, "REPLACE INTO dept VALUES (1, 'Sales', 'NY'), (2, 'HR', 'LA')") + if err != nil { + t.Fatalf("%+v", err) + } + + query := logicPlan.GetQuery() + assert.NotNil(t, query) + + uniqueScanFiltered := make(map[string]int) + for _, node := range query.Nodes { + if node.NodeType != plan.Node_TABLE_SCAN || len(node.FilterList) == 0 || node.TableDef == nil { + continue + } + if catalog.IsUniqueIndexTable(node.TableDef.Name) { + uniqueScanFiltered[node.TableDef.Name]++ + } + } + + hasRowidLookupFiltered := false + for _, filteredCount := range uniqueScanFiltered { + // The same unique index table should be filtered in both: + // 1) unique dedup scan + // 2) index rowid lookup scan + if filteredCount >= 2 { + hasRowidLookupFiltered = true + break + } + } + assert.True(t, hasRowidLookupFiltered, "unique index rowid lookup scan should carry static filter pushdown") +} + func TestReplaceSelfRefPlanStructure(t *testing.T) { mock := NewMockOptimizer(true) diff --git a/pkg/sql/plan/build_util.go b/pkg/sql/plan/build_util.go index 94da5522edeb6..c8b0b17c86b6d 100644 --- a/pkg/sql/plan/build_util.go +++ b/pkg/sql/plan/build_util.go @@ -691,6 +691,9 @@ func convertValueIntoBool(name string, args []*Expr, isLogic bool) error { return nil } for _, arg := range args { + if arg == nil { + return moerr.NewInternalErrorNoCtxf("convertValueIntoBool got nil argument for function %s", name) + } if arg.Typ.Id == int32(types.T_bool) { continue } diff --git a/pkg/sql/plan/build_util_test.go b/pkg/sql/plan/build_util_test.go index 84691e101c9c5..680971d00c4f3 100644 --- a/pkg/sql/plan/build_util_test.go +++ b/pkg/sql/plan/build_util_test.go @@ -217,3 +217,14 @@ func TestBuildDefaultExprGeometryAllowsNullDefault(t *testing.T) { require.NoError(t, err) require.NotNil(t, def) } + +func TestConvertValueIntoBoolNilArg(t *testing.T) { + args := []*Expr{ + makePlan2BoolConstExprWithType(true), + nil, + } + + err := convertValueIntoBool("and", args, true) + require.Error(t, err) + require.Contains(t, err.Error(), "nil argument") +} diff --git a/pkg/sql/plan/deepcopy.go b/pkg/sql/plan/deepcopy.go index 547ebe70cf17f..2ad5afa4d072b 100644 --- a/pkg/sql/plan/deepcopy.go +++ b/pkg/sql/plan/deepcopy.go @@ -80,6 +80,7 @@ func DeepCopyUpdateCtxList(updateCtxList []*plan.UpdateCtx) []*plan.UpdateCtx { InsertCols: slices.Clone(ctx.InsertCols), DeleteCols: slices.Clone(ctx.DeleteCols), PartitionCols: slices.Clone(ctx.PartitionCols), + IsReplace: ctx.IsReplace, } } diff --git a/pkg/sql/plan/function/operatorSet.go b/pkg/sql/plan/function/operatorSet.go index a835a8bd1a9bc..2d6e1c323ae13 100644 --- a/pkg/sql/plan/function/operatorSet.go +++ b/pkg/sql/plan/function/operatorSet.go @@ -320,7 +320,7 @@ var ( types.T_int8, types.T_int16, types.T_int32, types.T_int64, types.T_uint8, types.T_uint16, types.T_uint32, types.T_uint64, types.T_float32, types.T_float64, - types.T_uuid, + types.T_uuid, types.T_Rowid, types.T_bool, types.T_date, types.T_datetime, types.T_bit, types.T_varchar, types.T_char, types.T_blob, types.T_text, types.T_json, @@ -413,6 +413,8 @@ func iffFn(parameters []*vector.Vector, result vector.FunctionResultWrapper, pro return generalIffFn[float64](parameters, result, proc, length, selectList) case types.T_uuid: return generalIffFn[types.Uuid](parameters, result, proc, length, selectList) + case types.T_Rowid: + return generalIffFn[types.Rowid](parameters, result, proc, length, selectList) case types.T_bool: return generalIffFn[bool](parameters, result, proc, length, selectList) case types.T_date: @@ -438,7 +440,7 @@ func iffFn(parameters []*vector.Vector, result vector.FunctionResultWrapper, pro } func generalIffFn[T constraints.Integer | constraints.Float | bool | types.Date | types.Datetime | - types.Decimal64 | types.Decimal128 | types.Decimal256 | types.Timestamp | types.Uuid](vecs []*vector.Vector, result vector.FunctionResultWrapper, _ *process.Process, length int, selectList *FunctionSelectList) error { + types.Decimal64 | types.Decimal128 | types.Decimal256 | types.Timestamp | types.Uuid | types.Rowid](vecs []*vector.Vector, result vector.FunctionResultWrapper, _ *process.Process, length int, selectList *FunctionSelectList) error { p1 := vector.GenerateFunctionFixedTypeParameter[bool](vecs[0]) p2 := vector.GenerateFunctionFixedTypeParameter[T](vecs[1]) p3 := vector.GenerateFunctionFixedTypeParameter[T](vecs[2]) diff --git a/pkg/sql/plan/function/operatorSet_test.go b/pkg/sql/plan/function/operatorSet_test.go index efddbc4a6dedc..b6e822cbf3491 100644 --- a/pkg/sql/plan/function/operatorSet_test.go +++ b/pkg/sql/plan/function/operatorSet_test.go @@ -432,6 +432,37 @@ func Test_IffCheck_MixedTypes(t *testing.T) { } } +func Test_Iff_Rowid(t *testing.T) { + inputs := []types.Type{ + types.T_bool.ToType(), + types.T_Rowid.ToType(), + types.T_Rowid.ToType(), + } + result := iffCheck(nil, inputs) + require.NotEqual(t, failedFunctionParametersWrong, result.status, "iffCheck should accept rowid branches") + + proc := testutil.NewProcess(t) + rid1 := types.Rowid([24]byte{1}) + rid2 := types.Rowid([24]byte{2}) + tc := tcTemp{ + info: "if(cond, rowid_a, rowid_b)", + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.T_bool.ToType(), + []bool{true, false, false, true}, []bool{false, false, false, false}), + NewFunctionTestInput(types.T_Rowid.ToType(), + []types.Rowid{rid1, rid1, rid1, rid1}, []bool{false, false, false, false}), + NewFunctionTestInput(types.T_Rowid.ToType(), + []types.Rowid{rid2, rid2, rid2, rid2}, []bool{false, false, false, false}), + }, + expect: NewFunctionTestResult(types.T_Rowid.ToType(), false, + []types.Rowid{rid1, rid2, rid2, rid1}, []bool{false, false, false, false}), + } + + tcc := NewFunctionTestCase(proc, tc.inputs, tc.expect, iffFn) + succeed, info := tcc.Run() + require.True(t, succeed, tc.info, info) +} + func Test_CaseWhen_WithNullAndStringComparison(t *testing.T) { // Test CASE WHEN with NULL value compared to string // This should not error, matching MySQL behavior diff --git a/pkg/sql/plan/opt_misc.go b/pkg/sql/plan/opt_misc.go index 3e97e71f30707..d49044db46500 100644 --- a/pkg/sql/plan/opt_misc.go +++ b/pkg/sql/plan/opt_misc.go @@ -228,6 +228,9 @@ func (builder *QueryBuilder) canRemoveProject(parentType plan.Node_NodeType, nod } func exprCanRemoveProject(expr *Expr) bool { + if expr == nil || expr.Expr == nil { + return false + } switch ne := expr.Expr.(type) { case *plan.Expr_F: if ne.F.Func.ObjName == "sleep" { diff --git a/proto/plan.proto b/proto/plan.proto index 5c931a42e1de3..31cf5180d7360 100644 --- a/proto/plan.proto +++ b/proto/plan.proto @@ -623,6 +623,10 @@ message UpdateCtx { repeated ColRef insert_cols = 7 [(gogoproto.nullable) = false]; repeated ColRef delete_cols = 8 [(gogoproto.nullable) = false]; repeated ColRef partition_cols = 9 [(gogoproto.nullable) = false]; + // is_replace marks UpdateCtx instances produced by REPLACE INTO. + // Used by the multi_update operator to count delete rows toward affected_rows + // (MySQL REPLACE semantics: affected_rows = inserted + deleted). + bool is_replace = 10; } message InsertCtx {