From 8f81b1f8052af6e38f31d16657d863d76a8106fc Mon Sep 17 00:00:00 2001 From: Gohsuke Shimada Date: Fri, 20 Mar 2026 04:55:07 +0900 Subject: [PATCH 1/5] feat(ui): add canvas snapshot save/restore functionality Add ability to save and restore canvas state snapshots, allowing users to preserve their canvas layout at any point and restore it later. This is useful when the canvas freezes or resets unexpectedly. Backend: - Add get_keys_by_prefix and delete_by_key to client_state persistence - Add corresponding API endpoints Frontend: - Add canvasSnapshotRestored reducer to canvasSlice - Add useCanvasSnapshots hook for snapshot CRUD operations - Add CanvasToolbarSnapshotMenuButton with save/restore UI - Add i18n keys for snapshot feature - Regenerate API schema types Tests: - Add tests for new client_state endpoints (prefix search, key deletion) Co-Authored-By: Claude Opus 4.6 (1M context) --- invokeai/app/api/routers/client_state.py | 38 ++++ .../client_state_persistence_base.py | 25 +++ .../client_state_persistence_sqlite.py | 22 +++ invokeai/frontend/web/public/locales/en.json | 14 ++ .../components/Toolbar/CanvasToolbar.tsx | 2 + .../CanvasToolbarSnapshotMenuButton.tsx | 173 ++++++++++++++++++ .../controlLayers/hooks/useCanvasSnapshots.ts | 135 ++++++++++++++ .../controlLayers/store/canvasSlice.ts | 12 ++ .../frontend/web/src/services/api/schema.ts | 117 ++++++++++++ .../routers/test_client_state_multiuser.py | 145 +++++++++++++++ 10 files changed, 683 insertions(+) create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts diff --git a/invokeai/app/api/routers/client_state.py b/invokeai/app/api/routers/client_state.py index 2e34ea9fe6b..cd92263f97c 100644 --- a/invokeai/app/api/routers/client_state.py +++ b/invokeai/app/api/routers/client_state.py @@ -45,6 +45,44 @@ async def set_client_state( raise HTTPException(status_code=500, detail="Error setting client state") +@client_state_router.get( + "/{queue_id}/get_keys_by_prefix", + operation_id="get_client_state_keys_by_prefix", + response_model=list[str], +) +async def get_client_state_keys_by_prefix( + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), + prefix: str = Query(..., description="Prefix to filter keys by"), +) -> list[str]: + """Gets client state keys matching a prefix for the current user""" + try: + return ApiDependencies.invoker.services.client_state_persistence.get_keys_by_prefix( + current_user.user_id, prefix + ) + except Exception as e: + logging.error(f"Error getting client state keys: {e}") + raise HTTPException(status_code=500, detail="Error getting client state keys") + + +@client_state_router.post( + "/{queue_id}/delete_by_key", + operation_id="delete_client_state_by_key", + responses={204: {"description": "Client state key deleted"}}, +) +async def delete_client_state_by_key( + current_user: CurrentUserOrDefault, + queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"), + key: str = Query(..., description="Key to delete"), +) -> None: + """Deletes a specific client state key for the current user""" + try: + ApiDependencies.invoker.services.client_state_persistence.delete_by_key(current_user.user_id, key) + except Exception as e: + logging.error(f"Error deleting client state key: {e}") + raise HTTPException(status_code=500, detail="Error deleting client state key") + + @client_state_router.post( "/{queue_id}/delete", operation_id="delete_client_state", diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py index 99ad71bc8b7..7be6841a790 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_base.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_base.py @@ -36,6 +36,31 @@ def get_by_key(self, user_id: str, key: str) -> str | None: """ pass + @abstractmethod + def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: + """ + Get all keys matching a prefix for a user. + + Args: + user_id (str): The user ID to get keys for. + prefix (str): The prefix to filter keys by. + + Returns: + list[str]: A list of keys matching the prefix. + """ + pass + + @abstractmethod + def delete_by_key(self, user_id: str, key: str) -> None: + """ + Delete a specific key-value pair for a user. + + Args: + user_id (str): The user ID to delete state for. + key (str): The key to delete. + """ + pass + @abstractmethod def delete(self, user_id: str) -> None: """ diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 643db306857..8f5bf828572 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -44,6 +44,28 @@ def get_by_key(self, user_id: str, key: str) -> str | None: return None return row[0] + def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: + with self._db.transaction() as cursor: + cursor.execute( + """ + SELECT key FROM client_state + WHERE user_id = ? AND key LIKE ? + ORDER BY rowid DESC + """, + (user_id, f"{prefix}%"), + ) + return [row[0] for row in cursor.fetchall()] + + def delete_by_key(self, user_id: str, key: str) -> None: + with self._db.transaction() as cursor: + cursor.execute( + """ + DELETE FROM client_state + WHERE user_id = ? AND key = ? + """, + (user_id, key), + ) + def delete(self, user_id: str) -> None: with self._db.transaction() as cursor: cursor.execute( diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 58be5430a26..1e2e4dc4e09 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2898,6 +2898,20 @@ "off": "Off", "switchOnStart": "On Start", "switchOnFinish": "On Finish" + }, + "snapshot": { + "snapshots": "Save or Load Canvas Snapshot", + "saveSnapshot": "Save Snapshot", + "restoreSnapshot": "Restore Snapshot", + "snapshotNamePlaceholder": "Snapshot name", + "save": "Save", + "delete": "Delete", + "snapshotSaved": "Snapshot \"{{name}}\" saved", + "snapshotRestored": "Snapshot \"{{name}}\" restored", + "snapshotDeleted": "Snapshot \"{{name}}\" deleted", + "snapshotSaveFailed": "Failed to save snapshot", + "snapshotRestoreFailed": "Failed to restore snapshot", + "snapshotDeleteFailed": "Failed to delete snapshot" } }, "upscaling": { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bf186ed6300..76533605965 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -14,6 +14,7 @@ import { CanvasToolbarRedoButton } from 'features/controlLayers/components/Toolb import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton'; import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton'; import { CanvasToolbarScale } from 'features/controlLayers/components/Toolbar/CanvasToolbarScale'; +import { CanvasToolbarSnapshotMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton'; import { CanvasToolbarUndoButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarUndoButton'; import { useCanvasDeleteLayerHotkey } from 'features/controlLayers/hooks/useCanvasDeleteLayerHotkey'; import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/useCanvasEntityQuickSwitchHotkey'; @@ -68,6 +69,7 @@ export const CanvasToolbar = memo(() => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx new file mode 100644 index 00000000000..a55d9ae913a --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx @@ -0,0 +1,173 @@ +import { + Flex, + IconButton, + Input, + Menu, + MenuButton, + MenuDivider, + MenuGroup, + MenuItem, + MenuList, + Text, +} from '@invoke-ai/ui-library'; +import type { SnapshotInfo } from 'features/controlLayers/hooks/useCanvasSnapshots'; +import { useCanvasSnapshots } from 'features/controlLayers/hooks/useCanvasSnapshots'; +import { toast } from 'features/toast/toast'; +import type { ChangeEvent, KeyboardEvent, MouseEvent } from 'react'; +import { memo, useCallback, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCameraBold, PiFloppyDiskBold, PiTrashBold } from 'react-icons/pi'; + +const SnapshotItem = memo( + ({ + snapshot, + onRestore, + onDelete, + }: { + snapshot: SnapshotInfo; + onRestore: (key: string, name: string) => void; + onDelete: (e: MouseEvent, key: string, name: string) => void; + }) => { + const handleClick = useCallback(() => { + onRestore(snapshot.key, snapshot.name); + }, [onRestore, snapshot.key, snapshot.name]); + + const handleDelete = useCallback( + (e: MouseEvent) => { + onDelete(e, snapshot.key, snapshot.name); + }, + [onDelete, snapshot.key, snapshot.name] + ); + + return ( + + + + {snapshot.name} + + } + size="xs" + variant="ghost" + colorScheme="error" + onClick={handleDelete} + /> + + + ); + } +); + +SnapshotItem.displayName = 'SnapshotItem'; + +const getDefaultSnapshotName = (): string => { + const now = new Date(); + const y = now.getFullYear(); + const mo = String(now.getMonth() + 1).padStart(2, '0'); + const d = String(now.getDate()).padStart(2, '0'); + const h = String(now.getHours()).padStart(2, '0'); + const mi = String(now.getMinutes()).padStart(2, '0'); + return `${y}/${mo}/${d} ${h}:${mi}`; +}; + +export const CanvasToolbarSnapshotMenuButton = memo(() => { + const { t } = useTranslation(); + const { snapshots, saveSnapshot, restoreSnapshot, deleteSnapshot } = useCanvasSnapshots(); + const [snapshotName, setSnapshotName] = useState(''); + + const onNameChange = useCallback((e: ChangeEvent) => { + setSnapshotName(e.target.value); + }, []); + + const onSave = useCallback(async () => { + const name = snapshotName.trim() || getDefaultSnapshotName(); + const success = await saveSnapshot(name); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotSaved', { name }), status: 'info' }); + setSnapshotName(''); + } else { + toast({ title: t('controlLayers.snapshot.snapshotSaveFailed'), status: 'error' }); + } + }, [snapshotName, saveSnapshot, t]); + + const onKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter') { + e.preventDefault(); + e.stopPropagation(); + onSave(); + } + }, + [onSave] + ); + + const onRestore = useCallback( + async (key: string, name: string) => { + const success = await restoreSnapshot(key); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotRestored', { name }), status: 'info' }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotRestoreFailed'), status: 'error' }); + } + }, + [restoreSnapshot, t] + ); + + const onDelete = useCallback( + async (e: MouseEvent, key: string, name: string) => { + e.stopPropagation(); + const success = await deleteSnapshot(key); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotDeleted', { name }), status: 'info' }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotDeleteFailed'), status: 'error' }); + } + }, + [deleteSnapshot, t] + ); + + return ( + + } + variant="link" + alignSelf="stretch" + /> + + + + + } + size="sm" + onClick={onSave} + /> + + + {snapshots.length > 0 && ( + <> + + + {snapshots.map((snapshot) => ( + + ))} + + + )} + + + ); +}); + +CanvasToolbarSnapshotMenuButton.displayName = 'CanvasToolbarSnapshotMenuButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts new file mode 100644 index 00000000000..f078ef672a0 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts @@ -0,0 +1,135 @@ +import { logger } from 'app/logging/logger'; +import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; +import { canvasSnapshotRestored } from 'features/controlLayers/store/canvasSlice'; +import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import { zCanvasState } from 'features/controlLayers/store/types'; +import { useCallback, useEffect, useState } from 'react'; +import { serializeError } from 'serialize-error'; +import { buildV1Url, getBaseUrl } from 'services/api'; +import type { JsonObject } from 'type-fest'; + +const log = logger('canvas'); + +const SNAPSHOT_PREFIX = 'canvas_snapshot:'; + +const getAuthHeaders = (): Record => { + const headers: Record = {}; + if (typeof window !== 'undefined' && window.localStorage) { + const token = localStorage.getItem('auth_token'); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } + } + return headers; +}; + +const getUrl = (endpoint: string, query?: Record) => { + const baseUrl = getBaseUrl(); + const path = buildV1Url(`client_state/default/${endpoint}`, query); + return `${baseUrl}/${path}`; +}; + +export type SnapshotInfo = { + key: string; + name: string; +}; + +export const useCanvasSnapshots = () => { + const dispatch = useAppDispatch(); + const store = useAppStore(); + const [snapshots, setSnapshots] = useState([]); + + const fetchSnapshots = useCallback(async () => { + try { + const url = getUrl('get_keys_by_prefix', { prefix: SNAPSHOT_PREFIX }); + const res = await fetch(url, { method: 'GET', headers: getAuthHeaders() }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + const keys: string[] = await res.json(); + setSnapshots( + keys.map((key) => ({ + key, + name: key.slice(SNAPSHOT_PREFIX.length), + })) + ); + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to fetch snapshots'); + } + }, []); + + const saveSnapshot = useCallback( + async (name: string) => { + try { + const state = selectCanvasSlice(store.getState()); + const value = JSON.stringify(state); + const key = `${SNAPSHOT_PREFIX}${name}`; + const url = getUrl('set_by_key', { key }); + const res = await fetch(url, { + method: 'POST', + body: value, + headers: getAuthHeaders(), + }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + await fetchSnapshots(); + return true; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to save snapshot'); + return false; + } + }, + [store, fetchSnapshots] + ); + + const restoreSnapshot = useCallback( + async (key: string) => { + try { + const url = getUrl('get_by_key', { key }); + const res = await fetch(url, { method: 'GET', headers: getAuthHeaders() }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + const raw = await res.json(); + const parsed = JSON.parse(raw); + const canvasState = zCanvasState.parse(parsed); + dispatch(canvasSnapshotRestored(canvasState)); + return true; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to restore snapshot'); + return false; + } + }, + [dispatch] + ); + + const deleteSnapshot = useCallback( + async (key: string) => { + try { + const url = getUrl('delete_by_key', { key }); + const res = await fetch(url, { method: 'POST', headers: getAuthHeaders() }); + if (!res.ok) { + throw new Error(`Response status: ${res.status}`); + } + await fetchSnapshots(); + return true; + } catch (e) { + log.error({ error: serializeError(e) } as JsonObject, 'Failed to delete snapshot'); + return false; + } + }, + [fetchSnapshots] + ); + + useEffect(() => { + fetchSnapshots(); + }, [fetchSnapshots]); + + return { + snapshots, + saveSnapshot, + restoreSnapshot, + deleteSnapshot, + }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 79d3963d122..385cd45b198 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1710,6 +1710,17 @@ const slice = createSlice({ state.regionalGuidance.entities = regionalGuidance; return state; }, + canvasSnapshotRestored: (state, action: PayloadAction) => { + const snapshot = action.payload; + state.controlLayers = snapshot.controlLayers; + state.inpaintMasks = snapshot.inpaintMasks; + state.rasterLayers = snapshot.rasterLayers; + state.regionalGuidance = snapshot.regionalGuidance; + state.bbox = snapshot.bbox; + state.selectedEntityIdentifier = snapshot.selectedEntityIdentifier; + state.bookmarkedEntityIdentifier = snapshot.bookmarkedEntityIdentifier; + return state; + }, canvasUndo: () => {}, canvasRedo: () => {}, canvasClearHistory: () => {}, @@ -1768,6 +1779,7 @@ const resetState = (state: CanvasState) => { export const { canvasMetadataRecalled, + canvasSnapshotRestored, canvasUndo, canvasRedo, canvasClearHistory, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index fc6506ce22b..cd5687e50ae 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2386,6 +2386,46 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/client_state/{queue_id}/get_keys_by_prefix": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Client State Keys By Prefix + * @description Gets client state keys matching a prefix for the current user + */ + get: operations["get_client_state_keys_by_prefix"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/client_state/{queue_id}/delete_by_key": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Delete Client State By Key + * @description Deletes a specific client state key for the current user + */ + post: operations["delete_client_state_by_key"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/client_state/{queue_id}/delete": { parameters: { query?: never; @@ -33511,6 +33551,83 @@ export interface operations { }; }; }; + get_client_state_keys_by_prefix: { + parameters: { + query: { + /** @description Prefix to filter keys by */ + prefix: string; + }; + header?: never; + path: { + /** @description The queue id (ignored, kept for backwards compatibility) */ + queue_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": string[]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + delete_client_state_by_key: { + parameters: { + query: { + /** @description Key to delete */ + key: string; + }; + header?: never; + path: { + /** @description The queue id (ignored, kept for backwards compatibility) */ + queue_id: string; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": unknown; + }; + }; + /** @description Client state key deleted */ + 204: { + headers: { + [name: string]: unknown; + }; + content?: never; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; delete_client_state: { parameters: { query?: never; diff --git a/tests/app/routers/test_client_state_multiuser.py b/tests/app/routers/test_client_state_multiuser.py index 814c9182fec..4ca1de3bf49 100644 --- a/tests/app/routers/test_client_state_multiuser.py +++ b/tests/app/routers/test_client_state_multiuser.py @@ -297,3 +297,148 @@ def test_complex_json_values(client: TestClient, admin_token: str): ) assert get_response.status_code == status.HTTP_200_OK assert get_response.json() == complex_value + + +def test_get_keys_by_prefix_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that keys can be retrieved by prefix without authentication.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set several keys with a common prefix directly + for i in range(3): + mock_invoker.services.client_state_persistence.set_by_key("system", f"canvas_snapshot:snap{i}", f"value{i}") + mock_invoker.services.client_state_persistence.set_by_key("system", "other_key", "other_value") + + # Get keys by prefix + response = client.get("/api/v1/client_state/default/get_keys_by_prefix?prefix=canvas_snapshot:") + assert response.status_code == status.HTTP_200_OK + keys = response.json() + assert len(keys) == 3 + assert "canvas_snapshot:snap0" in keys + assert "canvas_snapshot:snap1" in keys + assert "canvas_snapshot:snap2" in keys + assert "other_key" not in keys + + +def test_get_keys_by_prefix_empty_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that an empty list is returned when no keys match the prefix.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + response = client.get("/api/v1/client_state/default/get_keys_by_prefix?prefix=nonexistent_prefix:") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [] + + +def test_delete_by_key_without_auth(client: TestClient, monkeypatch, mock_invoker: Invoker): + """Test that a specific key can be deleted without affecting other keys.""" + monkeypatch.setattr("invokeai.app.api.auth_dependencies.ApiDependencies", MockApiDependencies(mock_invoker)) + monkeypatch.setattr("invokeai.app.api.routers.client_state.ApiDependencies", MockApiDependencies(mock_invoker)) + + # Set two keys directly + mock_invoker.services.client_state_persistence.set_by_key("system", "keep_key", "keep_value") + mock_invoker.services.client_state_persistence.set_by_key("system", "delete_key", "delete_value") + + # Delete only one key via endpoint + delete_response = client.post("/api/v1/client_state/default/delete_by_key?key=delete_key") + assert delete_response.status_code == status.HTTP_200_OK + + # Verify deleted key is gone + value = mock_invoker.services.client_state_persistence.get_by_key("system", "delete_key") + assert value is None + + # Verify other key still exists + value = mock_invoker.services.client_state_persistence.get_by_key("system", "keep_key") + assert value == "keep_value" + + +def test_get_keys_by_prefix(client: TestClient, admin_token: str): + """Test that keys can be retrieved by prefix with authentication.""" + # Set several keys with a common prefix + for i in range(3): + client.post( + f"/api/v1/client_state/default/set_by_key?key=canvas_snapshot:snap{i}", + json=f"value{i}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + # Set a key without the prefix + client.post( + "/api/v1/client_state/default/set_by_key?key=other_key", + json="other_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Get keys by prefix + response = client.get( + "/api/v1/client_state/default/get_keys_by_prefix?prefix=canvas_snapshot:", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == status.HTTP_200_OK + keys = response.json() + assert len(keys) == 3 + assert "canvas_snapshot:snap0" in keys + assert "canvas_snapshot:snap1" in keys + assert "canvas_snapshot:snap2" in keys + assert "other_key" not in keys + + +def test_delete_by_key(client: TestClient, admin_token: str): + """Test that a specific key can be deleted without affecting other keys.""" + # Set two keys + client.post( + "/api/v1/client_state/default/set_by_key?key=keep_key", + json="keep_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + client.post( + "/api/v1/client_state/default/set_by_key?key=delete_key", + json="delete_value", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + # Delete only one key + delete_response = client.post( + "/api/v1/client_state/default/delete_by_key?key=delete_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert delete_response.status_code == status.HTTP_200_OK + + # Verify deleted key is gone + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=delete_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() is None + + # Verify other key still exists + get_response = client.get( + "/api/v1/client_state/default/get_by_key?key=keep_key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert get_response.json() == "keep_value" + + +def test_get_keys_by_prefix_isolation_between_users(client: TestClient, user1_token: str, user2_token: str): + """Test that get_keys_by_prefix is isolated between users.""" + # User 1 sets keys + client.post( + "/api/v1/client_state/default/set_by_key?key=snapshot:u1", + json="user1_data", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + + # User 2 sets keys + client.post( + "/api/v1/client_state/default/set_by_key?key=snapshot:u2", + json="user2_data", + headers={"Authorization": f"Bearer {user2_token}"}, + ) + + # User 1 should only see their own keys + response = client.get( + "/api/v1/client_state/default/get_keys_by_prefix?prefix=snapshot:", + headers={"Authorization": f"Bearer {user1_token}"}, + ) + keys = response.json() + assert "snapshot:u1" in keys + assert "snapshot:u2" not in keys From 368fe749c717aca979b30c31697ed3c4f27c541f Mon Sep 17 00:00:00 2001 From: Gohsuke Shimada Date: Fri, 10 Apr 2026 17:09:16 +0900 Subject: [PATCH 2/5] fix(ui): address review feedback for canvas snapshot feature - Preserve current modelBase on snapshot restore to prevent bbox desync with the active model (mirrors resetState pattern) - Exclude snapshot restore from undo history so it cannot be accidentally undone - Migrate manual fetch calls to RTKQ endpoints (clientState.ts) so snapshots go through the shared API transport layer with proper auth, session-expiry handling and sliding-window token refresh - Validate referenced images on restore and warn when some are missing - Detect incompatible (schema-changed) snapshots and show a specific error message instead of a generic failure toast - Disable snapshot restore while the canvas is staging to prevent entity ID conflicts with in-progress generations - Sort snapshot list by updated_at instead of rowid so re-saved snapshots appear at the top - Add pre-flight backend reachability check before image validation to avoid false "missing images" warnings when offline Co-Authored-By: Claude Opus 4.6 (1M context) --- .../client_state_persistence_sqlite.py | 2 +- invokeai/frontend/web/public/locales/en.json | 5 +- .../CanvasToolbarSnapshotMenuButton.tsx | 35 +++- .../controlLayers/hooks/useCanvasSnapshots.ts | 187 ++++++++++++------ .../controlLayers/store/canvasSlice.ts | 9 + .../src/services/api/endpoints/clientState.ts | 48 +++++ 6 files changed, 216 insertions(+), 70 deletions(-) create mode 100644 invokeai/frontend/web/src/services/api/endpoints/clientState.ts diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 8f5bf828572..7a0c0f9f4c9 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -50,7 +50,7 @@ def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: """ SELECT key FROM client_state WHERE user_id = ? AND key LIKE ? - ORDER BY rowid DESC + ORDER BY updated_at DESC """, (user_id, f"{prefix}%"), ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index d0d037cc88b..525fcd4f4c9 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -3010,7 +3010,10 @@ "snapshotDeleted": "Snapshot \"{{name}}\" deleted", "snapshotSaveFailed": "Failed to save snapshot", "snapshotRestoreFailed": "Failed to restore snapshot", - "snapshotDeleteFailed": "Failed to delete snapshot" + "snapshotDeleteFailed": "Failed to delete snapshot", + "snapshotMissingImages_one": "{{count}} image referenced by this snapshot no longer exists and will appear as a placeholder", + "snapshotMissingImages_other": "{{count}} images referenced by this snapshot no longer exist and will appear as placeholders", + "snapshotIncompatible": "This snapshot was created with a different version and is no longer compatible" } }, "upscaling": { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx index a55d9ae913a..9d21041c6e1 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx @@ -12,6 +12,7 @@ import { } from '@invoke-ai/ui-library'; import type { SnapshotInfo } from 'features/controlLayers/hooks/useCanvasSnapshots'; import { useCanvasSnapshots } from 'features/controlLayers/hooks/useCanvasSnapshots'; +import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { toast } from 'features/toast/toast'; import type { ChangeEvent, KeyboardEvent, MouseEvent } from 'react'; import { memo, useCallback, useState } from 'react'; @@ -23,10 +24,12 @@ const SnapshotItem = memo( snapshot, onRestore, onDelete, + isRestoreDisabled, }: { snapshot: SnapshotInfo; onRestore: (key: string, name: string) => void; onDelete: (e: MouseEvent, key: string, name: string) => void; + isRestoreDisabled: boolean; }) => { const handleClick = useCallback(() => { onRestore(snapshot.key, snapshot.name); @@ -40,7 +43,7 @@ const SnapshotItem = memo( ); return ( - + {snapshot.name} @@ -52,6 +55,7 @@ const SnapshotItem = memo( variant="ghost" colorScheme="error" onClick={handleDelete} + isDisabled={isRestoreDisabled} /> @@ -74,6 +78,7 @@ const getDefaultSnapshotName = (): string => { export const CanvasToolbarSnapshotMenuButton = memo(() => { const { t } = useTranslation(); const { snapshots, saveSnapshot, restoreSnapshot, deleteSnapshot } = useCanvasSnapshots(); + const isStaging = useCanvasIsStaging(); const [snapshotName, setSnapshotName] = useState(''); const onNameChange = useCallback((e: ChangeEvent) => { @@ -104,9 +109,23 @@ export const CanvasToolbarSnapshotMenuButton = memo(() => { const onRestore = useCallback( async (key: string, name: string) => { - const success = await restoreSnapshot(key); - if (success) { - toast({ title: t('controlLayers.snapshot.snapshotRestored', { name }), status: 'info' }); + const result = await restoreSnapshot(key); + if (result.success) { + if (result.missingImageCount && result.missingImageCount > 0) { + toast({ + title: t('controlLayers.snapshot.snapshotRestored', { name }), + description: t('controlLayers.snapshot.snapshotMissingImages', { count: result.missingImageCount }), + status: 'warning', + }); + } else { + toast({ title: t('controlLayers.snapshot.snapshotRestored', { name }), status: 'info' }); + } + } else if (result.error === 'incompatible') { + toast({ + title: t('controlLayers.snapshot.snapshotRestoreFailed'), + description: t('controlLayers.snapshot.snapshotIncompatible'), + status: 'error', + }); } else { toast({ title: t('controlLayers.snapshot.snapshotRestoreFailed'), status: 'error' }); } @@ -160,7 +179,13 @@ export const CanvasToolbarSnapshotMenuButton = memo(() => { {snapshots.map((snapshot) => ( - + ))} diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts index f078ef672a0..2c6cf0af9bd 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts @@ -2,31 +2,92 @@ import { logger } from 'app/logging/logger'; import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; import { canvasSnapshotRestored } from 'features/controlLayers/store/canvasSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import type { CanvasState } from 'features/controlLayers/store/types'; import { zCanvasState } from 'features/controlLayers/store/types'; -import { useCallback, useEffect, useState } from 'react'; +import { useCallback, useMemo } from 'react'; import { serializeError } from 'serialize-error'; -import { buildV1Url, getBaseUrl } from 'services/api'; +import { appInfoApi } from 'services/api/endpoints/appInfo'; +import { + clientStateApi, + useDeleteClientStateByKeyMutation, + useGetClientStateKeysByPrefixQuery, + useSetClientStateByKeyMutation, +} from 'services/api/endpoints/clientState'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; import type { JsonObject } from 'type-fest'; +import { z } from 'zod'; const log = logger('canvas'); const SNAPSHOT_PREFIX = 'canvas_snapshot:'; -const getAuthHeaders = (): Record => { - const headers: Record = {}; - if (typeof window !== 'undefined' && window.localStorage) { - const token = localStorage.getItem('auth_token'); - if (token) { - headers['Authorization'] = `Bearer ${token}`; +/** + * Collect all unique image_name references from a canvas state. + */ +const collectImageNames = (state: CanvasState): string[] => { + const names = new Set(); + + const entityGroups = [state.rasterLayers, state.controlLayers, state.inpaintMasks, state.regionalGuidance]; + for (const group of entityGroups) { + for (const entity of group.entities) { + for (const obj of entity.objects) { + if (obj.type === 'image' && 'image_name' in obj.image) { + names.add(obj.image.image_name); + } + } + } + } + + // Regional guidance reference images (IP Adapter / FLUX Redux) + for (const entity of state.regionalGuidance.entities) { + for (const ref of entity.referenceImages) { + if (ref.config.image && 'image_name' in ref.config.image) { + names.add(ref.config.image.image_name); + } } } - return headers; + + return [...names]; +}; + +/** + * Quick health check to determine if the backend is reachable. + * Uses the existing appInfoApi RTKQ endpoint for consistency. + */ +const isBackendReachable = async (dispatch: ReturnType): Promise => { + const req = dispatch(appInfoApi.endpoints.getAppVersion.initiate(undefined, { subscribe: false })); + try { + await req.unwrap(); + return true; + } catch { + return false; + } finally { + req.unsubscribe(); + } }; -const getUrl = (endpoint: string, query?: Record) => { - const baseUrl = getBaseUrl(); - const path = buildV1Url(`client_state/default/${endpoint}`, query); - return `${baseUrl}/${path}`; +/** + * Check which image_names still exist on the server. + * Returns the list of missing image names. If the backend is unreachable, + * skips all checks and returns an empty array to avoid false warnings. + */ +const findMissingImages = async ( + imageNames: string[], + dispatch: ReturnType +): Promise => { + // Pre-flight: verify backend is reachable before checking individual images + if (!(await isBackendReachable(dispatch))) { + log.warn('Backend unreachable — skipping missing image check'); + return []; + } + + const results = await Promise.all( + imageNames.map(async (name) => { + const dto = await getImageDTOSafe(name); + return dto === null ? name : null; + }) + ); + return results.filter((name): name is string => name !== null); }; export type SnapshotInfo = { @@ -34,29 +95,28 @@ export type SnapshotInfo = { name: string; }; +type RestoreResult = { + success: boolean; + missingImageCount?: number; + error?: 'incompatible' | 'not_found' | 'unknown'; +}; + export const useCanvasSnapshots = () => { const dispatch = useAppDispatch(); const store = useAppStore(); - const [snapshots, setSnapshots] = useState([]); - - const fetchSnapshots = useCallback(async () => { - try { - const url = getUrl('get_keys_by_prefix', { prefix: SNAPSHOT_PREFIX }); - const res = await fetch(url, { method: 'GET', headers: getAuthHeaders() }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } - const keys: string[] = await res.json(); - setSnapshots( - keys.map((key) => ({ - key, - name: key.slice(SNAPSHOT_PREFIX.length), - })) - ); - } catch (e) { - log.error({ error: serializeError(e) } as JsonObject, 'Failed to fetch snapshots'); - } - }, []); + + const { data: keys } = useGetClientStateKeysByPrefixQuery(SNAPSHOT_PREFIX); + const [setClientState] = useSetClientStateByKeyMutation(); + const [deleteClientState] = useDeleteClientStateByKeyMutation(); + + const snapshots: SnapshotInfo[] = useMemo( + () => + (keys ?? []).map((key) => ({ + key, + name: key.slice(SNAPSHOT_PREFIX.length), + })), + [keys] + ); const saveSnapshot = useCallback( async (name: string) => { @@ -64,41 +124,51 @@ export const useCanvasSnapshots = () => { const state = selectCanvasSlice(store.getState()); const value = JSON.stringify(state); const key = `${SNAPSHOT_PREFIX}${name}`; - const url = getUrl('set_by_key', { key }); - const res = await fetch(url, { - method: 'POST', - body: value, - headers: getAuthHeaders(), - }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } - await fetchSnapshots(); + await setClientState({ key, value }).unwrap(); return true; } catch (e) { log.error({ error: serializeError(e) } as JsonObject, 'Failed to save snapshot'); return false; } }, - [store, fetchSnapshots] + [store, setClientState] ); const restoreSnapshot = useCallback( - async (key: string) => { + async (key: string): Promise => { + const req = dispatch(clientStateApi.endpoints.getClientStateByKey.initiate(key, { subscribe: false })); try { - const url = getUrl('get_by_key', { key }); - const res = await fetch(url, { method: 'GET', headers: getAuthHeaders() }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); + const raw = await req.unwrap(); + if (raw === null) { + throw new Error('Snapshot data not found'); } - const raw = await res.json(); const parsed = JSON.parse(raw); const canvasState = zCanvasState.parse(parsed); + + // Check for missing images before restoring + const imageNames = collectImageNames(canvasState); + const missingImages = imageNames.length > 0 ? await findMissingImages(imageNames, dispatch) : []; + + if (missingImages.length > 0) { + log.warn( + { missingCount: missingImages.length, total: imageNames.length } as unknown as JsonObject, + 'Snapshot references images that no longer exist' + ); + } + dispatch(canvasSnapshotRestored(canvasState)); - return true; + return { success: true, missingImageCount: missingImages.length }; } catch (e) { log.error({ error: serializeError(e) } as JsonObject, 'Failed to restore snapshot'); - return false; + // Distinguish Zod validation errors (incompatible snapshot) from other failures + const isZodError = e instanceof z.ZodError; + const isNotFound = e instanceof Error && e.message === 'Snapshot data not found'; + return { + success: false, + error: isZodError ? 'incompatible' : isNotFound ? 'not_found' : 'unknown', + }; + } finally { + req.unsubscribe(); } }, [dispatch] @@ -107,25 +177,16 @@ export const useCanvasSnapshots = () => { const deleteSnapshot = useCallback( async (key: string) => { try { - const url = getUrl('delete_by_key', { key }); - const res = await fetch(url, { method: 'POST', headers: getAuthHeaders() }); - if (!res.ok) { - throw new Error(`Response status: ${res.status}`); - } - await fetchSnapshots(); + await deleteClientState(key).unwrap(); return true; } catch (e) { log.error({ error: serializeError(e) } as JsonObject, 'Failed to delete snapshot'); return false; } }, - [fetchSnapshots] + [deleteClientState] ); - useEffect(() => { - fetchSnapshots(); - }, [fetchSnapshots]); - return { snapshots, saveSnapshot, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 385cd45b198..bfdea7b1de5 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1716,7 +1716,12 @@ const slice = createSlice({ state.inpaintMasks = snapshot.inpaintMasks; state.rasterLayers = snapshot.rasterLayers; state.regionalGuidance = snapshot.regionalGuidance; + // Restore bbox from snapshot but preserve the current modelBase to avoid desync + // with the currently selected model (same pattern as resetState). + const currentModelBase = state.bbox.modelBase; state.bbox = snapshot.bbox; + state.bbox.modelBase = currentModelBase; + syncScaledSize(state); state.selectedEntityIdentifier = snapshot.selectedEntityIdentifier; state.bookmarkedEntityIdentifier = snapshot.bookmarkedEntityIdentifier; return state; @@ -1905,6 +1910,10 @@ const canvasUndoableConfig: UndoableOptions = { if (!action.type.startsWith(slice.name)) { return false; } + // Snapshot restore replaces the canvas state and should not be undoable + if (action.type === canvasSnapshotRestored.type) { + return false; + } // Throttle rapid actions of the same type filter = actionsThrottlingFilter(action); return filter; diff --git a/invokeai/frontend/web/src/services/api/endpoints/clientState.ts b/invokeai/frontend/web/src/services/api/endpoints/clientState.ts new file mode 100644 index 00000000000..5d3cc96d226 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/endpoints/clientState.ts @@ -0,0 +1,48 @@ +import { api, buildV1Url, LIST_TAG } from '..'; + +/** + * Builds an endpoint URL for the client_state router. + * The queue_id path parameter is kept as 'default' for backwards compatibility. + */ +const buildClientStateUrl = (path: string, query?: Record) => + buildV1Url(`client_state/default/${path}`, query); + +export const clientStateApi = api.injectEndpoints({ + endpoints: (build) => ({ + getClientStateKeysByPrefix: build.query({ + query: (prefix) => ({ + url: buildClientStateUrl('get_keys_by_prefix', { prefix }), + method: 'GET', + }), + providesTags: [{ type: 'ClientState', id: LIST_TAG }, 'FetchOnReconnect'], + }), + getClientStateByKey: build.query({ + query: (key) => ({ + url: buildClientStateUrl('get_by_key', { key }), + method: 'GET', + }), + }), + setClientStateByKey: build.mutation({ + query: ({ key, value }) => ({ + url: buildClientStateUrl('set_by_key', { key }), + method: 'POST', + // Send raw string body — the backend expects Body(...) as a plain string, + // not JSON-encoded. Setting Content-Type to text/plain prevents fetchBaseQuery + // from JSON-stringifying the body. + headers: { 'Content-Type': 'text/plain' }, + body: value, + }), + invalidatesTags: [{ type: 'ClientState', id: LIST_TAG }], + }), + deleteClientStateByKey: build.mutation({ + query: (key) => ({ + url: buildClientStateUrl('delete_by_key', { key }), + method: 'POST', + }), + invalidatesTags: [{ type: 'ClientState', id: LIST_TAG }], + }), + }), +}); + +export const { useGetClientStateKeysByPrefixQuery, useSetClientStateByKeyMutation, useDeleteClientStateByKeyMutation } = + clientStateApi; From ad6cb17dc5ef6de35834ece4a39956cf002f76da Mon Sep 17 00:00:00 2001 From: Gohsuke Shimada Date: Wed, 15 Apr 2026 15:26:02 +0900 Subject: [PATCH 3/5] refactor(ui): consolidate collectImageNames to shared canvasProjectFile utility Remove the local collectImageNames from useCanvasSnapshots and reuse the shared, more comprehensive version from canvasProjectFile.ts that was introduced by the canvas project save/load feature (#8917). Snapshots don't include global ref images, so an empty array is passed for that parameter. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../controlLayers/hooks/useCanvasSnapshots.ts | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts index 2c6cf0af9bd..6352e23f4f0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasSnapshots.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; import { canvasSnapshotRestored } from 'features/controlLayers/store/canvasSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; -import type { CanvasState } from 'features/controlLayers/store/types'; import { zCanvasState } from 'features/controlLayers/store/types'; +import { collectImageNames } from 'features/controlLayers/util/canvasProjectFile'; import { useCallback, useMemo } from 'react'; import { serializeError } from 'serialize-error'; import { appInfoApi } from 'services/api/endpoints/appInfo'; @@ -21,35 +21,6 @@ const log = logger('canvas'); const SNAPSHOT_PREFIX = 'canvas_snapshot:'; -/** - * Collect all unique image_name references from a canvas state. - */ -const collectImageNames = (state: CanvasState): string[] => { - const names = new Set(); - - const entityGroups = [state.rasterLayers, state.controlLayers, state.inpaintMasks, state.regionalGuidance]; - for (const group of entityGroups) { - for (const entity of group.entities) { - for (const obj of entity.objects) { - if (obj.type === 'image' && 'image_name' in obj.image) { - names.add(obj.image.image_name); - } - } - } - } - - // Regional guidance reference images (IP Adapter / FLUX Redux) - for (const entity of state.regionalGuidance.entities) { - for (const ref of entity.referenceImages) { - if (ref.config.image && 'image_name' in ref.config.image) { - names.add(ref.config.image.image_name); - } - } - } - - return [...names]; -}; - /** * Quick health check to determine if the backend is reachable. * Uses the existing appInfoApi RTKQ endpoint for consistency. @@ -146,12 +117,26 @@ export const useCanvasSnapshots = () => { const canvasState = zCanvasState.parse(parsed); // Check for missing images before restoring - const imageNames = collectImageNames(canvasState); - const missingImages = imageNames.length > 0 ? await findMissingImages(imageNames, dispatch) : []; + // Reuse the shared collectImageNames from canvasProjectFile — snapshots only + // contain canvas entities (no global ref images), so we pass an empty array. + const imageNames = collectImageNames( + { + rasterLayers: canvasState.rasterLayers.entities, + controlLayers: canvasState.controlLayers.entities, + inpaintMasks: canvasState.inpaintMasks.entities, + regionalGuidance: canvasState.regionalGuidance.entities, + bbox: canvasState.bbox, + selectedEntityIdentifier: canvasState.selectedEntityIdentifier, + bookmarkedEntityIdentifier: canvasState.bookmarkedEntityIdentifier, + }, + [] + ); + const imageNamesList = [...imageNames]; + const missingImages = imageNamesList.length > 0 ? await findMissingImages(imageNamesList, dispatch) : []; if (missingImages.length > 0) { log.warn( - { missingCount: missingImages.length, total: imageNames.length } as unknown as JsonObject, + { missingCount: missingImages.length, total: imageNamesList.length } as unknown as JsonObject, 'Snapshot references images that no longer exist' ); } From 5e70ad85b287fecc3f3ea4c694e66d38f4da9695 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Thu, 16 Apr 2026 06:25:51 +0200 Subject: [PATCH 4/5] fix(canvas-snapshots): escape LIKE wildcards, warn on overwrite, fix default name chars - Escape %, _, \ in client_state prefix query to prevent accidental wildcard matching - Confirm before overwriting an existing snapshot instead of silently replacing it - Use - instead of / and : in the default snapshot name to avoid key separator clashes --- .../client_state_persistence_sqlite.py | 7 +- invokeai/frontend/web/public/locales/en.json | 5 +- .../CanvasToolbarSnapshotMenuButton.tsx | 146 +++++++++++------- 3 files changed, 102 insertions(+), 56 deletions(-) diff --git a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py index 7a0c0f9f4c9..7605de829d9 100644 --- a/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py +++ b/invokeai/app/services/client_state_persistence/client_state_persistence_sqlite.py @@ -45,14 +45,17 @@ def get_by_key(self, user_id: str, key: str) -> str | None: return row[0] def get_keys_by_prefix(self, user_id: str, prefix: str) -> list[str]: + # Escape LIKE wildcards (%, _) and the escape char itself so callers can pass + # arbitrary strings as a literal prefix without accidental pattern matching. + escaped_prefix = prefix.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") with self._db.transaction() as cursor: cursor.execute( """ SELECT key FROM client_state - WHERE user_id = ? AND key LIKE ? + WHERE user_id = ? AND key LIKE ? ESCAPE '\\' ORDER BY updated_at DESC """, - (user_id, f"{prefix}%"), + (user_id, f"{escaped_prefix}%"), ) return [row[0] for row in cursor.fetchall()] diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 4188367cffc..46d97820bf7 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -3067,7 +3067,10 @@ "snapshotDeleteFailed": "Failed to delete snapshot", "snapshotMissingImages_one": "{{count}} image referenced by this snapshot no longer exists and will appear as a placeholder", "snapshotMissingImages_other": "{{count}} images referenced by this snapshot no longer exist and will appear as placeholders", - "snapshotIncompatible": "This snapshot was created with a different version and is no longer compatible" + "snapshotIncompatible": "This snapshot was created with a different version and is no longer compatible", + "overwriteSnapshotTitle": "Overwrite snapshot?", + "overwriteSnapshotMessage": "A snapshot named \"{{name}}\" already exists. Overwrite it?", + "overwrite": "Overwrite" } }, "upscaling": { diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx index 9d21041c6e1..6fe3b23a82f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarSnapshotMenuButton.tsx @@ -1,4 +1,5 @@ import { + ConfirmationAlertDialog, Flex, IconButton, Input, @@ -9,6 +10,7 @@ import { MenuItem, MenuList, Text, + useDisclosure, } from '@invoke-ai/ui-library'; import type { SnapshotInfo } from 'features/controlLayers/hooks/useCanvasSnapshots'; import { useCanvasSnapshots } from 'features/controlLayers/hooks/useCanvasSnapshots'; @@ -72,7 +74,7 @@ const getDefaultSnapshotName = (): string => { const d = String(now.getDate()).padStart(2, '0'); const h = String(now.getHours()).padStart(2, '0'); const mi = String(now.getMinutes()).padStart(2, '0'); - return `${y}/${mo}/${d} ${h}:${mi}`; + return `${y}-${mo}-${d} ${h}-${mi}`; }; export const CanvasToolbarSnapshotMenuButton = memo(() => { @@ -80,21 +82,47 @@ export const CanvasToolbarSnapshotMenuButton = memo(() => { const { snapshots, saveSnapshot, restoreSnapshot, deleteSnapshot } = useCanvasSnapshots(); const isStaging = useCanvasIsStaging(); const [snapshotName, setSnapshotName] = useState(''); + const overwriteDialog = useDisclosure(); + const [pendingOverwriteName, setPendingOverwriteName] = useState(null); const onNameChange = useCallback((e: ChangeEvent) => { setSnapshotName(e.target.value); }, []); + const doSave = useCallback( + async (name: string) => { + const success = await saveSnapshot(name); + if (success) { + toast({ title: t('controlLayers.snapshot.snapshotSaved', { name }), status: 'info' }); + setSnapshotName(''); + } else { + toast({ title: t('controlLayers.snapshot.snapshotSaveFailed'), status: 'error' }); + } + }, + [saveSnapshot, t] + ); + const onSave = useCallback(async () => { const name = snapshotName.trim() || getDefaultSnapshotName(); - const success = await saveSnapshot(name); - if (success) { - toast({ title: t('controlLayers.snapshot.snapshotSaved', { name }), status: 'info' }); - setSnapshotName(''); - } else { - toast({ title: t('controlLayers.snapshot.snapshotSaveFailed'), status: 'error' }); + if (snapshots.some((s) => s.name === name)) { + setPendingOverwriteName(name); + overwriteDialog.onOpen(); + return; + } + await doSave(name); + }, [snapshotName, snapshots, doSave, overwriteDialog]); + + const onConfirmOverwrite = useCallback(() => { + if (pendingOverwriteName) { + doSave(pendingOverwriteName); + setPendingOverwriteName(null); } - }, [snapshotName, saveSnapshot, t]); + }, [pendingOverwriteName, doSave]); + + const onCloseOverwriteDialog = useCallback(() => { + setPendingOverwriteName(null); + overwriteDialog.onClose(); + }, [overwriteDialog]); const onKeyDown = useCallback( (e: KeyboardEvent) => { @@ -147,51 +175,63 @@ export const CanvasToolbarSnapshotMenuButton = memo(() => { ); return ( - - } - variant="link" - alignSelf="stretch" - /> - - - - - } - size="sm" - onClick={onSave} - /> - - - {snapshots.length > 0 && ( - <> - - - {snapshots.map((snapshot) => ( - - ))} - - - )} - - + <> + + } + variant="link" + alignSelf="stretch" + /> + + + + + } + size="sm" + onClick={onSave} + /> + + + {snapshots.length > 0 && ( + <> + + + {snapshots.map((snapshot) => ( + + ))} + + + )} + + + + {t('controlLayers.snapshot.overwriteSnapshotMessage', { name: pendingOverwriteName ?? '' })} + + ); }); From 0a71e64b700454e7cc4f0ccfe144bd10bfc8580a Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Thu, 16 Apr 2026 06:32:43 +0200 Subject: [PATCH 5/5] fix(canvas): align canvasProjectRecalled with snapshot restore pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Preserve modelBase, call syncScaledSize, and exclude from undo history to avoid bbox/model desync on project load — same pattern already used by canvasSnapshotRestored. --- .../web/src/features/controlLayers/store/canvasSlice.ts | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 610a3b159ef..1ddec3f2033 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1761,7 +1761,12 @@ const slice = createSlice({ state.controlLayers.entities = controlLayers; state.inpaintMasks.entities = inpaintMasks; state.regionalGuidance.entities = regionalGuidance; + // Preserve the current modelBase to avoid desync with the currently selected model + // (same pattern as canvasSnapshotRestored and resetState). + const currentModelBase = state.bbox.modelBase; state.bbox = bbox; + state.bbox.modelBase = currentModelBase; + syncScaledSize(state); state.selectedEntityIdentifier = selectedEntityIdentifier; state.bookmarkedEntityIdentifier = bookmarkedEntityIdentifier; return state; @@ -1968,8 +1973,8 @@ const canvasUndoableConfig: UndoableOptions = { if (!action.type.startsWith(slice.name)) { return false; } - // Snapshot restore replaces the canvas state and should not be undoable - if (action.type === canvasSnapshotRestored.type) { + // Snapshot restore and project load replace the canvas state and should not be undoable + if (action.type === canvasSnapshotRestored.type || action.type === canvasProjectRecalled.type) { return false; } // Throttle rapid actions of the same type