Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { describe, expect, it } from 'vitest';

import {
clearCompletedInvocationKeysForQueueItem,
hasCompletedInvocationKey,
markInvocationAsCompleted,
shouldIgnoreFinishedQueueItemInvocationEvent,
} from './invocationTracking';

describe(markInvocationAsCompleted.name, () => {
it('tracks completed invocations per queue item', () => {
const completedInvocationKeysByItemId = new Map<number, Set<string>>();

markInvocationAsCompleted(completedInvocationKeysByItemId, 1, 'prepared-node-1');
markInvocationAsCompleted(completedInvocationKeysByItemId, 2, 'prepared-node-1');

expect(hasCompletedInvocationKey(completedInvocationKeysByItemId, 1, 'prepared-node-1')).toBe(true);
expect(hasCompletedInvocationKey(completedInvocationKeysByItemId, 2, 'prepared-node-1')).toBe(true);
});

it('clears only the completed invocations for a finished queue item', () => {
const completedInvocationKeysByItemId = new Map<number, Set<string>>();

markInvocationAsCompleted(completedInvocationKeysByItemId, 1, 'prepared-node-1');
markInvocationAsCompleted(completedInvocationKeysByItemId, 2, 'prepared-node-1');

clearCompletedInvocationKeysForQueueItem(completedInvocationKeysByItemId, 1);

expect(hasCompletedInvocationKey(completedInvocationKeysByItemId, 1, 'prepared-node-1')).toBe(false);
expect(hasCompletedInvocationKey(completedInvocationKeysByItemId, 2, 'prepared-node-1')).toBe(true);
});
});

describe(shouldIgnoreFinishedQueueItemInvocationEvent.name, () => {
it('ignores late started and progress events for finished queue items', () => {
const finishedQueueItemIds = new Set<number>([1]);

expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_started', finishedQueueItemIds, 1)).toBe(true);
expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_progress', finishedQueueItemIds, 1)).toBe(true);
});

it('does not ignore late error events for finished queue items', () => {
const finishedQueueItemIds = new Set<number>([1]);

expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_error', finishedQueueItemIds, 1)).toBe(false);
});
});
45 changes: 45 additions & 0 deletions invokeai/frontend/web/src/services/events/invocationTracking.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
type CompletedInvocationKeysByItemId = Map<number, Set<string>>;

type FinishedQueueItemIds = {
has: (itemId: number) => boolean;
};

type FinishedQueueItemInvocationEventName = 'invocation_error' | 'invocation_progress' | 'invocation_started';

export const hasCompletedInvocationKey = (
completedInvocationKeysByItemId: CompletedInvocationKeysByItemId,
itemId: number,
invocationId: string
) => completedInvocationKeysByItemId.get(itemId)?.has(invocationId) ?? false;

export const markInvocationAsCompleted = (
completedInvocationKeysByItemId: CompletedInvocationKeysByItemId,
itemId: number,
invocationId: string
) => {
let completedInvocationKeys = completedInvocationKeysByItemId.get(itemId);
if (!completedInvocationKeys) {
completedInvocationKeys = new Set<string>();
completedInvocationKeysByItemId.set(itemId, completedInvocationKeys);
}
completedInvocationKeys.add(invocationId);
};

export const clearCompletedInvocationKeysForQueueItem = (
completedInvocationKeysByItemId: CompletedInvocationKeysByItemId,
itemId: number
) => {
completedInvocationKeysByItemId.delete(itemId);
};

export const shouldIgnoreFinishedQueueItemInvocationEvent = (
eventName: FinishedQueueItemInvocationEventName,
finishedQueueItemIds: FinishedQueueItemIds,
itemId: number
) => {
if (eventName === 'invocation_error') {
return false;
}

return finishedQueueItemIds.has(itemId);
};
280 changes: 280 additions & 0 deletions invokeai/frontend/web/src/services/events/nodeExecutionState.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { S } from 'services/api/types';
import { describe, expect, it } from 'vitest';

import {
getUpdatedNodeExecutionStateOnInvocationComplete,
getUpdatedNodeExecutionStateOnInvocationError,
getUpdatedNodeExecutionStateOnInvocationProgress,
getUpdatedNodeExecutionStateOnInvocationStarted,
} from './nodeExecutionState';

const buildNodeExecutionState = (overrides: Partial<NodeExecutionState> = {}): NodeExecutionState => ({
nodeId: 'node-1',
status: zNodeStatus.enum.PENDING,
progress: null,
progressImage: null,
outputs: [],
error: null,
...overrides,
});

const buildInvocationStartedEvent = (
overrides: Partial<S['InvocationStartedEvent']> = {}
): S['InvocationStartedEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
...overrides,
}) as S['InvocationStartedEvent'];

const buildInvocationProgressEvent = (
overrides: Partial<S['InvocationProgressEvent']> = {}
): S['InvocationProgressEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
percentage: 0.42,
image: {
dataURL: 'data:image/png;base64,abc',
width: 64,
height: 64,
},
message: 'working',
...overrides,
}) as S['InvocationProgressEvent'];

const buildInvocationCompleteEvent = (
overrides: Partial<S['InvocationCompleteEvent']> = {}
): S['InvocationCompleteEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
result: {
type: 'integer_output',
value: 42,
},
...overrides,
}) as S['InvocationCompleteEvent'];

const buildInvocationErrorEvent = (overrides: Partial<S['InvocationErrorEvent']> = {}): S['InvocationErrorEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
error_type: 'TestError',
error_message: 'boom',
error_traceback: 'traceback',
...overrides,
}) as S['InvocationErrorEvent'];

describe(getUpdatedNodeExecutionStateOnInvocationStarted.name, () => {
it('creates an execution state when started arrives before initialization', () => {
const event = buildInvocationStartedEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(undefined, event, new Map<number, Set<string>>());

expect(updated?.nodeId).toBe(event.invocation_source_id);
expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
expect(updated?.outputs).toEqual([]);
});

it('marks the node in progress on invocation start', () => {
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
buildNodeExecutionState(),
buildInvocationStartedEvent(),
new Map<number, Set<string>>()
);

expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
});

it('ignores a late started event after that invocation already completed', () => {
const event = buildInvocationStartedEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
event,
new Map([[event.item_id, new Set([event.invocation.id])]])
);

expect(updated).toBeUndefined();
});
});

describe(getUpdatedNodeExecutionStateOnInvocationProgress.name, () => {
it('creates an execution state when progress arrives before initialization', () => {
const event = buildInvocationProgressEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(undefined, event, new Map<number, Set<string>>());

expect(updated?.nodeId).toBe(event.invocation_source_id);
expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
expect(updated?.progress).toBe(event.percentage);
expect(updated?.progressImage).toEqual(event.image);
});

it('marks the node in progress and preserves progress updates', () => {
const event = buildInvocationProgressEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
buildNodeExecutionState(),
event,
new Map<number, Set<string>>()
);

expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
expect(updated?.progress).toBe(event.percentage);
expect(updated?.progressImage).toEqual(event.image);
});

it('ignores a late progress event after that invocation already completed', () => {
const event = buildInvocationProgressEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
event,
new Map([[event.item_id, new Set([event.invocation.id])]])
);

expect(updated).toBeUndefined();
});
});

describe(getUpdatedNodeExecutionStateOnInvocationComplete.name, () => {
it('creates an execution state when completion arrives before initialization', () => {
const event = buildInvocationCompleteEvent();
const completedInvocationKeysByItemId = new Map<number, Set<string>>();
const updated = getUpdatedNodeExecutionStateOnInvocationComplete(undefined, event, completedInvocationKeysByItemId);

expect(updated?.nodeId).toBe(event.invocation_source_id);
expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED);
expect(updated?.outputs).toEqual([event.result]);
expect(completedInvocationKeysByItemId).toEqual(new Map([[event.item_id, new Set([event.invocation.id])]]));
});

it('records a completed invocation result once', () => {
const event = buildInvocationCompleteEvent();
const completedInvocationKeysByItemId = new Map<number, Set<string>>();

const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState({ status: zNodeStatus.enum.IN_PROGRESS, progress: 0.5 }),
event,
completedInvocationKeysByItemId
);

expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED);
expect(updated?.progress).toBe(1);
expect(updated?.outputs).toEqual([event.result]);
expect(completedInvocationKeysByItemId).toEqual(new Map([[event.item_id, new Set([event.invocation.id])]]));
});

it('ignores duplicate completion events for the same invocation', () => {
const event = buildInvocationCompleteEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1, outputs: [event.result] }),
event,
new Map([[event.item_id, new Set([event.invocation.id])]])
);

expect(updated).toBeUndefined();
});

it('allows the same prepared invocation id on a different queue item', () => {
const firstEvent = buildInvocationCompleteEvent({
item_id: 1,
result: { type: 'integer_output', value: 1 } as unknown as S['InvocationCompleteEvent']['result'],
});
const secondEvent = buildInvocationCompleteEvent({
item_id: 2,
result: { type: 'integer_output', value: 2 } as unknown as S['InvocationCompleteEvent']['result'],
});
const completedInvocationKeysByItemId = new Map<number, Set<string>>();

const firstUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState(),
firstEvent,
completedInvocationKeysByItemId
);
const secondUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
firstUpdate,
secondEvent,
completedInvocationKeysByItemId
);

expect(secondUpdate?.outputs).toEqual([firstEvent.result, secondEvent.result]);
});
});

describe(getUpdatedNodeExecutionStateOnInvocationError.name, () => {
it('creates an execution state when error arrives before initialization', () => {
const event = buildInvocationErrorEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationError(undefined, event);

expect(updated?.nodeId).toBe(event.invocation_source_id);
expect(updated?.status).toBe(zNodeStatus.enum.FAILED);
expect(updated?.progress).toBeNull();
expect(updated?.progressImage).toBeNull();
expect(updated?.error).toEqual({
error_type: event.error_type,
error_message: event.error_message,
error_traceback: event.error_traceback,
});
});

it('marks the node failed and records the error', () => {
const event = buildInvocationErrorEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationError(
buildNodeExecutionState({
status: zNodeStatus.enum.IN_PROGRESS,
progress: 0.5,
progressImage: { dataURL: 'data:image/png;base64,abc', width: 64, height: 64 },
}),
event
);

expect(updated?.status).toBe(zNodeStatus.enum.FAILED);
expect(updated?.progress).toBeNull();
expect(updated?.progressImage).toBeNull();
expect(updated?.error).toEqual({
error_type: event.error_type,
error_message: event.error_message,
error_traceback: event.error_traceback,
});
});
});
Loading
Loading