-
Notifications
You must be signed in to change notification settings - Fork 860
[WIP] Fix message ordering when processing function approval responses #7158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
10ba264
30a50bb
ba5e9bf
e65dd44
e58a3ae
e0bb70d
1bf5d84
7dbbdbb
370eaec
89ccabe
fb92a70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -295,10 +295,10 @@ public override async Task<ChatResponse> GetResponseAsync( | |
| // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message | ||
| // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if | ||
| // the inner client had returned them directly. | ||
| (responseMessages, var notInvokedApprovals) = ProcessFunctionApprovalResponses( | ||
| (responseMessages, var notInvokedApprovals, var resultInsertionIndex) = ProcessFunctionApprovalResponses( | ||
| originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId: null, functionCallContentFallbackMessageId: null); | ||
| (IList<ChatMessage>? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = | ||
| await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: false, cancellationToken); | ||
| await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, resultInsertionIndex, isStreaming: false, cancellationToken); | ||
|
|
||
| if (invokedApprovedFunctionApprovalResponses is not null) | ||
| { | ||
|
|
@@ -381,7 +381,7 @@ public override async Task<ChatResponse> GetResponseAsync( | |
|
|
||
| // Add the responses from the function calls into the augmented history and also into the tracked | ||
| // list of response messages. | ||
| var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken); | ||
| var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: false, cancellationToken); | ||
| responseMessages.AddRange(modeAndMessages.MessagesAdded); | ||
| consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; | ||
|
|
||
|
|
@@ -447,7 +447,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA | |
| // approval requests, we need to process them now. This entails removing these manufactured approval requests from the chat message | ||
| // list and replacing them with the appropriate FunctionCallContents and FunctionResultContents that would have been generated if | ||
| // the inner client had returned them directly. | ||
| var (preDownstreamCallHistory, notInvokedApprovals) = ProcessFunctionApprovalResponses( | ||
| var (preDownstreamCallHistory, notInvokedApprovals, resultInsertionIndex) = ProcessFunctionApprovalResponses( | ||
| originalMessages, !string.IsNullOrWhiteSpace(options?.ConversationId), toolMessageId, functionCallContentFallbackMessageId); | ||
| if (preDownstreamCallHistory is not null) | ||
| { | ||
|
|
@@ -460,7 +460,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA | |
|
|
||
| // Invoke approved approval responses, which generates some additional FRC wrapped in ChatMessage. | ||
| (IList<ChatMessage>? invokedApprovedFunctionApprovalResponses, bool shouldTerminate, consecutiveErrorCount) = | ||
| await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, isStreaming: true, cancellationToken); | ||
| await InvokeApprovedFunctionApprovalResponsesAsync(notInvokedApprovals, toolMap, originalMessages, options, consecutiveErrorCount, resultInsertionIndex, isStreaming: true, cancellationToken); | ||
|
|
||
| if (invokedApprovedFunctionApprovalResponses is not null) | ||
| { | ||
|
|
@@ -604,7 +604,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA | |
| FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadConversationId); | ||
|
|
||
| // Process all of the functions, adding their results into the history. | ||
| var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken); | ||
| var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, toolMap, functionCallContents!, iteration, consecutiveErrorCount, insertionIndex: -1, isStreaming: true, cancellationToken); | ||
| responseMessages.AddRange(modeAndMessages.MessagesAdded); | ||
| consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; | ||
|
|
||
|
|
@@ -880,13 +880,14 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont | |
| /// <param name="functionCallContents">The function call contents representing the functions to be invoked.</param> | ||
| /// <param name="iteration">The iteration number of how many roundtrips have been made to the inner client.</param> | ||
| /// <param name="consecutiveErrorCount">The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors.</param> | ||
| /// <param name="insertionIndex">The index at which to insert the function result messages, or -1 to append to the end.</param> | ||
| /// <param name="isStreaming">Whether the function calls are being processed in a streaming context.</param> | ||
| /// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param> | ||
| /// <returns>A value indicating how the caller should proceed.</returns> | ||
| private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList<ChatMessage> MessagesAdded)> ProcessFunctionCallsAsync( | ||
| List<ChatMessage> messages, ChatOptions? options, | ||
| Dictionary<string, AITool>? toolMap, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount, | ||
| bool isStreaming, CancellationToken cancellationToken) | ||
| int insertionIndex, bool isStreaming, CancellationToken cancellationToken) | ||
| { | ||
| // We must add a response for every tool call, regardless of whether we successfully executed it or not. | ||
| // If we successfully execute it, we'll add the result. If we don't, we'll add an error. | ||
|
|
@@ -905,7 +906,16 @@ private bool ShouldTerminateLoopBasedOnHandleableFunctions(List<FunctionCallCont | |
| IList<ChatMessage> addedMessages = CreateResponseMessages([result]); | ||
| ThrowIfNoFunctionResultsAdded(addedMessages); | ||
| UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); | ||
| messages.AddRange(addedMessages); | ||
|
|
||
| // Insert at the specified position or append if no valid insertion index | ||
| if (insertionIndex >= 0 && insertionIndex <= messages.Count) | ||
| { | ||
| messages.InsertRange(insertionIndex, addedMessages); | ||
| } | ||
| else | ||
| { | ||
| messages.AddRange(addedMessages); | ||
| } | ||
|
|
||
| return (result.Terminate, consecutiveErrorCount, addedMessages); | ||
| } | ||
|
|
@@ -950,7 +960,16 @@ select ProcessFunctionCallAsync( | |
| IList<ChatMessage> addedMessages = CreateResponseMessages(results.ToArray()); | ||
| ThrowIfNoFunctionResultsAdded(addedMessages); | ||
| UpdateConsecutiveErrorCountOrThrow(addedMessages, ref consecutiveErrorCount); | ||
| messages.AddRange(addedMessages); | ||
|
|
||
| // Insert at the specified position or append if no valid insertion index | ||
| if (insertionIndex >= 0 && insertionIndex <= messages.Count) | ||
| { | ||
| messages.InsertRange(insertionIndex, addedMessages); | ||
| } | ||
| else | ||
| { | ||
| messages.AddRange(addedMessages); | ||
| } | ||
|
|
||
| return (shouldTerminate, consecutiveErrorCount, addedMessages); | ||
| } | ||
|
|
@@ -1248,12 +1267,13 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul | |
| /// 3. Genreate failed <see cref="FunctionResultContent"/> for any rejected <see cref="FunctionApprovalResponseContent"/>. | ||
| /// 4. add all the new content items to <paramref name="originalMessages"/> and return them as the pre-invocation history. | ||
| /// </summary> | ||
| private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals) ProcessFunctionApprovalResponses( | ||
| private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResultWithRequestMessage>? approvals, int insertionIndex) ProcessFunctionApprovalResponses( | ||
| List<ChatMessage> originalMessages, bool hasConversationId, string? toolMessageId, string? functionCallContentFallbackMessageId) | ||
| { | ||
| // Extract any approval responses where we need to execute or reject the function calls. | ||
| // The original messages are also modified to remove all approval requests and responses. | ||
| var notInvokedResponses = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages); | ||
| var (notInvokedApprovalsResult, notInvokedRejectionsResult, insertionIndex) = ExtractAndRemoveApprovalRequestsAndResponses(originalMessages); | ||
| var notInvokedResponses = (approvals: notInvokedApprovalsResult, rejections: notInvokedRejectionsResult); | ||
|
|
||
| // Wrap the function call content in message(s). | ||
| ICollection<ChatMessage>? allPreDownstreamCallMessages = ConvertToFunctionCallContentMessages( | ||
|
|
@@ -1269,25 +1289,54 @@ private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResult | |
| // Add all the FCC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. | ||
| // Also, if we are not dealing with a service thread (i.e. we don't have a conversation ID), add them | ||
| // into the original messages list so that they are passed to the inner client and can be used to generate a result. | ||
| // Insert at the position where the approval request was originally located to preserve message ordering. | ||
| List<ChatMessage>? preDownstreamCallHistory = null; | ||
| if (allPreDownstreamCallMessages is not null) | ||
| { | ||
| preDownstreamCallHistory = [.. allPreDownstreamCallMessages]; | ||
| if (!hasConversationId) | ||
| { | ||
| originalMessages.AddRange(preDownstreamCallHistory); | ||
| // If we have a valid insertion index, insert at that position. Otherwise, append to the end. | ||
| if (insertionIndex >= 0 && insertionIndex <= originalMessages.Count) | ||
| { | ||
| originalMessages.InsertRange(insertionIndex, preDownstreamCallHistory); | ||
| } | ||
| else | ||
| { | ||
| originalMessages.AddRange(preDownstreamCallHistory); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Add all the FRC that we generated to the pre-downstream-call history so that they can be returned to the caller as part of the next response. | ||
| // Also, add them into the original messages list so that they are passed to the inner client and can be used to generate a result. | ||
| // Insert immediately after the FCC messages to preserve message ordering. | ||
| if (rejectedPreDownstreamCallResultsMessage is not null) | ||
| { | ||
| (preDownstreamCallHistory ??= []).Add(rejectedPreDownstreamCallResultsMessage); | ||
| originalMessages.Add(rejectedPreDownstreamCallResultsMessage); | ||
|
|
||
| // Calculate the insertion position: right after the FCC messages we just inserted | ||
| // Only add the FCC count if they were actually inserted (!hasConversationId) | ||
| int rejectedInsertionIndex = insertionIndex >= 0 && insertionIndex <= originalMessages.Count | ||
| ? insertionIndex + (!hasConversationId ? (allPreDownstreamCallMessages?.Count ?? 0) : 0) | ||
| : originalMessages.Count; | ||
|
|
||
| if (rejectedInsertionIndex >= 0 && rejectedInsertionIndex <= originalMessages.Count) | ||
| { | ||
| originalMessages.Insert(rejectedInsertionIndex, rejectedPreDownstreamCallResultsMessage); | ||
| } | ||
| else | ||
| { | ||
| originalMessages.Add(rejectedPreDownstreamCallResultsMessage); | ||
| } | ||
| } | ||
|
|
||
| return (preDownstreamCallHistory, notInvokedResponses.approvals); | ||
| // Calculate the insertion index for function result content (after the FCC messages and rejected FRC messages) | ||
| int resultInsertionIndex = insertionIndex >= 0 && insertionIndex <= originalMessages.Count && !hasConversationId | ||
| ? insertionIndex + (allPreDownstreamCallMessages?.Count ?? 0) + (rejectedPreDownstreamCallResultsMessage is not null ? 1 : 0) | ||
| : -1; | ||
|
|
||
| return (preDownstreamCallHistory, notInvokedResponses.approvals, resultInsertionIndex); | ||
| } | ||
|
|
||
| /// <summary> | ||
|
|
@@ -1299,13 +1348,14 @@ private static (List<ChatMessage>? preDownstreamCallHistory, List<ApprovalResult | |
| /// We can then use the metadata from these messages when we re-create the FunctionCallContent messages/updates to return to the caller. This way, when we finally do return | ||
| /// the FuncionCallContent to users it's part of a message/update that contains the same metadata as originally returned to the downstream service. | ||
| /// </remarks> | ||
| private static (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections) ExtractAndRemoveApprovalRequestsAndResponses( | ||
| private static (List<ApprovalResultWithRequestMessage>? approvals, List<ApprovalResultWithRequestMessage>? rejections, int insertionIndex) ExtractAndRemoveApprovalRequestsAndResponses( | ||
| List<ChatMessage> messages) | ||
| { | ||
| Dictionary<string, ChatMessage>? allApprovalRequestsMessages = null; | ||
| List<FunctionApprovalResponseContent>? allApprovalResponses = null; | ||
| HashSet<string>? approvalRequestCallIds = null; | ||
| HashSet<string>? functionResultCallIds = null; | ||
| int firstApprovalRequestIndex = -1; | ||
|
|
||
| // 1st iteration, over all messages and content: | ||
| // - Build a list of all function call ids that are already executed. | ||
|
|
@@ -1330,6 +1380,13 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval | |
| // Validation: Capture each call id for each approval request to ensure later we have a matching response. | ||
| _ = (approvalRequestCallIds ??= []).Add(farc.FunctionCall.CallId); | ||
| (allApprovalRequestsMessages ??= []).Add(farc.Id, message); | ||
|
|
||
| // Track the first approval request index for later insertion | ||
| if (firstApprovalRequestIndex == -1) | ||
| { | ||
| firstApprovalRequestIndex = i; | ||
| } | ||
|
|
||
| break; | ||
|
|
||
| case FunctionApprovalResponseContent farc: | ||
|
|
@@ -1371,9 +1428,53 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval | |
| } | ||
|
|
||
| // Clean up any messages that were marked for removal during the iteration. | ||
| // Also adjust the insertion index to account for removed messages. | ||
| int insertionIndex = firstApprovalRequestIndex; | ||
| if (anyRemoved) | ||
| { | ||
| // Count how many messages before the first approval request were removed | ||
| int removedBeforeInsertionIndex = 0; | ||
| if (firstApprovalRequestIndex >= 0) | ||
| { | ||
| for (int idx = 0; idx < firstApprovalRequestIndex; idx++) | ||
| { | ||
| if (messages[idx] is null) | ||
| { | ||
| removedBeforeInsertionIndex++; | ||
| } | ||
| } | ||
| } | ||
stephentoub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| _ = messages.RemoveAll(static m => m is null); | ||
|
|
||
| // Adjust the insertion index | ||
| if (insertionIndex >= 0) | ||
| { | ||
| insertionIndex -= removedBeforeInsertionIndex; | ||
| } | ||
| } | ||
|
|
||
| // If there are already-executed function results, insert new function calls at the end instead of at the insertion index | ||
| // to preserve the ordering of already-present function calls and results. This handles scenarios where: | ||
| // 1. Previous approval responses have been processed and their function calls/results are present in the message list | ||
| // 2. New approval responses are being processed | ||
| // In this case, we want the new function calls to come AFTER the existing ones, not at the position | ||
| // where the first (already-processed) approval request was originally located. | ||
stephentoub marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // | ||
| // Example: | ||
| // Before extraction (original user input with approval messages): | ||
| // [User, FunctionApprovalRequest(A), FunctionApprovalResponse(A), FunctionResult(A), FunctionApprovalRequest(B), FunctionApprovalResponse(B)] | ||
| // After extraction of approval requests/responses (state of 'messages' at this point): | ||
| // [User, FunctionResult(A)] | ||
| // After processing approval for B, if we inserted at the original index where B's approval request was, | ||
| // we'd incorrectly interleave new calls with existing results: | ||
| // [User, FunctionCall(B), FunctionResult(B), FunctionResult(A)] // Wrong order | ||
| // But if there are already function results present (e.g., for A), we instead append new function calls/results | ||
| // for B at the end to preserve chronological ordering: | ||
| // [User, FunctionResult(A), FunctionCall(B), FunctionResult(B)] // Correct order | ||
| if (functionResultCallIds is { Count: > 0 } && insertionIndex >= 0) | ||
| { | ||
| insertionIndex = messages.Count; | ||
| } | ||
|
||
|
|
||
| // Validation: If we got an approval for each request, we should have no call ids left. | ||
|
|
@@ -1408,7 +1509,7 @@ private static (List<ApprovalResultWithRequestMessage>? approvals, List<Approval | |
| } | ||
| } | ||
|
|
||
| return (approvedFunctionCalls, rejectedFunctionCalls); | ||
| return (approvedFunctionCalls, rejectedFunctionCalls, insertionIndex); | ||
| } | ||
|
|
||
| /// <summary> | ||
|
|
@@ -1658,6 +1759,7 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => | |
| List<ChatMessage> originalMessages, | ||
| ChatOptions? options, | ||
| int consecutiveErrorCount, | ||
| int insertionIndex, | ||
| bool isStreaming, | ||
| CancellationToken cancellationToken) | ||
| { | ||
|
|
@@ -1666,7 +1768,7 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => | |
| { | ||
| // The FRC that is generated here is already added to originalMessages by ProcessFunctionCallsAsync. | ||
| var modeAndMessages = await ProcessFunctionCallsAsync( | ||
| originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, isStreaming, cancellationToken); | ||
| originalMessages, options, toolMap, notInvokedApprovals.Select(x => x.Response.FunctionCall).ToList(), 0, consecutiveErrorCount, insertionIndex, isStreaming, cancellationToken); | ||
| consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; | ||
|
|
||
| return (modeAndMessages.MessagesAdded, modeAndMessages.ShouldTerminate, consecutiveErrorCount); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.