Skip to content
32 changes: 18 additions & 14 deletions internal/internal_nexus_task_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ func newNexusTaskHandler(

func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) {
failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses())
pollerGroupId := task.GetPollerGroupId()
nctx, handlerErr := h.newNexusOperationContext(task)
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand All @@ -107,13 +108,13 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo
return nil, nil, err
}
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
return nil, failureRequest, nil
}
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport)
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand All @@ -122,18 +123,19 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo

func (h *nexusTaskHandler) ExecuteContext(nctx *NexusOperationContext, task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) {
failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses())
pollerGroupId := task.GetPollerGroupId()
res, handlerErr, err := h.execute(nctx, task)
if err != nil {
return nil, nil, err
}
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
return nil, failureRequest, nil
}
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport)
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -459,7 +461,7 @@ func (h *nexusTaskHandler) newNexusOperationContext(response *workflowservice.Po
}, nil
}

func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool) (*workflowservice.RespondNexusTaskCompletedRequest, error) {
func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskCompletedRequest, error) {
// Handle conversion of Failure to OperationError for backwards compatibility with old servers.
if res.GetStartOperation().GetFailure() != nil && !failureReasonSupport {
// Convert to operation error for backwards compatibility.
Expand Down Expand Up @@ -488,18 +490,20 @@ func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Respo
}
}
return &workflowservice.RespondNexusTaskCompletedRequest{
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Response: res,
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Response: res,
PollerGroupId: pollerGroupId,
}, nil
}

func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool) (*workflowservice.RespondNexusTaskFailedRequest, error) {
func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskFailedRequest, error) {
r := &workflowservice.RespondNexusTaskFailedRequest{
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
PollerGroupId: pollerGroupId,
}
if failureReasonSupport {
r.Failure = h.failureConverter.ErrorToFailure(handlerError)
Expand Down
26 changes: 17 additions & 9 deletions internal/internal_nexus_task_poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (

type nexusTaskPoller struct {
basePoller
namespace string
taskQueueName string
identity string
service workflowservice.WorkflowServiceClient
taskHandler *nexusTaskHandler
logger log.Logger
numPollerMetric *numPollerMetric
namespace string
taskQueueName string
identity string
service workflowservice.WorkflowServiceClient
taskHandler *nexusTaskHandler
logger log.Logger
numPollerMetric *numPollerMetric
pollerGroupTracker *pollerGroupTracker
}

type nexusTask struct {
Expand Down Expand Up @@ -53,7 +54,8 @@ func newNexusTaskPoller(
taskQueueName: params.TaskQueue,
identity: params.Identity,
logger: params.Logger,
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask),
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask),
pollerGroupTracker: newPollerGroupTracker(),
}
}

Expand All @@ -69,6 +71,10 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) {
traceLog(func() {
ntp.logger.Debug("nexusTaskPoller::Poll")
})

groupId := ntp.pollerGroupTracker.getNextGroupId()
defer ntp.pollerGroupTracker.release(groupId)

request := &workflowservice.PollNexusTaskQueueRequest{
Namespace: ntp.namespace,
TaskQueue: &taskqueuepb.TaskQueue{Name: ntp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL},
Expand All @@ -83,12 +89,14 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) {
ntp.workerDeploymentVersion,
),
WorkerInstanceKey: ntp.workerInstanceKey,
PollerGroupId: groupId,
}

response, err := ntp.pollNexusTaskQueue(ctx, request)
if err != nil {
return nil, err
}
ntp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())
if response == nil || len(response.TaskToken) == 0 {
// No operation info is available on empty poll. Emit using base scope.
ntp.metricsHandler.Counter(metrics.NexusPollNoTaskCounter).Inc(1)
Expand Down Expand Up @@ -131,7 +139,7 @@ func (ntp *nexusTaskPoller) ProcessTask(task interface{}) error {
nctx, handlerErr := ntp.taskHandler.newNexusOperationContext(response)
if handlerErr != nil {
// context wasn't propagated to us, use a background context.
failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses()))
failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses()), response.GetPollerGroupId())
if err != nil {
return err
}
Expand Down
19 changes: 3 additions & 16 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1802,20 +1802,6 @@ func isCommandMatchEvent(d *commandpb.Command, e *historypb.HistoryEvent, obes [
return false
}

func isSearchAttributesMatched(attrFromEvent, attrFromCommand *commonpb.SearchAttributes) bool {
if attrFromEvent != nil && attrFromCommand != nil {
return reflect.DeepEqual(attrFromEvent.IndexedFields, attrFromCommand.IndexedFields)
}
return attrFromEvent == nil && attrFromCommand == nil
}

func isMemoMatched(attrFromEvent, attrFromCommand *commonpb.Memo) bool {
if attrFromEvent != nil && attrFromCommand != nil {
return reflect.DeepEqual(attrFromEvent.Fields, attrFromCommand.Fields)
}
return attrFromEvent == nil && attrFromCommand == nil
}

// return true if the check fails:
//
// namespace is not empty in command
Expand All @@ -1839,8 +1825,9 @@ func (wth *workflowTaskHandlerImpl) completeWorkflow(
// for query task
if task.Query != nil {
queryCompletedRequest := &workflowservice.RespondQueryTaskCompletedRequest{
TaskToken: task.TaskToken,
Namespace: wth.namespace,
TaskToken: task.TaskToken,
Namespace: wth.namespace,
PollerGroupId: task.GetPollerGroupId(),
}
var panicErr *PanicError
if errors.As(workflowContext.err, &panicErr) {
Expand Down
15 changes: 15 additions & 0 deletions internal/internal_task_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"
"strconv"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -43,6 +44,20 @@ const (
testNamespace = "test-namespace"
)

func isSearchAttributesMatched(attrFromEvent, attrFromCommand *commonpb.SearchAttributes) bool {
if attrFromEvent != nil && attrFromCommand != nil {
return reflect.DeepEqual(attrFromEvent.IndexedFields, attrFromCommand.IndexedFields)
}
return attrFromEvent == nil && attrFromCommand == nil
}

func isMemoMatched(attrFromEvent, attrFromCommand *commonpb.Memo) bool {
if attrFromEvent != nil && attrFromCommand != nil {
return reflect.DeepEqual(attrFromEvent.Fields, attrFromCommand.Fields)
}
return attrFromEvent == nil && attrFromCommand == nil
}

type (
TaskHandlersTestSuite struct {
suite.Suite
Expand Down
18 changes: 17 additions & 1 deletion internal/internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ type (
numNormalPollerMetric *numPollerMetric
numStickyPollerMetric *numPollerMetric

pollerGroupTracker *pollerGroupTracker
inboundPayloadVisitor PayloadVisitor
payloadVisitorConcurrency int
}
Expand Down Expand Up @@ -161,7 +162,7 @@ type (
payloadVisitorConcurrency int
}

// activityTaskPoller implements polling/processing a workflow task
// activityTaskPoller implements polling/processing an activity task
activityTaskPoller struct {
basePoller
namespace string
Expand All @@ -172,6 +173,7 @@ type (
logger log.Logger
activitiesPerSecond float64
numPollerMetric *numPollerMetric
pollerGroupTracker *pollerGroupTracker
}

historyIteratorImpl struct {
Expand Down Expand Up @@ -430,6 +432,7 @@ func (wtp *workflowTaskProcessor) createPoller(mode workflowTaskPollerMode) task
eagerActivityExecutor: wtp.eagerActivityExecutor,
numNormalPollerMetric: wtp.numNormalPollerMetric,
numStickyPollerMetric: wtp.numStickyPollerMetric,
pollerGroupTracker: newPollerGroupTracker(),
inboundPayloadVisitor: wtp.inboundPayloadVisitor,
payloadVisitorConcurrency: wtp.payloadVisitorConcurrency,
}
Expand Down Expand Up @@ -789,6 +792,7 @@ func (wtp *workflowTaskProcessor) reportGrpcMessageTooLarge(
Namespace: wtp.namespace,
Failure: wtp.failureConverter.ErrorToFailure(sendErr),
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_GRPC_MESSAGE_TOO_LARGE,
PollerGroupId: task.GetPollerGroupId(),
}
if err = visitProtoPayloads(ctx, wtp.outboundPayloadVisitor, request, wtp.payloadVisitorConcurrency); err != nil {
wtp.logger.Error("Failed to visit payloads for GRPC message too large query failure response.", tagError, err)
Expand Down Expand Up @@ -1141,6 +1145,9 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po
panic("unknown workflow task poller mode")
}

groupId := wtp.pollerGroupTracker.getNextGroupId()
defer wtp.pollerGroupTracker.release(groupId)

builtRequest := &workflowservice.PollWorkflowTaskQueueRequest{
Namespace: wtp.namespace,
TaskQueue: taskQueue,
Expand All @@ -1156,6 +1163,7 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po
wtp.workerDeploymentVersion,
),
WorkerInstanceKey: wtp.workerInstanceKey,
PollerGroupId: groupId,
}
if wtp.getCapabilities().BuildIdBasedVersioning {
//lint:ignore SA1019 ignore deprecated versioning APIs
Expand Down Expand Up @@ -1191,6 +1199,7 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error)
wtp.updateBacklog(request.TaskQueue.GetKind(), 0)
return nil, err
}
wtp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())

if response == nil || len(response.TaskToken) == 0 {
// Emit using base scope as no workflow type information is available in the case of empty poll
Expand Down Expand Up @@ -1382,6 +1391,7 @@ func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserv
logger: params.Logger,
activitiesPerSecond: params.TaskQueueActivitiesPerSecond,
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeActivityTask),
pollerGroupTracker: newPollerGroupTracker(),
}
}

Expand All @@ -1398,6 +1408,10 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error)
traceLog(func() {
atp.logger.Debug("activityTaskPoller::Poll")
})

groupId := atp.pollerGroupTracker.getNextGroupId()
defer atp.pollerGroupTracker.release(groupId)

request := &workflowservice.PollActivityTaskQueueRequest{
Namespace: atp.namespace,
TaskQueue: &taskqueuepb.TaskQueue{Name: atp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL},
Expand All @@ -1413,12 +1427,14 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error)
atp.workerDeploymentVersion,
),
WorkerInstanceKey: atp.workerInstanceKey,
PollerGroupId: groupId,
}

response, err := atp.pollActivityTaskQueue(ctx, request)
if err != nil {
return nil, err
}
atp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())
if response == nil || len(response.TaskToken) == 0 {
// No activity info is available on empty poll. Emit using base scope.
atp.metricsHandler.Counter(metrics.ActivityPollNoTaskCounter).Inc(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2775,7 +2775,7 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation(
response, failure, err := taskHandler.Execute(task)
if err != nil {
// No retries for operations, fail the operation immediately.
failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "%s", err.Error()), false)
failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "%s", err.Error()), false, "")
}
if failure != nil {
// Convert to a nexus HandlerError first to simulate the flow in the server.
Expand Down
Loading
Loading