Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions internal/internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,24 +310,12 @@ func (bp *basePoller) doPoll(pollFunc func(ctx context.Context) (taskForWorker,
}()

if bp.workerPollCompleteOnShutdown != nil && bp.workerPollCompleteOnShutdown.Load() {
// Don't kill the gRPC stream. After ShutdownWorker, the server returns empty responses.
select {
case <-doneC:
return result, err
case <-bp.stopC:
// TEMP FIX: Give the server a reasonable window to complete the poll after
// ShutdownWorker. Fall back to cancelling the poll if it takes too
// long, e.g. when the gRPC connection was closed before Stop().
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case <-doneC:
case <-timer.C:
cancel()
<-doneC
}
return result, err
}
// Don't cancel the gRPC stream. After ShutdownWorker, the server
// completes the poll with an empty response. The poll is bounded
// by the gRPC timeout (pollTaskServiceTimeOut). Stop() waits for
// all pollers to finish before proceeding to task drain.
<-doneC
return result, err
}

// Legacy: cancel in-flight polls immediately on shutdown
Expand Down
54 changes: 32 additions & 22 deletions internal/internal_worker_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ type (
lastPollTaskErrLock sync.Mutex

noRepoll atomic.Bool
pollerWG sync.WaitGroup
}

eagerOrPolledTask interface {
Expand Down Expand Up @@ -391,6 +392,7 @@ func (bw *baseWorker) Start() {

for i := 0; i < taskWorker.pollerCount; i++ {
bw.stopWG.Add(1)
bw.pollerWG.Add(1)
go bw.runPoller(taskWorker)
}

Expand All @@ -403,6 +405,15 @@ func (bw *baseWorker) Start() {
}
}

// When all pollers have exited, close taskQueueCh so the dispatcher
// knows no more polled tasks will arrive and can drain what remains.
bw.stopWG.Add(1)
go func() {
defer bw.stopWG.Done()
bw.pollerWG.Wait()
close(bw.taskQueueCh)
}()

bw.stopWG.Add(1)
go bw.runTaskDispatcher()

Expand All @@ -428,6 +439,7 @@ func (bw *baseWorker) isStop() bool {

func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) {
defer bw.stopWG.Done()
defer bw.pollerWG.Done()
// Note: With poller autoscaling, this metric doesn't make a lot of sense since the number of pollers can go up and down.
bw.metricsHandler.Counter(metrics.PollerStartCounter).Inc(1)

Expand Down Expand Up @@ -561,24 +573,17 @@ func (bw *baseWorker) processTaskAsync(eagerOrPolled eagerOrPolledTask) {
func (bw *baseWorker) runTaskDispatcher() {
defer bw.stopWG.Done()

for {
// wait for new task or worker stop
select {
case <-bw.stopCh:
// Currently we can drop any tasks received when closing.
// https://github.com/temporalio/sdk-go/issues/1197
return
case task := <-bw.taskQueueCh:
// for non-polled-task (local activity result as task or eager task), we don't need to rate limit
_, isPolledTask := task.(*polledTask)
if isPolledTask && bw.taskLimiter.Wait(bw.limiterContext) != nil {
if bw.isStop() {
bw.releaseSlot(task.getPermit(), SlotReleaseReasonUnused)
return
}
}
bw.processTaskAsync(task)
for task := range bw.taskQueueCh {
// For non-polled-task (local activity result as task or eager task),
// we don't need to rate limit. During shutdown the limiter context
// is cancelled, so Wait returns immediately — we still process the
// task rather than dropping it.
if _, isPolledTask := task.(*polledTask); isPolledTask {
// Ignore error: during shutdown the limiter context is
// cancelled, but we still process remaining tasks.
_ = bw.taskLimiter.Wait(bw.limiterContext)
}
bw.processTaskAsync(task)
}
}

Expand Down Expand Up @@ -639,11 +644,10 @@ func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPe
taskWorker.pollerAutoscalerReportHandle.handleTask(task)
}

select {
case bw.taskQueueCh <- &polledTask{task: task, permit: slotPermit}:
didSendTask = true
case <-bw.stopCh:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking the stop channel is still used elsewhere since we removed it in two spots

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep! still used in plenty of other places

}
// The dispatcher is guaranteed to be alive: it only exits after
// taskQueueCh is closed, which happens after all pollers finish.
bw.taskQueueCh <- &polledTask{task: task, permit: slotPermit}
didSendTask = true
}
}

Expand Down Expand Up @@ -703,6 +707,12 @@ func (bw *baseWorker) Stop() {
close(bw.stopCh)
bw.limiterContextCancel()

// Wait for pollers to finish. (pollTaskServiceTimeOut) bounds this if the connection is broken.
bw.pollerWG.Wait()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unbounded pollerWG.Wait() bypasses user-configured stopTimeout

Medium Severity

bw.pollerWG.Wait() in Stop() blocks without any timeout, and runs before awaitWaitGroup(&bw.stopWG, bw.options.stopTimeout). Combined with doPoll now waiting unconditionally on <-doneC (bounded only by pollTaskServiceTimeOut = 70s), Stop() can block for up to 70 seconds before the user's stopTimeout even begins counting. Previously, a 5-second fallback cancellation bounded this. In failure scenarios (broken gRPC connection, unresponsive server), total Stop() duration becomes ~70s + stopTimeout instead of just stopTimeout.

Additional Locations (1)
Fix in Cursor Fix in Web


// Wait for task processing to complete. The dispatcher
// drains taskQueueCh (closed after pollers finish above) and
// processTaskAsync goroutines are tracked in stopWG.
if success := awaitWaitGroup(&bw.stopWG, bw.options.stopTimeout); !success {
traceLog(func() {
bw.logger.Info("Worker graceful stop timed out.", "Stop timeout", bw.options.stopTimeout)
Expand Down
98 changes: 98 additions & 0 deletions internal/internal_worker_base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,104 @@ type noopTaskProcessor struct{}

func (noopTaskProcessor) ProcessTask(any) error { return nil }

// TestTaskNotDroppedDuringShutdown verifies the two-stage shutdown: when a
// poller receives a task during shutdown, the task is still dispatched and
// processed rather than silently dropped.
func TestTaskNotDroppedDuringShutdown(t *testing.T) {
taskProcessed := make(chan struct{})
pollStarted := make(chan struct{})

// A poller that blocks until returnTask is closed, then returns a task
// exactly once. Subsequent polls return nil so the poller can exit.
tp := &shutdownTaskPoller{
pollStarted: pollStarted,
returnTask: make(chan struct{}),
task: &testTask{},
}

processor := &recordingTaskProcessor{
processed: taskProcessed,
}

bw := newBaseWorker(baseWorkerOptions{
slotSupplier: &testSlotSupplier{},
maxTaskPerSecond: 1000,
taskPollers: []scalableTaskPoller{
{taskPollerType: "test", pollerCount: 1, taskPoller: tp},
},
taskProcessor: processor,
workerType: "ShutdownTest",
logger: ilog.NewNopLogger(),
stopTimeout: 5 * time.Second,
metricsHandler: metrics.NopHandler,
})

bw.Start()

// Wait for the poller to be actively polling.
<-pollStarted

// Release the poller so it returns a task, then stop the worker.
// The poller returns a task and then nil on subsequent polls,
// allowing it to exit via noRepoll/stopCh during Stop().
close(tp.returnTask)

// Stop exercises the real shutdown path: noRepoll, close(stopCh),
// limiterContextCancel, and awaitWaitGroup.
stopDone := make(chan struct{})
go func() {
bw.Stop()
close(stopDone)
}()

select {
case <-taskProcessed:
// Success: the task was dispatched and processed during shutdown
case <-time.After(5 * time.Second):
t.Fatal("task polled during shutdown was not processed (dropped)")
}

select {
case <-stopDone:
// Stop completed cleanly
case <-time.After(5 * time.Second):
t.Fatal("Stop() did not return in time")
}
}

// shutdownTaskPoller blocks until returnTask is closed, then returns a task
// exactly once. Subsequent polls return nil.
type shutdownTaskPoller struct {
pollStarted chan struct{}
returnTask chan struct{}
task taskForWorker
returned atomic.Bool
}

func (p *shutdownTaskPoller) PollTask() (taskForWorker, error) {
select {
case p.pollStarted <- struct{}{}:
default:
}
<-p.returnTask
if p.returned.CompareAndSwap(false, true) {
return p.task, nil
}
return nil, nil
}

type recordingTaskProcessor struct {
processed chan struct{}
}

func (p *recordingTaskProcessor) ProcessTask(any) error {
select {
case p.processed <- struct{}{}:
default:
}
return nil
}

func (s *PollScalerReportHandleSuite) TestAutoscaleDownOnTimeoutWithCapability() {
targetSuggestion := 0
ps := newPollScalerReportHandle(pollScalerReportHandleOptions{
Expand Down
4 changes: 3 additions & 1 deletion test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4303,7 +4303,9 @@ func (ts *IntegrationTestSuite) testUpdateOrderingCancel(cancelWf bool) {
}()
}

// The server does not support admitted updates, so we send the update in a separate goroutine
// The server does not support admitted updates, so we send the update in a separate goroutine.
// Keep this shorter than the activity's ScheduleToCloseTimeout (5s) so the new worker
// has time to execute activities before they time out.
time.Sleep(5 * time.Second)
// Now create a new worker on that same task queue to resume the work of the
// workflow
Expand Down
Loading