diff --git a/pkg/sql/colexec/dedupjoin/join.go b/pkg/sql/colexec/dedupjoin/join.go index 75c378cf6c024..ea6b5a75235ff 100644 --- a/pkg/sql/colexec/dedupjoin/join.go +++ b/pkg/sql/colexec/dedupjoin/join.go @@ -16,6 +16,7 @@ package dedupjoin import ( "bytes" + "context" "strings" "time" @@ -33,6 +34,49 @@ import ( "github.com/matrixorigin/matrixone/pkg/vm/process" ) +// receiveWorkerMsg blocks until the channel yields a message or the context +// is canceled. Returns nil on close or cancellation. +func receiveWorkerMsg(ctx context.Context, ch chan *WorkerJoinMsg) *WorkerJoinMsg { + select { + case <-ctx.Done(): + return nil + case msg, ok := <-ch: + if !ok { + return nil + } + return msg + } +} + +// mergeCaptured folds a non-merger worker's captured state into the merger's. +// For each bucket set in msg.captured that the merger has not yet captured, +// the merger copies the per-column values from the worker's capturedVecs into +// its own and marks the bucket. First-wins semantics across workers: the +// merger retains whichever worker's values arrive first. +func (ctr *container) mergeCaptured(ap *DedupJoin, msg *WorkerJoinMsg, proc *process.Process) error { + if ctr.capturedVecs == nil || msg.capturedVecs == nil { + return nil + } + itr := msg.captured.Iterator() + for itr.HasNext() { + bucket := itr.Next() + if ctr.captured.Contains(bucket) { + continue + } + for cIdx := range ctr.capturedVecs { + if err := ctr.capturedVecs[cIdx].Copy( + msg.capturedVecs[cIdx], + int64(bucket), int64(bucket), + proc.Mp(), + ); err != nil { + return err + } + } + ctr.captured.Add(bucket) + } + return nil +} + const opName = "dedup_join" func (dedupJoin *DedupJoin) String(buf *bytes.Buffer) { @@ -234,19 +278,35 @@ func (ctr *container) finalize(ap *DedupJoin, proc *process.Process) error { if ap.NumCPU > 1 { if !ap.IsMerger { - ap.Channel <- ctr.matched + msg := &WorkerJoinMsg{matched: ctr.matched} + if len(ap.OldColCapturePlaceholderIdxList) > 0 { + // Transfer ownership of capture state to the merger; clear + // our references so cleanCaptured() does not double-free. + msg.captured = ctr.captured + msg.capturedVecs = ctr.capturedVecs + ctr.captured = nil + ctr.capturedVecs = nil + } + ap.Channel <- msg return nil - } else { - for cnt := 1; cnt < int(ap.NumCPU); cnt++ { - v := colexec.ReceiveBitmapFromChannel(proc.Ctx, ap.Channel) - if v != nil { - ctr.matched.Or(v) - } else { - return nil + } + for cnt := 1; cnt < int(ap.NumCPU); cnt++ { + msg := receiveWorkerMsg(proc.Ctx, ap.Channel) + if msg == nil { + return nil + } + if msg.matched != nil { + ctr.matched.Or(msg.matched) + } + if len(ap.OldColCapturePlaceholderIdxList) > 0 && msg.captured != nil { + if err := ctr.mergeCaptured(ap, msg, proc); err != nil { + freeCapturedVecs(msg.capturedVecs, proc) + return err } + freeCapturedVecs(msg.capturedVecs, proc) } - close(ap.Channel) } + close(ap.Channel) } if ap.OnDuplicateAction != plan.Node_UPDATE || ctr.mp.HashOnUnique() { diff --git a/pkg/sql/colexec/dedupjoin/join_test.go b/pkg/sql/colexec/dedupjoin/join_test.go index e2a892a95b298..219c7d269017e 100644 --- a/pkg/sql/colexec/dedupjoin/join_test.go +++ b/pkg/sql/colexec/dedupjoin/join_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/matrixorigin/matrixone/pkg/common/bitmap" "github.com/matrixorigin/matrixone/pkg/common/mpool" "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/types" @@ -642,3 +643,185 @@ func TestDedupJoinCaptureReset(t *testing.T) { proc.Free() require.Equal(t, int64(0), proc.Mp().CurrNB()) } + +// makeCaptureFixture constructs a merger container and a ready-to-send +// WorkerJoinMsg sharing the same bucket layout. Caller owns cleanup of both +// sides via Free of the returned vectors (merger's via its container, msg's +// via freeCapturedVecs or merger ownership transfer). +func makeCaptureFixture(t *testing.T, proc *process.Process, bucketCnt int) (*container, *WorkerJoinMsg) { + int32Typ := types.T_int32.ToType() + mkVec := func() *vector.Vector { + v := vector.NewOffHeapVecWithType(int32Typ) + require.NoError(t, vector.AppendMultiFixed(v, int32(0), true, bucketCnt, proc.Mp())) + return v + } + ctr := &container{ + capturedVecs: []*vector.Vector{mkVec()}, + captured: &bitmap.Bitmap{}, + matched: &bitmap.Bitmap{}, + } + ctr.captured.InitWithSize(int64(bucketCnt)) + ctr.matched.InitWithSize(int64(bucketCnt)) + + msg := &WorkerJoinMsg{ + matched: &bitmap.Bitmap{}, + captured: &bitmap.Bitmap{}, + capturedVecs: []*vector.Vector{mkVec()}, + } + msg.matched.InitWithSize(int64(bucketCnt)) + msg.captured.InitWithSize(int64(bucketCnt)) + return ctr, msg +} + +// writeBucketValue sets capturedVecs[0][bucket] = val and records the bucket +// in the accompanying captured bitmap. +func writeBucketValue(t *testing.T, vecs []*vector.Vector, captured *bitmap.Bitmap, bucket uint64, val int32, proc *process.Process) { + src := vector.NewOffHeapVecWithType(types.T_int32.ToType()) + defer src.Free(proc.Mp()) + require.NoError(t, vector.AppendFixed(src, val, false, proc.Mp())) + require.NoError(t, vecs[0].Copy(src, int64(bucket), 0, proc.Mp())) + captured.Add(bucket) +} + +// TestMergeCaptured_DisjointBuckets covers the common parallel case where +// merger and non-merger captured different buckets. After merge, the merger +// owns the union of both sides. +func TestMergeCaptured_DisjointBuckets(t *testing.T) { + proc, ctrl := newCaptureTestProc(t) + defer ctrl.Finish() + + ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}} + ctr, msg := makeCaptureFixture(t, proc, 4) + + writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 10, proc) + writeBucketValue(t, msg.capturedVecs, msg.captured, 2, 20, proc) + + require.NoError(t, ctr.mergeCaptured(ap, msg, proc)) + + require.True(t, ctr.captured.Contains(0)) + require.True(t, ctr.captured.Contains(2)) + require.False(t, ctr.captured.Contains(1)) + vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0]) + require.Equal(t, int32(10), vals[0]) + require.Equal(t, int32(20), vals[2]) + + freeCapturedVecs(msg.capturedVecs, proc) + for _, v := range ctr.capturedVecs { + v.Free(proc.Mp()) + } + proc.Free() + require.Equal(t, int64(0), proc.Mp().CurrNB()) +} + +// TestMergeCaptured_FirstWinsOnConflict verifies that when merger and +// non-merger both captured the same bucket, the merger's value is retained. +func TestMergeCaptured_FirstWinsOnConflict(t *testing.T) { + proc, ctrl := newCaptureTestProc(t) + defer ctrl.Finish() + + ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}} + ctr, msg := makeCaptureFixture(t, proc, 2) + + writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 111, proc) + writeBucketValue(t, msg.capturedVecs, msg.captured, 0, 222, proc) + + require.NoError(t, ctr.mergeCaptured(ap, msg, proc)) + + require.True(t, ctr.captured.Contains(0)) + vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0]) + require.Equal(t, int32(111), vals[0], "merger's existing capture must win") + + freeCapturedVecs(msg.capturedVecs, proc) + for _, v := range ctr.capturedVecs { + v.Free(proc.Mp()) + } + proc.Free() + require.Equal(t, int64(0), proc.Mp().CurrNB()) +} + +// TestMergeCaptured_EmptyWorkerMsg verifies a non-merger worker that captured +// nothing does not corrupt the merger state. +func TestMergeCaptured_EmptyWorkerMsg(t *testing.T) { + proc, ctrl := newCaptureTestProc(t) + defer ctrl.Finish() + + ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}} + ctr, msg := makeCaptureFixture(t, proc, 2) + + writeBucketValue(t, ctr.capturedVecs, ctr.captured, 1, 77, proc) + + require.NoError(t, ctr.mergeCaptured(ap, msg, proc)) + + require.True(t, ctr.captured.Contains(1)) + require.False(t, ctr.captured.Contains(0)) + vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0]) + require.Equal(t, int32(77), vals[1]) + + freeCapturedVecs(msg.capturedVecs, proc) + for _, v := range ctr.capturedVecs { + v.Free(proc.Mp()) + } + proc.Free() + require.Equal(t, int64(0), proc.Mp().CurrNB()) +} + +// TestWorkerJoinMsg_ChannelRoundTrip verifies the channel transport: +// non-merger sends a WorkerJoinMsg that transfers capture ownership; receiver +// reads it back and folds it in via mergeCaptured with no leaks. +func TestWorkerJoinMsg_ChannelRoundTrip(t *testing.T) { + proc, ctrl := newCaptureTestProc(t) + defer ctrl.Finish() + + ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}} + ctr, msg := makeCaptureFixture(t, proc, 3) + + writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 1, proc) + writeBucketValue(t, msg.capturedVecs, msg.captured, 1, 2, proc) + writeBucketValue(t, msg.capturedVecs, msg.captured, 2, 3, proc) + + ch := make(chan *WorkerJoinMsg, 1) + ch <- msg + close(ch) + + received := receiveWorkerMsg(context.Background(), ch) + require.NotNil(t, received) + require.Same(t, msg, received) + + require.NoError(t, ctr.mergeCaptured(ap, received, proc)) + freeCapturedVecs(received.capturedVecs, proc) + + require.True(t, ctr.captured.Contains(0)) + require.True(t, ctr.captured.Contains(1)) + require.True(t, ctr.captured.Contains(2)) + vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0]) + require.Equal(t, int32(1), vals[0]) + require.Equal(t, int32(2), vals[1]) + require.Equal(t, int32(3), vals[2]) + + for _, v := range ctr.capturedVecs { + v.Free(proc.Mp()) + } + proc.Free() + require.Equal(t, int64(0), proc.Mp().CurrNB()) +} + +// TestReceiveWorkerMsg_ContextCancel verifies the receive helper respects +// context cancellation and returns nil (used to unblock the merger when a +// worker dies abnormally). +func TestReceiveWorkerMsg_ContextCancel(t *testing.T) { + ch := make(chan *WorkerJoinMsg) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + msg := receiveWorkerMsg(ctx, ch) + require.Nil(t, msg) +} + +// TestReceiveWorkerMsg_ChannelClose verifies that a closed channel returns nil. +func TestReceiveWorkerMsg_ChannelClose(t *testing.T) { + ch := make(chan *WorkerJoinMsg) + close(ch) + + msg := receiveWorkerMsg(context.Background(), ch) + require.Nil(t, msg) +} diff --git a/pkg/sql/colexec/dedupjoin/types.go b/pkg/sql/colexec/dedupjoin/types.go index e81d592726f43..b7583234ed292 100644 --- a/pkg/sql/colexec/dedupjoin/types.go +++ b/pkg/sql/colexec/dedupjoin/types.go @@ -36,6 +36,31 @@ const ( End ) +// WorkerJoinMsg carries per-worker state from non-merger workers to the +// merger worker at finalize time. Regular DEDUP JOIN only populates matched; +// the REPLACE INTO merged main-table scan path (OldColCapture) additionally +// populates captured and capturedVecs. +// +// Ownership: once a non-merger worker sends this message on the channel, it +// must relinquish its references to captured / capturedVecs so that the +// merger is the sole owner and is responsible for Free'ing capturedVecs. +type WorkerJoinMsg struct { + matched *bitmap.Bitmap + captured *bitmap.Bitmap + capturedVecs []*vector.Vector +} + +// freeCapturedVecs releases vectors owned by a WorkerJoinMsg. Intended to be +// called by the merger after it has finished merging captures out of the +// message (ownership was transferred from the sender). +func freeCapturedVecs(vecs []*vector.Vector, proc *process.Process) { + for _, v := range vecs { + if v != nil { + v.Free(proc.GetMPool()) + } + } +} + type evalVector struct { executor colexec.ExpressionExecutor vec *vector.Vector @@ -94,7 +119,7 @@ type DedupJoin struct { RuntimeFilterSpecs []*plan.RuntimeFilterSpec JoinMapTag int32 - Channel chan *bitmap.Bitmap + Channel chan *WorkerJoinMsg NumCPU uint64 IsMerger bool diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go index 864ef356cb263..856f002faecdd 100644 --- a/pkg/sql/compile/compile.go +++ b/pkg/sql/compile/compile.go @@ -2647,18 +2647,10 @@ func (c *Compile) compileProbeSideForBroadcastJoin(node, left, right *plan.Node, } else { rs = c.newProbeScopeListForBroadcastJoin(probeScopes, true) currentFirstFlag := c.anal.isFirst - // OldColCapture (merged main-table scan for REPLACE INTO) keeps - // captured vectors in per-worker local state; the parallel finalize - // path only merges the matched bitmap, not capturedVecs. Force - // single-worker to avoid losing captures from non-merger workers. - hasCapture := node.DedupJoinCtx != nil && len(node.DedupJoinCtx.OldColCaptureList) > 0 for i := range rs { op := constructDedupJoin(node, leftTypes, rightTypes, c.proc) op.SetAnalyzeControl(c.anal.curNodeIdx, currentFirstFlag) rs[i].setRootOperator(op) - if hasCapture { - rs[i].NodeInfo.Mcpu = 1 - } } c.anal.isFirst = false } diff --git a/pkg/sql/compile/operator.go b/pkg/sql/compile/operator.go index 4a2835e5f5ee0..33a8348cc1f4a 100644 --- a/pkg/sql/compile/operator.go +++ b/pkg/sql/compile/operator.go @@ -546,7 +546,7 @@ func dupOperator(sourceOp vm.Operator, index int, maxParallel int) vm.Operator { t := sourceOp.(*dedupjoin.DedupJoin) op := dedupjoin.NewArgument() if t.Channel == nil { - t.Channel = make(chan *bitmap.Bitmap, maxParallel) + t.Channel = make(chan *dedupjoin.WorkerJoinMsg, maxParallel) } op.Channel = t.Channel op.NumCPU = uint64(maxParallel)