Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions pkg/sql/colexec/dedupjoin/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package dedupjoin

import (
"bytes"
"context"
"strings"
"time"

Expand All @@ -33,6 +34,49 @@ import (
"github.com/matrixorigin/matrixone/pkg/vm/process"
)

// receiveWorkerMsg blocks until the channel yields a message or the context
// is canceled. Returns nil on close or cancellation.
func receiveWorkerMsg(ctx context.Context, ch chan *WorkerJoinMsg) *WorkerJoinMsg {
select {
case <-ctx.Done():
return nil
case msg, ok := <-ch:
if !ok {
return nil
}
return msg
}
}

// mergeCaptured folds a non-merger worker's captured state into the merger's.
// For each bucket set in msg.captured that the merger has not yet captured,
// the merger copies the per-column values from the worker's capturedVecs into
// its own and marks the bucket. First-wins semantics across workers: the
// merger retains whichever worker's values arrive first.
func (ctr *container) mergeCaptured(ap *DedupJoin, msg *WorkerJoinMsg, proc *process.Process) error {
if ctr.capturedVecs == nil || msg.capturedVecs == nil {
return nil
}
itr := msg.captured.Iterator()
for itr.HasNext() {
bucket := itr.Next()
if ctr.captured.Contains(bucket) {
continue
}
for cIdx := range ctr.capturedVecs {
if err := ctr.capturedVecs[cIdx].Copy(
msg.capturedVecs[cIdx],
int64(bucket), int64(bucket),
proc.Mp(),
); err != nil {
return err
}
}
ctr.captured.Add(bucket)
}
return nil
}

const opName = "dedup_join"

func (dedupJoin *DedupJoin) String(buf *bytes.Buffer) {
Expand Down Expand Up @@ -234,19 +278,35 @@ func (ctr *container) finalize(ap *DedupJoin, proc *process.Process) error {

if ap.NumCPU > 1 {
if !ap.IsMerger {
ap.Channel <- ctr.matched
msg := &WorkerJoinMsg{matched: ctr.matched}
if len(ap.OldColCapturePlaceholderIdxList) > 0 {
// Transfer ownership of capture state to the merger; clear
// our references so cleanCaptured() does not double-free.
msg.captured = ctr.captured
msg.capturedVecs = ctr.capturedVecs
ctr.captured = nil
ctr.capturedVecs = nil
}
ap.Channel <- msg
return nil
} else {
for cnt := 1; cnt < int(ap.NumCPU); cnt++ {
v := colexec.ReceiveBitmapFromChannel(proc.Ctx, ap.Channel)
if v != nil {
ctr.matched.Or(v)
} else {
return nil
}
for cnt := 1; cnt < int(ap.NumCPU); cnt++ {
msg := receiveWorkerMsg(proc.Ctx, ap.Channel)
if msg == nil {
return nil
}
if msg.matched != nil {
ctr.matched.Or(msg.matched)
}
if len(ap.OldColCapturePlaceholderIdxList) > 0 && msg.captured != nil {
if err := ctr.mergeCaptured(ap, msg, proc); err != nil {
freeCapturedVecs(msg.capturedVecs, proc)
return err
}
freeCapturedVecs(msg.capturedVecs, proc)
}
close(ap.Channel)
}
close(ap.Channel)
}

if ap.OnDuplicateAction != plan.Node_UPDATE || ctr.mp.HashOnUnique() {
Expand Down
183 changes: 183 additions & 0 deletions pkg/sql/colexec/dedupjoin/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/golang/mock/gomock"
"github.com/matrixorigin/matrixone/pkg/common/bitmap"
"github.com/matrixorigin/matrixone/pkg/common/mpool"
"github.com/matrixorigin/matrixone/pkg/container/batch"
"github.com/matrixorigin/matrixone/pkg/container/types"
Expand Down Expand Up @@ -642,3 +643,185 @@ func TestDedupJoinCaptureReset(t *testing.T) {
proc.Free()
require.Equal(t, int64(0), proc.Mp().CurrNB())
}

// makeCaptureFixture constructs a merger container and a ready-to-send
// WorkerJoinMsg sharing the same bucket layout. Caller owns cleanup of both
// sides via Free of the returned vectors (merger's via its container, msg's
// via freeCapturedVecs or merger ownership transfer).
func makeCaptureFixture(t *testing.T, proc *process.Process, bucketCnt int) (*container, *WorkerJoinMsg) {
int32Typ := types.T_int32.ToType()
mkVec := func() *vector.Vector {
v := vector.NewOffHeapVecWithType(int32Typ)
require.NoError(t, vector.AppendMultiFixed(v, int32(0), true, bucketCnt, proc.Mp()))
return v
}
ctr := &container{
capturedVecs: []*vector.Vector{mkVec()},
captured: &bitmap.Bitmap{},
matched: &bitmap.Bitmap{},
}
ctr.captured.InitWithSize(int64(bucketCnt))
ctr.matched.InitWithSize(int64(bucketCnt))

msg := &WorkerJoinMsg{
matched: &bitmap.Bitmap{},
captured: &bitmap.Bitmap{},
capturedVecs: []*vector.Vector{mkVec()},
}
msg.matched.InitWithSize(int64(bucketCnt))
msg.captured.InitWithSize(int64(bucketCnt))
return ctr, msg
}

// writeBucketValue sets capturedVecs[0][bucket] = val and records the bucket
// in the accompanying captured bitmap.
func writeBucketValue(t *testing.T, vecs []*vector.Vector, captured *bitmap.Bitmap, bucket uint64, val int32, proc *process.Process) {
src := vector.NewOffHeapVecWithType(types.T_int32.ToType())
defer src.Free(proc.Mp())
require.NoError(t, vector.AppendFixed(src, val, false, proc.Mp()))
require.NoError(t, vecs[0].Copy(src, int64(bucket), 0, proc.Mp()))
captured.Add(bucket)
}

// TestMergeCaptured_DisjointBuckets covers the common parallel case where
// merger and non-merger captured different buckets. After merge, the merger
// owns the union of both sides.
func TestMergeCaptured_DisjointBuckets(t *testing.T) {
proc, ctrl := newCaptureTestProc(t)
defer ctrl.Finish()

ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}}
ctr, msg := makeCaptureFixture(t, proc, 4)

writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 10, proc)
writeBucketValue(t, msg.capturedVecs, msg.captured, 2, 20, proc)

require.NoError(t, ctr.mergeCaptured(ap, msg, proc))

require.True(t, ctr.captured.Contains(0))
require.True(t, ctr.captured.Contains(2))
require.False(t, ctr.captured.Contains(1))
vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0])
require.Equal(t, int32(10), vals[0])
require.Equal(t, int32(20), vals[2])

freeCapturedVecs(msg.capturedVecs, proc)
for _, v := range ctr.capturedVecs {
v.Free(proc.Mp())
}
proc.Free()
require.Equal(t, int64(0), proc.Mp().CurrNB())
}

// TestMergeCaptured_FirstWinsOnConflict verifies that when merger and
// non-merger both captured the same bucket, the merger's value is retained.
func TestMergeCaptured_FirstWinsOnConflict(t *testing.T) {
proc, ctrl := newCaptureTestProc(t)
defer ctrl.Finish()

ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}}
ctr, msg := makeCaptureFixture(t, proc, 2)

writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 111, proc)
writeBucketValue(t, msg.capturedVecs, msg.captured, 0, 222, proc)

require.NoError(t, ctr.mergeCaptured(ap, msg, proc))

require.True(t, ctr.captured.Contains(0))
vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0])
require.Equal(t, int32(111), vals[0], "merger's existing capture must win")

freeCapturedVecs(msg.capturedVecs, proc)
for _, v := range ctr.capturedVecs {
v.Free(proc.Mp())
}
proc.Free()
require.Equal(t, int64(0), proc.Mp().CurrNB())
}

// TestMergeCaptured_EmptyWorkerMsg verifies a non-merger worker that captured
// nothing does not corrupt the merger state.
func TestMergeCaptured_EmptyWorkerMsg(t *testing.T) {
proc, ctrl := newCaptureTestProc(t)
defer ctrl.Finish()

ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}}
ctr, msg := makeCaptureFixture(t, proc, 2)

writeBucketValue(t, ctr.capturedVecs, ctr.captured, 1, 77, proc)

require.NoError(t, ctr.mergeCaptured(ap, msg, proc))

require.True(t, ctr.captured.Contains(1))
require.False(t, ctr.captured.Contains(0))
vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0])
require.Equal(t, int32(77), vals[1])

freeCapturedVecs(msg.capturedVecs, proc)
for _, v := range ctr.capturedVecs {
v.Free(proc.Mp())
}
proc.Free()
require.Equal(t, int64(0), proc.Mp().CurrNB())
}

// TestWorkerJoinMsg_ChannelRoundTrip verifies the channel transport:
// non-merger sends a WorkerJoinMsg that transfers capture ownership; receiver
// reads it back and folds it in via mergeCaptured with no leaks.
func TestWorkerJoinMsg_ChannelRoundTrip(t *testing.T) {
proc, ctrl := newCaptureTestProc(t)
defer ctrl.Finish()

ap := &DedupJoin{OldColCapturePlaceholderIdxList: []int32{1}, OldColCaptureProbeIdxList: []int32{1}}
ctr, msg := makeCaptureFixture(t, proc, 3)

writeBucketValue(t, ctr.capturedVecs, ctr.captured, 0, 1, proc)
writeBucketValue(t, msg.capturedVecs, msg.captured, 1, 2, proc)
writeBucketValue(t, msg.capturedVecs, msg.captured, 2, 3, proc)

ch := make(chan *WorkerJoinMsg, 1)
ch <- msg
close(ch)

received := receiveWorkerMsg(context.Background(), ch)
require.NotNil(t, received)
require.Same(t, msg, received)

require.NoError(t, ctr.mergeCaptured(ap, received, proc))
freeCapturedVecs(received.capturedVecs, proc)

require.True(t, ctr.captured.Contains(0))
require.True(t, ctr.captured.Contains(1))
require.True(t, ctr.captured.Contains(2))
vals := vector.MustFixedColNoTypeCheck[int32](ctr.capturedVecs[0])
require.Equal(t, int32(1), vals[0])
require.Equal(t, int32(2), vals[1])
require.Equal(t, int32(3), vals[2])

for _, v := range ctr.capturedVecs {
v.Free(proc.Mp())
}
proc.Free()
require.Equal(t, int64(0), proc.Mp().CurrNB())
}

// TestReceiveWorkerMsg_ContextCancel verifies the receive helper respects
// context cancellation and returns nil (used to unblock the merger when a
// worker dies abnormally).
func TestReceiveWorkerMsg_ContextCancel(t *testing.T) {
ch := make(chan *WorkerJoinMsg)
ctx, cancel := context.WithCancel(context.Background())
cancel()

msg := receiveWorkerMsg(ctx, ch)
require.Nil(t, msg)
}

// TestReceiveWorkerMsg_ChannelClose verifies that a closed channel returns nil.
func TestReceiveWorkerMsg_ChannelClose(t *testing.T) {
ch := make(chan *WorkerJoinMsg)
close(ch)

msg := receiveWorkerMsg(context.Background(), ch)
require.Nil(t, msg)
}
27 changes: 26 additions & 1 deletion pkg/sql/colexec/dedupjoin/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@ const (
End
)

// WorkerJoinMsg carries per-worker state from non-merger workers to the
// merger worker at finalize time. Regular DEDUP JOIN only populates matched;
// the REPLACE INTO merged main-table scan path (OldColCapture) additionally
// populates captured and capturedVecs.
//
// Ownership: once a non-merger worker sends this message on the channel, it
// must relinquish its references to captured / capturedVecs so that the
// merger is the sole owner and is responsible for Free'ing capturedVecs.
type WorkerJoinMsg struct {
matched *bitmap.Bitmap
captured *bitmap.Bitmap
capturedVecs []*vector.Vector
}

// freeCapturedVecs releases vectors owned by a WorkerJoinMsg. Intended to be
// called by the merger after it has finished merging captures out of the
// message (ownership was transferred from the sender).
func freeCapturedVecs(vecs []*vector.Vector, proc *process.Process) {
for _, v := range vecs {
if v != nil {
v.Free(proc.GetMPool())
}
}
}

type evalVector struct {
executor colexec.ExpressionExecutor
vec *vector.Vector
Expand Down Expand Up @@ -94,7 +119,7 @@ type DedupJoin struct {
RuntimeFilterSpecs []*plan.RuntimeFilterSpec
JoinMapTag int32

Channel chan *bitmap.Bitmap
Channel chan *WorkerJoinMsg
NumCPU uint64
IsMerger bool

Expand Down
8 changes: 0 additions & 8 deletions pkg/sql/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2647,18 +2647,10 @@ func (c *Compile) compileProbeSideForBroadcastJoin(node, left, right *plan.Node,
} else {
rs = c.newProbeScopeListForBroadcastJoin(probeScopes, true)
currentFirstFlag := c.anal.isFirst
// OldColCapture (merged main-table scan for REPLACE INTO) keeps
// captured vectors in per-worker local state; the parallel finalize
// path only merges the matched bitmap, not capturedVecs. Force
// single-worker to avoid losing captures from non-merger workers.
hasCapture := node.DedupJoinCtx != nil && len(node.DedupJoinCtx.OldColCaptureList) > 0
for i := range rs {
op := constructDedupJoin(node, leftTypes, rightTypes, c.proc)
op.SetAnalyzeControl(c.anal.curNodeIdx, currentFirstFlag)
rs[i].setRootOperator(op)
if hasCapture {
rs[i].NodeInfo.Mcpu = 1
}
}
c.anal.isFirst = false
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/compile/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ func dupOperator(sourceOp vm.Operator, index int, maxParallel int) vm.Operator {
t := sourceOp.(*dedupjoin.DedupJoin)
op := dedupjoin.NewArgument()
if t.Channel == nil {
t.Channel = make(chan *bitmap.Bitmap, maxParallel)
t.Channel = make(chan *dedupjoin.WorkerJoinMsg, maxParallel)
}
op.Channel = t.Channel
op.NumCPU = uint64(maxParallel)
Expand Down
Loading