diff --git a/src/platform/assets/services/assetService.test.ts b/src/platform/assets/services/assetService.test.ts index 2a8c525b077..e718cb3d727 100644 --- a/src/platform/assets/services/assetService.test.ts +++ b/src/platform/assets/services/assetService.test.ts @@ -1,6 +1,13 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { assetService } from '@/platform/assets/services/assetService' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { + MISSING_TAG, + assetService, + isBlake3AssetHash, + toBlake3AssetHash +} from '@/platform/assets/services/assetService' +import { api } from '@/scripts/api' const mockDistributionState = vi.hoisted(() => ({ isCloud: false })) const mockSettingStoreGet = vi.hoisted(() => vi.fn(() => false)) @@ -40,6 +47,32 @@ vi.mock('@/i18n', () => ({ st: vi.fn((_key: string, fallback: string) => fallback) })) +const fetchApiMock = vi.mocked(api.fetchApi) + +const validBlake3Hash = + '1111111111111111111111111111111111111111111111111111111111111111' +const validBlake3AssetHash = `blake3:${validBlake3Hash}` + +function buildResponse( + body: unknown, + init: { ok?: boolean; status?: number } = {} +): Response { + return { + ok: init.ok ?? true, + status: init.status ?? 200, + json: vi.fn().mockResolvedValue(body) + } as unknown as Response +} + +function validAsset(overrides: Partial = {}): AssetItem { + return { + id: 'asset-1', + name: 'model.safetensors', + tags: ['models'], + ...overrides + } +} + describe(assetService.shouldUseAssetBrowser, () => { beforeEach(() => { vi.clearAllMocks() @@ -104,3 +137,763 @@ describe(assetService.shouldUseAssetBrowser, () => { ).toBe(false) }) }) + +describe(assetService.getAssetMetadata, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('throws a localized message when the response is not ok', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ code: 'FILE_TOO_LARGE' }, { ok: false, status: 413 }) + ) + + await expect( + assetService.getAssetMetadata('https://example.com/model.safetensors') + ).rejects.toThrow('File too large') + }) + + it('throws a localized message when validation reports is_valid=false', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ + content_length: 100, + final_url: 'https://example.com/model.safetensors', + validation: { + is_valid: false, + errors: [{ code: 'UNSAFE_VIRUS_SCAN', message: 'bad', field: 'file' }] + } + }) + ) + + await expect( + assetService.getAssetMetadata('https://example.com/model.safetensors') + ).rejects.toThrow('Unsafe virus scan') + }) + + it('encodes the URL in the query string', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ + content_length: 1, + final_url: 'https://example.com/x' + }) + ) + + await assetService.getAssetMetadata('https://example.com/foo bar?x=1') + + expect(fetchApiMock).toHaveBeenCalledWith( + expect.stringContaining( + '/assets/remote-metadata?url=' + + encodeURIComponent('https://example.com/foo bar?x=1') + ) + ) + }) +}) + +describe(isBlake3AssetHash, () => { + it('accepts only prefixed 64-character blake3 hashes', () => { + expect(isBlake3AssetHash(validBlake3AssetHash)).toBe(true) + expect(isBlake3AssetHash('BLAKE3:' + validBlake3Hash.toUpperCase())).toBe( + true + ) + expect(isBlake3AssetHash('blake3:abc')).toBe(false) + expect(isBlake3AssetHash(validBlake3Hash)).toBe(false) + }) +}) + +describe(toBlake3AssetHash, () => { + it('normalizes 64-character blake3 hex values to asset hashes', () => { + expect(toBlake3AssetHash(validBlake3Hash)).toBe(validBlake3AssetHash) + expect(toBlake3AssetHash('abc')).toBeNull() + expect(toBlake3AssetHash(undefined)).toBeNull() + }) +}) + +describe(assetService.uploadAssetFromUrl, () => { + beforeEach(() => { + vi.clearAllMocks() + assetService.invalidateInputAssetsIncludingPublic() + }) + + it('does not invalidate cached input assets when the upload response is invalid', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse({ id: 'missing-name' })) + + await assetService.getInputAssetsIncludingPublic() + await expect( + assetService.uploadAssetFromUrl({ + url: 'https://example.com/input.png', + name: 'input.png', + tags: ['input'] + }) + ).rejects.toThrow('Failed to upload asset') + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + consoleSpy.mockRestore() + }) + + it('requires upload responses to include created_new', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce( + buildResponse(validAsset({ id: 'uploaded-input', tags: ['input'] })) + ) + + await assetService.getInputAssetsIncludingPublic() + await expect( + assetService.uploadAssetFromUrl({ + url: 'https://example.com/input.png', + name: 'input.png', + tags: ['input'] + }) + ).rejects.toThrow('Failed to upload asset') + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + consoleSpy.mockRestore() + }) + + it('returns validated upload responses with created_new', async () => { + const uploadedAsset = { + ...validAsset({ id: 'uploaded-input', tags: ['input'] }), + created_new: true + } + fetchApiMock.mockResolvedValueOnce(buildResponse(uploadedAsset)) + + await expect( + assetService.uploadAssetFromUrl({ + url: 'https://example.com/input.png', + name: 'input.png', + tags: ['input'] + }) + ).resolves.toEqual(uploadedAsset) + }) +}) + +describe(assetService.uploadAssetFromBase64, () => { + beforeEach(() => { + vi.clearAllMocks() + assetService.invalidateInputAssetsIncludingPublic() + }) + + it('throws before calling the network when data is not a data URL', async () => { + await expect( + assetService.uploadAssetFromBase64({ + data: 'not-a-data-url', + name: 'image.png' + }) + ).rejects.toThrow('Invalid data URL') + + expect(fetchApiMock).not.toHaveBeenCalled() + }) + + it('does not invalidate cached input assets when the upload response is invalid', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + const fetchSpy = vi + .spyOn(globalThis, 'fetch') + .mockResolvedValueOnce(new Response('hello')) + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse({ id: 'missing-name' })) + + await assetService.getInputAssetsIncludingPublic() + await expect( + assetService.uploadAssetFromBase64({ + data: 'data:text/plain;base64,aGVsbG8=', + name: 'input.txt', + tags: ['input'] + }) + ).rejects.toThrow('Failed to upload asset') + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + fetchSpy.mockRestore() + consoleSpy.mockRestore() + }) + + it('rejects upload responses with a non-boolean created_new', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + const fetchSpy = vi + .spyOn(globalThis, 'fetch') + .mockResolvedValueOnce(new Response('hello')) + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce( + buildResponse({ + ...validAsset({ id: 'uploaded-input', tags: ['input'] }), + created_new: 'true' + }) + ) + + await assetService.getInputAssetsIncludingPublic() + await expect( + assetService.uploadAssetFromBase64({ + data: 'data:text/plain;base64,aGVsbG8=', + name: 'input.txt', + tags: ['input'] + }) + ).rejects.toThrow('Failed to upload asset') + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + fetchSpy.mockRestore() + consoleSpy.mockRestore() + }) +}) + +describe(assetService.uploadAssetAsync, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('returns an async result when the server responds 202', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse( + { task_id: 'task-1', status: 'running' }, + { ok: true, status: 202 } + ) + ) + + const result = await assetService.uploadAssetAsync({ + source_url: 'https://example.com/model.safetensors' + }) + + expect(result).toEqual({ + type: 'async', + task: { task_id: 'task-1', status: 'running' } + }) + }) + + it('returns a sync result when the server responds 200', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse(validAsset({ id: 'asset-2', name: 'sync.safetensors' })) + ) + + const result = await assetService.uploadAssetAsync({ + source_url: 'https://example.com/model.safetensors' + }) + + expect(result).toEqual({ + type: 'sync', + asset: expect.objectContaining({ id: 'asset-2' }) + }) + }) +}) + +describe(assetService.deleteAsset, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('throws an error containing the status code when the response is not ok', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse(null, { ok: false, status: 503 }) + ) + + await expect(assetService.deleteAsset('asset-1')).rejects.toThrow(/503/) + }) + + it('issues a DELETE to the asset endpoint when the response is ok', async () => { + fetchApiMock.mockResolvedValueOnce(buildResponse(null)) + + await assetService.deleteAsset('asset-1') + + expect(fetchApiMock).toHaveBeenCalledWith( + '/assets/asset-1', + expect.objectContaining({ method: 'DELETE' }) + ) + }) +}) + +describe(assetService.getAssetModelFolders, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('filters out missing-tagged assets and blacklisted directories, returning alphabetical unique folders without include_public', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ + assets: [ + validAsset({ id: 'a', tags: ['models', 'loras'] }), + validAsset({ id: 'b', tags: ['models', 'checkpoints'] }), + validAsset({ id: 'c', tags: ['models', 'configs'] }), + validAsset({ id: 'd', tags: ['models', 'missing', 'controlnet'] }), + validAsset({ id: 'e', tags: ['models', 'loras'] }) + ] + }) + ) + + const folders = await assetService.getAssetModelFolders() + + expect(folders).toEqual([ + { name: 'checkpoints', folders: [] }, + { name: 'loras', folders: [] } + ]) + + const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string + const params = new URL(requestedUrl, 'http://localhost').searchParams + expect(params.has('include_public')).toBe(false) + }) +}) + +describe(assetService.updateAsset, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('throws when the response body fails schema validation', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ name: 'no-id-field.safetensors' }) + ) + + await expect( + assetService.updateAsset('asset-1', { name: 'renamed.safetensors' }) + ).rejects.toThrow(/Invalid response/) + }) + + it('PUTs the JSON payload and returns the parsed asset', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse(validAsset({ id: 'asset-1', name: 'renamed.safetensors' })) + ) + + const result = await assetService.updateAsset('asset-1', { + name: 'renamed.safetensors' + }) + + expect(result).toEqual( + expect.objectContaining({ id: 'asset-1', name: 'renamed.safetensors' }) + ) + expect(fetchApiMock).toHaveBeenCalledWith( + '/assets/asset-1', + expect.objectContaining({ + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ name: 'renamed.safetensors' }) + }) + ) + }) +}) + +describe(assetService.getAssetsByTag, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('forwards include_public=true by default and excludes missing-tagged assets', async () => { + fetchApiMock.mockResolvedValueOnce( + buildResponse({ + assets: [ + validAsset({ id: 'visible', tags: ['input'] }), + validAsset({ id: 'hidden', tags: ['input', 'missing'] }) + ] + }) + ) + + const assets = await assetService.getAssetsByTag('input') + + expect(assets.map((a) => a.id)).toEqual(['visible']) + + const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string + const params = new URL(requestedUrl, 'http://localhost').searchParams + expect(params.get('include_public')).toBe('true') + }) +}) + +describe(assetService.getAllAssetsByTag, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('paginates tagged asset requests with include_public=true', async () => { + fetchApiMock + .mockResolvedValueOnce( + buildResponse({ + assets: [ + validAsset({ id: 'a', tags: ['input'] }), + validAsset({ id: 'b', tags: ['input'] }) + ] + }) + ) + .mockResolvedValueOnce( + buildResponse({ + assets: [validAsset({ id: 'c', tags: ['input'] })] + }) + ) + + const assets = await assetService.getAllAssetsByTag('input', true, { + limit: 2 + }) + + expect(assets.map((a) => a.id)).toEqual(['a', 'b', 'c']) + + const firstUrl = fetchApiMock.mock.calls[0]?.[0] as string + const firstParams = new URL(firstUrl, 'http://localhost').searchParams + expect(firstParams.get('include_public')).toBe('true') + expect(firstParams.get('limit')).toBe('2') + expect(firstParams.has('offset')).toBe(false) + + const secondUrl = fetchApiMock.mock.calls[1]?.[0] as string + const secondParams = new URL(secondUrl, 'http://localhost').searchParams + expect(secondParams.get('include_public')).toBe('true') + expect(secondParams.get('limit')).toBe('2') + expect(secondParams.get('offset')).toBe('2') + }) + + it('paginates from raw response size before filtering missing-tagged assets', async () => { + fetchApiMock + .mockResolvedValueOnce( + buildResponse({ + assets: [ + validAsset({ id: 'visible', tags: ['input'] }), + validAsset({ id: 'hidden', tags: ['input', MISSING_TAG] }) + ] + }) + ) + .mockResolvedValueOnce( + buildResponse({ + assets: [validAsset({ id: 'later-public', tags: ['input'] })] + }) + ) + + const assets = await assetService.getAllAssetsByTag('input', true, { + limit: 2 + }) + + expect(assets.map((a) => a.id)).toEqual(['visible', 'later-public']) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + + const secondUrl = fetchApiMock.mock.calls[1]?.[0] + if (typeof secondUrl !== 'string') { + throw new Error('Expected a second asset request URL') + } + const secondParams = new URL(secondUrl, 'http://localhost').searchParams + expect(secondParams.get('offset')).toBe('2') + }) + + it('honors has_more when walking tagged asset pages', async () => { + fetchApiMock + .mockResolvedValueOnce( + buildResponse({ + assets: [ + validAsset({ id: 'first', tags: ['input'] }), + validAsset({ id: 'second', tags: ['input'] }) + ], + has_more: true + }) + ) + .mockResolvedValueOnce( + buildResponse({ + assets: [validAsset({ id: 'later-public', tags: ['input'] })], + has_more: false + }) + ) + + const assets = await assetService.getAllAssetsByTag('input', true, { + limit: 3 + }) + + expect(assets.map((a) => a.id)).toEqual(['first', 'second', 'later-public']) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + + const secondUrl = fetchApiMock.mock.calls[1]?.[0] + if (typeof secondUrl !== 'string') { + throw new Error('Expected a second asset request URL') + } + const secondParams = new URL(secondUrl, 'http://localhost').searchParams + expect(secondParams.get('offset')).toBe('2') + }) + + it('passes abort signals through paginated requests', async () => { + const controller = new AbortController() + fetchApiMock.mockResolvedValueOnce( + buildResponse({ + assets: [validAsset({ id: 'a', tags: ['input'] })] + }) + ) + + await assetService.getAllAssetsByTag('input', true, { + limit: 2, + signal: controller.signal + }) + + expect(fetchApiMock).toHaveBeenCalledWith(expect.any(String), { + signal: controller.signal + }) + }) + + it('stops pagination when aborted between pages', async () => { + const controller = new AbortController() + fetchApiMock.mockImplementationOnce(async () => { + controller.abort() + return buildResponse({ + assets: [ + validAsset({ id: 'a', tags: ['input'] }), + validAsset({ id: 'b', tags: ['input'] }) + ] + }) + }) + + await expect( + assetService.getAllAssetsByTag('input', true, { + limit: 2, + signal: controller.signal + }) + ).rejects.toMatchObject({ name: 'AbortError' }) + + expect(fetchApiMock).toHaveBeenCalledOnce() + }) +}) + +describe(assetService.getInputAssetsIncludingPublic, () => { + beforeEach(() => { + vi.clearAllMocks() + assetService.invalidateInputAssetsIncludingPublic() + }) + + it('loads input assets with public assets included and reuses the cache', async () => { + const assets = [ + validAsset({ id: 'user-input', tags: ['input'] }), + validAsset({ id: 'public-input', tags: ['input'], is_immutable: true }) + ] + fetchApiMock.mockResolvedValueOnce(buildResponse({ assets })) + + const first = await assetService.getInputAssetsIncludingPublic() + const second = await assetService.getInputAssetsIncludingPublic() + + expect(first).toEqual(assets) + expect(second).toBe(first) + expect(fetchApiMock).toHaveBeenCalledOnce() + + const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string + const params = new URL(requestedUrl, 'http://localhost').searchParams + expect(params.get('include_public')).toBe('true') + expect(params.get('limit')).toBe('500') + }) + + it('fetches fresh input assets after explicit invalidation', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })] + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse({ assets: freshAssets })) + + await assetService.getInputAssetsIncludingPublic() + assetService.invalidateInputAssetsIncludingPublic() + const refreshed = await assetService.getInputAssetsIncludingPublic() + + expect(refreshed).toEqual(freshAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + }) + + it('does not let one caller abort the shared input asset load for other callers', async () => { + const firstController = new AbortController() + const secondController = new AbortController() + const assets = [validAsset({ id: 'public-input', tags: ['input'] })] + let resolveResponse!: (response: Response) => void + let serviceSignal: AbortSignal | undefined + fetchApiMock.mockImplementationOnce(async (_url, options) => { + serviceSignal = options?.signal ?? undefined + return await new Promise((resolve) => { + resolveResponse = resolve + }) + }) + + const first = assetService.getInputAssetsIncludingPublic( + firstController.signal + ) + const second = assetService.getInputAssetsIncludingPublic( + secondController.signal + ) + firstController.abort() + + await expect(first).rejects.toMatchObject({ name: 'AbortError' }) + expect(serviceSignal).toBeUndefined() + + resolveResponse(buildResponse({ assets })) + + await expect(second).resolves.toEqual(assets) + expect(fetchApiMock).toHaveBeenCalledOnce() + }) + + it('keeps the shared input asset load alive after all callers abort', async () => { + const firstController = new AbortController() + const secondController = new AbortController() + const assets = [validAsset({ id: 'public-input', tags: ['input'] })] + let resolveResponse!: (response: Response) => void + fetchApiMock.mockImplementationOnce( + async () => + await new Promise((resolve) => { + resolveResponse = resolve + }) + ) + + const first = assetService.getInputAssetsIncludingPublic( + firstController.signal + ) + const second = assetService.getInputAssetsIncludingPublic( + secondController.signal + ) + firstController.abort() + secondController.abort() + + await expect(first).rejects.toMatchObject({ name: 'AbortError' }) + await expect(second).rejects.toMatchObject({ name: 'AbortError' }) + + resolveResponse(buildResponse({ assets })) + await Promise.resolve() + + await expect(assetService.getInputAssetsIncludingPublic()).resolves.toEqual( + assets + ) + expect(fetchApiMock).toHaveBeenCalledOnce() + }) + + it('does not abort in-flight input asset loads when invalidated', async () => { + const assets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })] + let resolveResponse!: (response: Response) => void + fetchApiMock + .mockImplementationOnce( + async () => + await new Promise((resolve) => { + resolveResponse = resolve + }) + ) + .mockResolvedValueOnce(buildResponse({ assets: freshAssets })) + + const inFlight = assetService.getInputAssetsIncludingPublic() + assetService.invalidateInputAssetsIncludingPublic() + + resolveResponse(buildResponse({ assets })) + + await expect(inFlight).resolves.toEqual(assets) + await expect(assetService.getInputAssetsIncludingPublic()).resolves.toEqual( + freshAssets + ) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + }) + + it('invalidates cached input assets after deleting an asset', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })] + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse(null)) + .mockResolvedValueOnce(buildResponse({ assets: freshAssets })) + + await assetService.getInputAssetsIncludingPublic() + await assetService.deleteAsset('stale-input') + const refreshed = await assetService.getInputAssetsIncludingPublic() + + expect(refreshed).toEqual(freshAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(3) + expect(fetchApiMock.mock.calls[1]).toEqual([ + '/assets/stale-input', + expect.objectContaining({ method: 'DELETE' }) + ]) + }) + + it('invalidates cached input assets after an input asset upload', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + const uploadedAsset = validAsset({ id: 'uploaded-input', tags: ['input'] }) + const freshAssets = [uploadedAsset] + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse(uploadedAsset)) + .mockResolvedValueOnce(buildResponse({ assets: freshAssets })) + + await assetService.getInputAssetsIncludingPublic() + await assetService.uploadAssetAsync({ + source_url: 'https://example.com/input.png', + tags: ['input'] + }) + const refreshed = await assetService.getInputAssetsIncludingPublic() + + expect(refreshed).toEqual(freshAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(3) + }) + + it('does not invalidate cached input assets for pending async input uploads', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce( + buildResponse( + { task_id: 'task-1', status: 'running' }, + { ok: true, status: 202 } + ) + ) + + await assetService.getInputAssetsIncludingPublic() + await assetService.uploadAssetAsync({ + source_url: 'https://example.com/input.png', + tags: ['input'] + }) + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + }) + + it('does not invalidate cached input assets for non-input uploads', async () => { + const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })] + fetchApiMock + .mockResolvedValueOnce(buildResponse({ assets: staleAssets })) + .mockResolvedValueOnce(buildResponse(validAsset({ tags: ['models'] }))) + + await assetService.getInputAssetsIncludingPublic() + await assetService.uploadAssetAsync({ + source_url: 'https://example.com/model.safetensors', + tags: ['models'] + }) + const cached = await assetService.getInputAssetsIncludingPublic() + + expect(cached).toEqual(staleAssets) + expect(fetchApiMock).toHaveBeenCalledTimes(2) + }) +}) + +describe(assetService.checkAssetHash, () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it.each([ + [200, 'exists'], + [404, 'missing'], + [400, 'invalid'] + ] as const)('maps %s responses to %s', async (status, expected) => { + const hash = + 'blake3:1111111111111111111111111111111111111111111111111111111111111111' + fetchApiMock.mockResolvedValueOnce(buildResponse(null, { status })) + + await expect(assetService.checkAssetHash(hash)).resolves.toBe(expected) + + expect(fetchApiMock).toHaveBeenCalledWith( + `/assets/hash/${encodeURIComponent(hash)}`, + { + method: 'HEAD', + signal: undefined + } + ) + }) + + it('throws for unexpected responses', async () => { + fetchApiMock.mockResolvedValueOnce(buildResponse(null, { status: 500 })) + + await expect(assetService.checkAssetHash('blake3:abc')).rejects.toThrow( + 'Unexpected asset hash check status: 500' + ) + }) +}) diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index d85a44aec3f..7a2fbb256f0 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -1,4 +1,5 @@ import { fromZodError } from 'zod-validation-error' +import { z } from 'zod' import { st } from '@/i18n' @@ -28,9 +29,14 @@ export interface PaginationOptions { offset?: number } +interface AssetPaginationOptions extends PaginationOptions { + signal?: AbortSignal +} + interface AssetRequestOptions extends PaginationOptions { includeTags: string[] includePublic?: boolean + signal?: AbortSignal } interface AssetExportOptions { @@ -169,10 +175,61 @@ const ASSETS_DOWNLOAD_ENDPOINT = '/assets/download' const ASSETS_EXPORT_ENDPOINT = '/assets/export' const EXPERIMENTAL_WARNING = `EXPERIMENTAL: If you are seeing this please make sure "Comfy.Assets.UseAssetAPI" is set to "false" in your ComfyUI Settings.\n` const DEFAULT_LIMIT = 500 +const INPUT_ASSETS_WITH_PUBLIC_LIMIT = 500 export const MODELS_TAG = 'models' +/** Asset tag used by the backend for placeholder records that are not installed. */ export const MISSING_TAG = 'missing' +/** Result of a HEAD lookup against an exact asset hash. */ +export type AssetHashStatus = 'exists' | 'missing' | 'invalid' + +const BLAKE3_ASSET_HASH_PATTERN = /^blake3:[0-9a-f]{64}$/i +const BLAKE3_HEX_PATTERN = /^[0-9a-f]{64}$/i +const uploadedAssetResponseSchema = assetItemSchema.extend({ + created_new: z.boolean() +}) + +/** Returns true for a prefixed BLAKE3 asset hash: `blake3:<64 hex>`. */ +export function isBlake3AssetHash(value: string): boolean { + return BLAKE3_ASSET_HASH_PATTERN.test(value) +} + +/** Converts a raw 64-character BLAKE3 hex digest into an asset hash. */ +export function toBlake3AssetHash(hash: string | undefined): string | null { + if (!hash || !BLAKE3_HEX_PATTERN.test(hash)) return null + return `blake3:${hash}` +} + +function createAbortError(): DOMException { + return new DOMException('Aborted', 'AbortError') +} + +function throwIfAborted(signal?: AbortSignal): void { + if (signal?.aborted) throw createAbortError() +} + +async function withCallerAbort( + promise: Promise, + signal?: AbortSignal +): Promise { + throwIfAborted(signal) + if (!signal) return await promise + + let removeAbortListener = () => {} + const abortPromise = new Promise((_, reject) => { + const onAbort = () => reject(createAbortError()) + signal.addEventListener('abort', onAbort, { once: true }) + removeAbortListener = () => signal.removeEventListener('abort', onAbort) + }) + + try { + return await Promise.race([promise, abortPromise]) + } finally { + removeAbortListener() + } +} + /** * Validates asset response data using Zod schema */ @@ -186,11 +243,43 @@ function validateAssetResponse(data: unknown): AssetResponse { ) } +function validateUploadedAssetResponse( + data: unknown +): AssetItem & { created_new: boolean } { + const result = uploadedAssetResponseSchema.safeParse(data) + if (result.success) { + return result.data + } + + console.error('Invalid asset upload response:', fromZodError(result.error)) + throw new Error( + st( + 'assetBrowser.errorUploadFailed', + 'Failed to upload asset. Please try again.' + ) + ) +} + /** * Private service for asset-related network requests * Not exposed globally - used internally by ComfyApi */ function createAssetService() { + let inputAssetsIncludingPublic: AssetItem[] | null = null + let inputAssetsIncludingPublicRequestId = 0 + let pendingInputAssetsIncludingPublic: Promise | null = null + + /** Invalidates the cached public-inclusive input assets without aborting in-flight readers. */ + function invalidateInputAssetsIncludingPublic(): void { + inputAssetsIncludingPublicRequestId++ + pendingInputAssetsIncludingPublic = null + inputAssetsIncludingPublic = null + } + + function invalidateInputAssetsCacheIfNeeded(tags?: string[]): void { + if (tags?.includes('input')) invalidateInputAssetsIncludingPublic() + } + /** * Handles API response with consistent error handling and Zod validation */ @@ -202,7 +291,8 @@ function createAssetService() { includeTags, limit = DEFAULT_LIMIT, offset, - includePublic + includePublic, + signal } = options const queryParams = new URLSearchParams({ include_tags: includeTags.join(','), @@ -216,7 +306,9 @@ function createAssetService() { } const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}` - const res = await api.fetchApi(url) + const res = signal + ? await api.fetchApi(url, { signal }) + : await api.fetchApi(url) if (!res.ok) { throw new Error( `${EXPERIMENTAL_WARNING}Unable to load ${context}: Server returned ${res.status}. Please try again.` @@ -402,15 +494,16 @@ function createAssetService() { * @param options - Pagination options * @param options.limit - Maximum number of assets to return (default: 500) * @param options.offset - Number of assets to skip (default: 0) + * @param options.signal - Optional abort signal for cancelling the request * @returns Promise - Full asset objects filtered by tag, excluding missing assets */ async function getAssetsByTag( tag: string, includePublic: boolean = true, - { limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {} + { limit = DEFAULT_LIMIT, offset = 0, signal }: AssetPaginationOptions = {} ): Promise { const data = await handleAssetRequest( - { includeTags: [tag], limit, offset, includePublic }, + { includeTags: [tag], limit, offset, includePublic, signal }, `assets for tag ${tag}` ) @@ -419,6 +512,116 @@ function createAssetService() { ) } + /** + * Gets every asset for a tag by walking paginated asset API responses. + * + * @param tag - The tag to filter by (e.g., 'models', 'input') + * @param includePublic - Whether to include public assets (default: true) + * @param options - Pagination options + * @param options.limit - Page size for each request (default: 500) + * @param options.signal - Optional abort signal for cancelling requests + * @returns Promise - Full asset objects filtered by tag + */ + async function getAllAssetsByTag( + tag: string, + includePublic: boolean = true, + { limit = DEFAULT_LIMIT, signal }: AssetPaginationOptions = {} + ): Promise { + const assets: AssetItem[] = [] + const pageSize = limit > 0 ? limit : DEFAULT_LIMIT + let offset = 0 + + while (true) { + if (signal?.aborted) throw createAbortError() + + const data = await handleAssetRequest( + { + includeTags: [tag], + limit: pageSize, + offset, + includePublic, + signal + }, + `assets for tag ${tag}` + ) + const batch = data.assets ?? [] + assets.push(...batch.filter((asset) => !asset.tags.includes(MISSING_TAG))) + + const noMoreFromServer = data.has_more === false + const inferredLastPage = + data.has_more === undefined && batch.length < pageSize + if (batch.length === 0 || noMoreFromServer || inferredLastPage) { + return assets + } + + offset += batch.length + } + } + + function startInputAssetsIncludingPublicRequest(): Promise { + const requestId = ++inputAssetsIncludingPublicRequestId + + pendingInputAssetsIncludingPublic = getAllAssetsByTag('input', true, { + limit: INPUT_ASSETS_WITH_PUBLIC_LIMIT + }) + .then((assets) => { + if (requestId === inputAssetsIncludingPublicRequestId) { + inputAssetsIncludingPublic = assets + } + return assets + }) + .finally(() => { + if (requestId === inputAssetsIncludingPublicRequestId) { + pendingInputAssetsIncludingPublic = null + } + }) + + void pendingInputAssetsIncludingPublic.catch(() => {}) + return pendingInputAssetsIncludingPublic + } + + /** + * Gets cached input assets including public assets for missing media checks. + * Caller aborts cancel only that caller; shared fetches are invalidated + * through invalidateInputAssetsIncludingPublic(). + */ + async function getInputAssetsIncludingPublic( + signal?: AbortSignal + ): Promise { + throwIfAborted(signal) + if (inputAssetsIncludingPublic) return inputAssetsIncludingPublic + + const request = + pendingInputAssetsIncludingPublic ?? + startInputAssetsIncludingPublicRequest() + return await withCallerAbort(request, signal) + } + + /** + * Checks whether an asset exists for an exact asset hash. + * + * Uses the HEAD /assets/hash/{hash} endpoint and maps status-only responses: + * 200 -> exists, 404 -> missing, and 400 -> invalid hash format. + */ + async function checkAssetHash( + assetHash: string, + signal?: AbortSignal + ): Promise { + const response = await api.fetchApi( + `${ASSETS_ENDPOINT}/hash/${encodeURIComponent(assetHash)}`, + { + method: 'HEAD', + signal + } + ) + + if (response.status === 200) return 'exists' + if (response.status === 404) return 'missing' + if (response.status === 400) return 'invalid' + + throw new Error(`Unexpected asset hash check status: ${response.status}`) + } + /** * Deletes an asset by ID * Only available in cloud environment @@ -437,6 +640,8 @@ function createAssetService() { `Unable to delete asset ${id}: Server returned ${res.status}` ) } + + invalidateInputAssetsIncludingPublic() } /** @@ -544,7 +749,9 @@ function createAssetService() { ) } - return await res.json() + const asset = validateUploadedAssetResponse(await res.json()) + invalidateInputAssetsCacheIfNeeded(params.tags) + return asset } /** @@ -597,7 +804,9 @@ function createAssetService() { ) } - return await res.json() + const asset = validateUploadedAssetResponse(await res.json()) + invalidateInputAssetsCacheIfNeeded(params.tags) + return asset } /** @@ -627,6 +836,7 @@ function createAssetService() { if (!parseResult.success) { throw fromZodError(parseResult.error) } + invalidateInputAssetsIncludingPublic() return parseResult.data } @@ -657,6 +867,7 @@ function createAssetService() { if (!parseResult.success) { throw fromZodError(parseResult.error) } + invalidateInputAssetsIncludingPublic() return parseResult.data } @@ -708,6 +919,13 @@ function createAssetService() { ) ) } + if ( + params.tags?.includes('input') && + result.data.type === 'async' && + result.data.task.status === 'completed' + ) { + invalidateInputAssetsIncludingPublic() + } return result.data } @@ -723,6 +941,7 @@ function createAssetService() { ) ) } + invalidateInputAssetsCacheIfNeeded(params.tags) return result.data } @@ -763,6 +982,10 @@ function createAssetService() { getAssetsForNodeType, getAssetDetails, getAssetsByTag, + getAllAssetsByTag, + getInputAssetsIncludingPublic, + invalidateInputAssetsIncludingPublic, + checkAssetHash, deleteAsset, updateAsset, addAssetTags, diff --git a/src/platform/missingMedia/missingMediaScan.test.ts b/src/platform/missingMedia/missingMediaScan.test.ts index 867bb4d3f7a..8e77aae88cd 100644 --- a/src/platform/missingMedia/missingMediaScan.test.ts +++ b/src/platform/missingMedia/missingMediaScan.test.ts @@ -1,9 +1,11 @@ import { fromAny } from '@total-typescript/shoehorn' -import { describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import type { LGraph } from '@/lib/litegraph/src/LGraph' import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode' import type { IComboWidget } from '@/lib/litegraph/src/types/widgets' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import type * as AssetServiceModule from '@/platform/assets/services/assetService' import { scanAllMediaCandidates, scanNodeMediaCandidates, @@ -13,6 +15,13 @@ import { } from './missingMediaScan' import type { MissingMediaCandidate } from './types' +const { mockCheckAssetHash, mockGetInputAssetsIncludingPublic } = vi.hoisted( + () => ({ + mockCheckAssetHash: vi.fn(), + mockGetInputAssetsIncludingPublic: vi.fn() + }) +) + vi.mock('@/utils/graphTraversalUtil', () => ({ collectAllNodes: (graph: { _testNodes: LGraphNode[] }) => graph._testNodes, getExecutionIdByNode: ( @@ -21,6 +30,21 @@ vi.mock('@/utils/graphTraversalUtil', () => ({ ) => node._testExecutionId ?? String(node.id) })) +vi.mock('@/platform/assets/services/assetService', async () => { + const actual = await vi.importActual( + '@/platform/assets/services/assetService' + ) + + return { + ...actual, + assetService: { + ...actual.assetService, + checkAssetHash: mockCheckAssetHash, + getInputAssetsIncludingPublic: mockGetInputAssetsIncludingPublic + } + } +}) + function makeCandidate( nodeId: string, name: string, @@ -70,6 +94,16 @@ function makeGraph(nodes: LGraphNode[]): LGraph { return fromAny({ _testNodes: nodes }) } +function makeAsset(name: string, assetHash: string | null = null): AssetItem { + return { + id: name, + name, + asset_hash: assetHash, + mime_type: null, + tags: ['input'] + } +} + describe('scanNodeMediaCandidates', () => { it('returns candidate for a LoadImage node with missing image', () => { const graph = makeGraph([]) @@ -232,37 +266,43 @@ describe('groupCandidatesByMediaType', () => { }) describe('verifyCloudMediaCandidates', () => { - it('marks candidates missing when not in input assets', async () => { + const existingHash = + 'blake3:1111111111111111111111111111111111111111111111111111111111111111' + const missingHash = + 'blake3:2222222222222222222222222222222222222222222222222222222222222222' + + beforeEach(() => { + vi.clearAllMocks() + mockCheckAssetHash.mockResolvedValue('missing') + mockGetInputAssetsIncludingPublic.mockResolvedValue([]) + }) + + it('marks candidates missing when the asset hash is not found', async () => { const candidates = [ - makeCandidate('1', 'abc123.png', { isMissing: undefined }), - makeCandidate('2', 'def456.png', { isMissing: undefined }) + makeCandidate('1', missingHash, { isMissing: undefined }), + makeCandidate('2', existingHash, { isMissing: undefined }) ] - const mockStore = { - updateInputs: async () => {}, - inputAssets: [{ asset_hash: 'def456.png', name: 'my-photo.png' }] - } + const checkAssetHash = vi.fn(async (assetHash: string) => + assetHash === existingHash ? ('exists' as const) : ('missing' as const) + ) - await verifyCloudMediaCandidates(candidates, undefined, mockStore) + await verifyCloudMediaCandidates(candidates, undefined, checkAssetHash) expect(candidates[0].isMissing).toBe(true) expect(candidates[1].isMissing).toBe(false) }) - it('calls updateInputs before checking assets', async () => { - let updateCalled = false - const candidates = [makeCandidate('1', 'abc.png', { isMissing: undefined })] - - const mockStore = { - updateInputs: async () => { - updateCalled = true - }, - inputAssets: [] - } + it('uses assetService.checkAssetHash by default', async () => { + const candidates = [ + makeCandidate('1', existingHash, { isMissing: undefined }) + ] + mockCheckAssetHash.mockResolvedValue('exists') - await verifyCloudMediaCandidates(candidates, undefined, mockStore) + await verifyCloudMediaCandidates(candidates) - expect(updateCalled).toBe(true) + expect(candidates[0].isMissing).toBe(false) + expect(mockCheckAssetHash).toHaveBeenCalledWith(existingHash, undefined) }) it('respects abort signal before execution', async () => { @@ -270,69 +310,221 @@ describe('verifyCloudMediaCandidates', () => { controller.abort() const candidates = [ - makeCandidate('1', 'abc123.png', { isMissing: undefined }) + makeCandidate('1', missingHash, { isMissing: undefined }) ] await verifyCloudMediaCandidates(candidates, controller.signal) expect(candidates[0].isMissing).toBeUndefined() + expect(mockCheckAssetHash).not.toHaveBeenCalled() }) - it('respects abort signal after updateInputs', async () => { + it('respects abort signal after hash verification', async () => { const controller = new AbortController() - const candidates = [makeCandidate('1', 'abc.png', { isMissing: undefined })] - - const mockStore = { - updateInputs: async () => { - controller.abort() - }, - inputAssets: [{ asset_hash: 'abc.png', name: 'photo.png' }] - } + const candidates = [ + makeCandidate('1', existingHash, { isMissing: undefined }) + ] + const checkAssetHash = vi.fn(async () => { + controller.abort() + return 'exists' as const + }) - await verifyCloudMediaCandidates(candidates, controller.signal, mockStore) + await verifyCloudMediaCandidates( + candidates, + controller.signal, + checkAssetHash + ) expect(candidates[0].isMissing).toBeUndefined() }) it('skips candidates already resolved as true', async () => { - const candidates = [makeCandidate('1', 'abc.png', { isMissing: true })] - - const mockStore = { - updateInputs: async () => {}, - inputAssets: [] - } + const candidates = [makeCandidate('1', missingHash, { isMissing: true })] - await verifyCloudMediaCandidates(candidates, undefined, mockStore) + await verifyCloudMediaCandidates(candidates) expect(candidates[0].isMissing).toBe(true) + expect(mockCheckAssetHash).not.toHaveBeenCalled() }) it('skips candidates already resolved as false', async () => { - const candidates = [makeCandidate('1', 'abc.png', { isMissing: false })] + const candidates = [makeCandidate('1', existingHash, { isMissing: false })] - const mockStore = { - updateInputs: async () => {}, - inputAssets: [] - } - - await verifyCloudMediaCandidates(candidates, undefined, mockStore) + await verifyCloudMediaCandidates(candidates) expect(candidates[0].isMissing).toBe(false) + expect(mockCheckAssetHash).not.toHaveBeenCalled() }) it('skips entirely when no pending candidates', async () => { - let updateCalled = false - const candidates = [makeCandidate('1', 'abc.png', { isMissing: true })] - - const mockStore = { - updateInputs: async () => { - updateCalled = true - }, - inputAssets: [] - } + const candidates = [makeCandidate('1', missingHash, { isMissing: true })] + + await verifyCloudMediaCandidates(candidates) + + expect(mockCheckAssetHash).not.toHaveBeenCalled() + }) + + it('falls back to input assets for non-blake3 candidate names', async () => { + const candidates = [ + makeCandidate('1', 'photo.png', { isMissing: undefined }), + makeCandidate('2', 'missing.png', { isMissing: undefined }) + ] + const fetchInputAssets = vi.fn(async () => [ + makeAsset('stored-photo.png', 'photo.png') + ]) + + await verifyCloudMediaCandidates( + candidates, + undefined, + undefined, + fetchInputAssets + ) + + expect(mockCheckAssetHash).not.toHaveBeenCalled() + expect(fetchInputAssets).toHaveBeenCalledOnce() + expect(candidates[0].isMissing).toBe(false) + expect(candidates[1].isMissing).toBe(true) + }) + + it('uses public input assets for default legacy fallback', async () => { + const candidates = [ + makeCandidate('1', 'public-photo.png', { isMissing: undefined }) + ] + const inputAssets = Array.from({ length: 500 }, (_, index) => + makeAsset(`asset-${index}.png`) + ) + inputAssets[42] = makeAsset('public-asset-record', 'public-photo.png') + mockGetInputAssetsIncludingPublic.mockResolvedValue(inputAssets) + + await verifyCloudMediaCandidates(candidates) - await verifyCloudMediaCandidates(candidates, undefined, mockStore) + expect(mockGetInputAssetsIncludingPublic).toHaveBeenCalledWith(undefined) + expect(candidates[0].isMissing).toBe(false) + }) + + it('silences aborts while loading legacy fallback input assets', async () => { + const abortError = new Error('aborted') + abortError.name = 'AbortError' + const controller = new AbortController() + const candidates = [ + makeCandidate('1', 'photo.png', { isMissing: undefined }) + ] + const fetchInputAssets = vi.fn(async () => { + controller.abort() + throw abortError + }) + + await expect( + verifyCloudMediaCandidates( + candidates, + controller.signal, + undefined, + fetchInputAssets + ) + ).resolves.toBeUndefined() + + expect(candidates[0].isMissing).toBeUndefined() + }) + + it('silences aborts from the default legacy fallback input asset store path', async () => { + const abortError = new Error('aborted') + abortError.name = 'AbortError' + const controller = new AbortController() + const candidates = [ + makeCandidate('1', 'photo.png', { isMissing: undefined }) + ] + mockGetInputAssetsIncludingPublic.mockImplementationOnce(async () => { + controller.abort() + throw abortError + }) + + await expect( + verifyCloudMediaCandidates(candidates, controller.signal) + ).resolves.toBeUndefined() + + expect(mockGetInputAssetsIncludingPublic).toHaveBeenCalledWith( + controller.signal + ) + expect(candidates[0].isMissing).toBeUndefined() + }) + + it('falls back to input assets when the hash endpoint returns 400', async () => { + const candidates = [ + makeCandidate('1', existingHash, { isMissing: undefined }) + ] + mockCheckAssetHash.mockResolvedValue('invalid') + const fetchInputAssets = vi.fn(async () => [ + makeAsset('photo.png', existingHash) + ]) + + await verifyCloudMediaCandidates( + candidates, + undefined, + undefined, + fetchInputAssets + ) + + expect(mockCheckAssetHash).toHaveBeenCalledWith(existingHash, undefined) + expect(fetchInputAssets).toHaveBeenCalledOnce() + expect(candidates[0].isMissing).toBe(false) + }) + + it('falls back to input assets when hash verification fails', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const candidates = [ + makeCandidate('1', existingHash, { isMissing: undefined }) + ] + const checkAssetHash = vi.fn(async () => { + throw new Error('network failed') + }) + const fetchInputAssets = vi.fn(async () => [ + makeAsset('photo.png', existingHash) + ]) + + await verifyCloudMediaCandidates( + candidates, + undefined, + checkAssetHash, + fetchInputAssets + ) + + expect(fetchInputAssets).toHaveBeenCalledOnce() + expect(candidates[0].isMissing).toBe(false) + expect(warn).toHaveBeenCalledOnce() + warn.mockRestore() + }) - expect(updateCalled).toBe(false) + it('does not call the hash endpoint for malformed blake3-looking values', async () => { + const malformedHash = 'blake3:abc' + const candidates = [ + makeCandidate('1', malformedHash, { isMissing: undefined }) + ] + const fetchInputAssets = vi.fn(async () => [ + makeAsset('legacy.png', malformedHash) + ]) + + await verifyCloudMediaCandidates( + candidates, + undefined, + undefined, + fetchInputAssets + ) + + expect(mockCheckAssetHash).not.toHaveBeenCalled() + expect(fetchInputAssets).toHaveBeenCalledOnce() + expect(candidates[0].isMissing).toBe(false) + }) + + it('deduplicates checks for repeated candidate names', async () => { + const candidates = [ + makeCandidate('1', missingHash, { isMissing: undefined }), + makeCandidate('2', missingHash, { isMissing: undefined }) + ] + + await verifyCloudMediaCandidates(candidates) + + expect(mockCheckAssetHash).toHaveBeenCalledOnce() + expect(candidates[0].isMissing).toBe(true) + expect(candidates[1].isMissing).toBe(true) }) }) diff --git a/src/platform/missingMedia/missingMediaScan.ts b/src/platform/missingMedia/missingMediaScan.ts index 7b4592768bb..5050996e06f 100644 --- a/src/platform/missingMedia/missingMediaScan.ts +++ b/src/platform/missingMedia/missingMediaScan.ts @@ -18,6 +18,12 @@ import { } from '@/utils/graphTraversalUtil' import { LGraphEventMode } from '@/lib/litegraph/src/types/globalEnums' import { resolveComboValues } from '@/utils/litegraphUtil' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import type { AssetHashStatus } from '@/platform/assets/services/assetService' +import { + assetService, + isBlake3AssetHash +} from '@/platform/assets/services/assetService' /** Map of node types to their media widget name and media type. */ const MEDIA_NODE_WIDGETS: Record< @@ -106,41 +112,130 @@ export function scanNodeMediaCandidates( return candidates } -interface InputVerifier { - updateInputs: () => Promise - inputAssets: Array<{ asset_hash?: string | null; name: string }> +type AssetHashVerifier = ( + assetHash: string, + signal?: AbortSignal +) => Promise + +type InputAssetFetcher = (signal?: AbortSignal) => Promise + +function groupCandidatesForHashLookup(candidates: MissingMediaCandidate[]): { + candidatesByHash: Map + legacyCandidates: MissingMediaCandidate[] +} { + const candidatesByHash = new Map() + const legacyCandidates: MissingMediaCandidate[] = [] + + for (const candidate of candidates) { + if (!isBlake3AssetHash(candidate.name)) { + legacyCandidates.push(candidate) + continue + } + + const hashCandidates = candidatesByHash.get(candidate.name) + if (hashCandidates) hashCandidates.push(candidate) + else candidatesByHash.set(candidate.name, [candidate]) + } + + return { candidatesByHash, legacyCandidates } +} + +async function verifyCandidatesByHash( + candidatesByHash: Map, + legacyCandidates: MissingMediaCandidate[], + signal: AbortSignal | undefined, + checkAssetHash: AssetHashVerifier +): Promise { + await Promise.all( + Array.from(candidatesByHash, async ([assetHash, hashCandidates]) => { + if (signal?.aborted) return + + let status: AssetHashStatus + try { + status = await checkAssetHash(assetHash, signal) + if (signal?.aborted) return + } catch (err) { + if (signal?.aborted || isAbortError(err)) return + console.warn( + '[Missing Media Pipeline] Failed to verify asset hash:', + err + ) + legacyCandidates.push(...hashCandidates) + return + } + + if (status === 'invalid') { + legacyCandidates.push(...hashCandidates) + return + } + + for (const candidate of hashCandidates) { + candidate.isMissing = status === 'missing' + } + }) + ) } /** - * Verify cloud media candidates against the input assets fetched from the - * assets store. Mutates candidates' `isMissing` in place. + * Verify cloud media candidates by probing the asset hash endpoint first. + * Invalid hash values fall back to the legacy input asset list check. */ export async function verifyCloudMediaCandidates( candidates: MissingMediaCandidate[], signal?: AbortSignal, - assetsStore?: InputVerifier + checkAssetHash: AssetHashVerifier = assetService.checkAssetHash, + fetchInputAssets: InputAssetFetcher = fetchMissingInputAssets ): Promise { if (signal?.aborted) return const pending = candidates.filter((c) => c.isMissing === undefined) if (pending.length === 0) return - const store = - assetsStore ?? (await import('@/stores/assetsStore')).useAssetsStore() + const { candidatesByHash, legacyCandidates } = + groupCandidatesForHashLookup(pending) + await verifyCandidatesByHash( + candidatesByHash, + legacyCandidates, + signal, + checkAssetHash + ) + + if (signal?.aborted || legacyCandidates.length === 0) return - await store.updateInputs() + let inputAssets: AssetItem[] + try { + inputAssets = await fetchInputAssets(signal) + } catch (err) { + if (signal?.aborted || isAbortError(err)) return + throw err + } if (signal?.aborted) return const assetHashes = new Set( - store.inputAssets.map((a) => a.asset_hash).filter((h): h is string => !!h) + inputAssets.map((a) => a.asset_hash).filter((h): h is string => !!h) ) - for (const c of pending) { - c.isMissing = !assetHashes.has(c.name) + for (const candidate of legacyCandidates) { + candidate.isMissing = !assetHashes.has(candidate.name) } } +async function fetchMissingInputAssets( + signal?: AbortSignal +): Promise { + return await assetService.getInputAssetsIncludingPublic(signal) +} + +function isAbortError(err: unknown): boolean { + return ( + typeof err === 'object' && + err !== null && + 'name' in err && + err.name === 'AbortError' + ) +} + /** Group confirmed-missing candidates by file name into view models. */ export function groupCandidatesByName( candidates: MissingMediaCandidate[] diff --git a/src/platform/missingModel/missingModelScan.test.ts b/src/platform/missingModel/missingModelScan.test.ts index cc26dbc609a..05326f8bb01 100644 --- a/src/platform/missingModel/missingModelScan.test.ts +++ b/src/platform/missingModel/missingModelScan.test.ts @@ -19,6 +19,11 @@ import activeSubgraphUnmatchedModel from '@/platform/missingModel/__fixtures__/a import bypassedSubgraphUnmatchedModel from '@/platform/missingModel/__fixtures__/bypassedSubgraphUnmatchedModel.json' with { type: 'json' } import type { MissingModelCandidate } from '@/platform/missingModel/types' import type { ComfyWorkflowJSON } from '@/platform/workflow/validation/schemas/workflowSchema' +import type * as AssetServiceModule from '@/platform/assets/services/assetService' + +const { mockCheckAssetHash } = vi.hoisted(() => ({ + mockCheckAssetHash: vi.fn() +})) vi.mock('@/utils/graphTraversalUtil', () => ({ collectAllNodes: (graph: { _testNodes: LGraphNode[] }) => graph._testNodes, @@ -28,6 +33,20 @@ vi.mock('@/utils/graphTraversalUtil', () => ({ ) => node._testExecutionId ?? String(node.id) })) +vi.mock('@/platform/assets/services/assetService', async () => { + const actual = await vi.importActual( + '@/platform/assets/services/assetService' + ) + + return { + ...actual, + assetService: { + ...actual.assetService, + checkAssetHash: mockCheckAssetHash + } + } +}) + /** Helper: create a combo widget mock */ function makeComboWidget( name: string, @@ -43,7 +62,7 @@ function makeComboWidget( } /** Helper: create an asset widget mock (Cloud combo replacement) */ -function makeAssetWidget(name: string, value: string): IBaseWidget { +function makeAssetWidget(name: string, value: unknown): IBaseWidget { return fromAny({ type: 'asset', name, @@ -551,6 +570,16 @@ describe('scanAllModelCandidates', () => { expect(result).toEqual([]) }) + it('should skip asset widgets with non-string values', () => { + const graph = makeGraph([ + makeNode(1, 'SomeNode', [makeAssetWidget('ckpt_name', 123)]) + ]) + + const result = scanAllModelCandidates(graph, noAssetSupport) + + expect(result).toEqual([]) + }) + it('should scan both combo and asset widgets on the same node', () => { const graph = makeGraph([ makeNode(1, 'DualLoaderNode', [ @@ -1411,6 +1440,7 @@ function makeAssetCandidate( describe('verifyAssetSupportedCandidates', () => { beforeEach(() => { vi.clearAllMocks() + mockCheckAssetHash.mockResolvedValue('missing') mockIsModelLoading.mockReturnValue(false) mockHasMore.mockReturnValue(false) mockGetAssets.mockReturnValue([]) @@ -1428,6 +1458,125 @@ describe('verifyAssetSupportedCandidates', () => { ) }) + it('should resolve isMissing=false when the blake3 hash endpoint finds the asset', async () => { + const hash = + '1111111111111111111111111111111111111111111111111111111111111111' + const candidates = [ + makeAssetCandidate('model.safetensors', { + hash, + hashType: 'blake3' + }) + ] + mockCheckAssetHash.mockResolvedValue('exists') + + await verifyAssetSupportedCandidates(candidates) + + expect(candidates[0].isMissing).toBe(false) + expect(mockCheckAssetHash).toHaveBeenCalledWith(`blake3:${hash}`, undefined) + expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled() + }) + + it('should fall back to asset store matching when the blake3 hash is not found', async () => { + const hash = + '2222222222222222222222222222222222222222222222222222222222222222' + const candidates = [ + makeAssetCandidate('my_model.safetensors', { + hash, + hashType: 'blake3' + }) + ] + mockCheckAssetHash.mockResolvedValue('missing') + mockGetAssets.mockReturnValue([ + { + id: '1', + name: 'my_model.safetensors', + asset_hash: null, + metadata: { filename: 'my_model.safetensors' } + } + ]) + + await verifyAssetSupportedCandidates(candidates) + + expect(candidates[0].isMissing).toBe(false) + expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith( + 'CheckpointLoaderSimple' + ) + }) + + it('should fall back to asset store matching when hash verification fails', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const hash = + '3333333333333333333333333333333333333333333333333333333333333333' + const candidates = [ + makeAssetCandidate('my_model.safetensors', { + hash, + hashType: 'blake3' + }) + ] + mockCheckAssetHash.mockRejectedValue(new Error('network failed')) + mockGetAssets.mockReturnValue([ + { + id: '1', + name: 'my_model.safetensors', + asset_hash: null, + metadata: { filename: 'my_model.safetensors' } + } + ]) + + await verifyAssetSupportedCandidates(candidates) + + expect(candidates[0].isMissing).toBe(false) + expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith( + 'CheckpointLoaderSimple' + ) + expect(warn).toHaveBeenCalledOnce() + warn.mockRestore() + }) + + it('should skip malformed blake3 hashes and use asset store matching', async () => { + const candidates = [ + makeAssetCandidate('my_model.safetensors', { + hash: 'abc123', + hashType: 'blake3' + }) + ] + mockGetAssets.mockReturnValue([ + { + id: '1', + name: 'my_model.safetensors', + asset_hash: null, + metadata: { filename: 'my_model.safetensors' } + } + ]) + + await verifyAssetSupportedCandidates(candidates) + + expect(mockCheckAssetHash).not.toHaveBeenCalled() + expect(candidates[0].isMissing).toBe(false) + }) + + it('should not warn or fall back when hash verification is aborted', async () => { + const warn = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const abortError = new Error('aborted') + abortError.name = 'AbortError' + const hash = + '4444444444444444444444444444444444444444444444444444444444444444' + const candidates = [ + makeAssetCandidate('my_model.safetensors', { + hash, + hashType: 'blake3' + }) + ] + mockCheckAssetHash.mockRejectedValue(abortError) + + await verifyAssetSupportedCandidates(candidates) + + expect(candidates[0].isMissing).toBeUndefined() + expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled() + expect(warn).not.toHaveBeenCalled() + warn.mockRestore() + }) + it('should resolve isMissing=false when asset with matching hash exists', async () => { const candidates = [ makeAssetCandidate('model.safetensors', { @@ -1442,6 +1591,7 @@ describe('verifyAssetSupportedCandidates', () => { await verifyAssetSupportedCandidates(candidates) expect(candidates[0].isMissing).toBe(false) + expect(mockCheckAssetHash).not.toHaveBeenCalled() }) it('should resolve isMissing=false when asset with matching filename exists', async () => { diff --git a/src/platform/missingModel/missingModelScan.ts b/src/platform/missingModel/missingModelScan.ts index 11302154bc9..bef803112a5 100644 --- a/src/platform/missingModel/missingModelScan.ts +++ b/src/platform/missingModel/missingModelScan.ts @@ -24,6 +24,11 @@ import { } from '@/utils/graphTraversalUtil' import { LGraphEventMode } from '@/lib/litegraph/src/types/globalEnums' import { resolveComboValues } from '@/utils/litegraphUtil' +import type { AssetHashStatus } from '@/platform/assets/services/assetService' +import { + assetService, + toBlake3AssetHash +} from '@/platform/assets/services/assetService' export type MissingModelWorkflowData = FlattenableWorkflowGraph & { models?: ModelFile[] @@ -177,7 +182,7 @@ function scanAssetWidget( getDirectory: ((nodeType: string) => string | undefined) | undefined ): MissingModelCandidate | null { const value = widget.value - if (!value.trim()) return null + if (typeof value !== 'string' || !value.trim()) return null if (!isModelFileName(value)) return null return { @@ -445,20 +450,68 @@ interface AssetVerifier { getAssets: (nodeType: string) => AssetItem[] | undefined } +type AssetHashVerifier = ( + assetHash: string, + signal?: AbortSignal +) => Promise + export async function verifyAssetSupportedCandidates( candidates: MissingModelCandidate[], signal?: AbortSignal, - assetsStore?: AssetVerifier + assetsStore?: AssetVerifier, + checkAssetHash: AssetHashVerifier = assetService.checkAssetHash ): Promise { if (signal?.aborted) return + const pendingCandidates = candidates.filter( + (c) => c.isAssetSupported && c.isMissing === undefined + ) + if (pendingCandidates.length === 0) return + const pendingNodeTypes = new Set() - for (const c of candidates) { - if (c.isAssetSupported && c.isMissing === undefined) { - pendingNodeTypes.add(c.nodeType) + const candidatesByHash = new Map() + + for (const candidate of pendingCandidates) { + const assetHash = getBlake3AssetHash(candidate) + if (!assetHash) { + pendingNodeTypes.add(candidate.nodeType) + continue } + + const hashCandidates = candidatesByHash.get(assetHash) + if (hashCandidates) hashCandidates.push(candidate) + else candidatesByHash.set(assetHash, [candidate]) } + await Promise.all( + Array.from(candidatesByHash, async ([assetHash, hashCandidates]) => { + if (signal?.aborted) return + + try { + const status = await checkAssetHash(assetHash, signal) + if (signal?.aborted) return + + if (status === 'exists') { + for (const candidate of hashCandidates) { + candidate.isMissing = false + } + return + } + } catch (err) { + if (signal?.aborted || isAbortError(err)) return + console.warn( + '[Missing Model Pipeline] Failed to verify asset hash:', + err + ) + } + + for (const candidate of hashCandidates) { + pendingNodeTypes.add(candidate.nodeType) + } + }) + ) + + if (signal?.aborted) return if (pendingNodeTypes.size === 0) return const store = @@ -491,6 +544,20 @@ export async function verifyAssetSupportedCandidates( } } +function getBlake3AssetHash(candidate: MissingModelCandidate): string | null { + if (candidate.hashType?.toLowerCase() !== 'blake3') return null + return toBlake3AssetHash(candidate.hash) +} + +function isAbortError(err: unknown): boolean { + return ( + typeof err === 'object' && + err !== null && + 'name' in err && + err.name === 'AbortError' + ) +} + function normalizePath(path: string): string { return path.replace(/\\/g, '/') } diff --git a/src/stores/assetsStore.test.ts b/src/stores/assetsStore.test.ts index 0a8cb0a75ff..6f606061084 100644 --- a/src/stores/assetsStore.test.ts +++ b/src/stores/assetsStore.test.ts @@ -24,7 +24,12 @@ vi.mock('@/scripts/api', () => ({ vi.mock('@/platform/assets/services/assetService', () => ({ assetService: { getAssetsByTag: vi.fn(), - getAssetsForNodeType: vi.fn() + getAllAssetsByTag: vi.fn(), + getAssetsForNodeType: vi.fn(), + invalidateInputAssetsIncludingPublic: vi.fn(), + updateAsset: vi.fn(), + addAssetTags: vi.fn(), + removeAssetTags: vi.fn() } })) @@ -1034,4 +1039,234 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => { ).not.toThrow() }) }) + + describe('updateAssetMetadata optimistic cache', () => { + it('reflects the server response in the cache after a successful update', async () => { + const store = useAssetsStore() + const original = { + ...createMockAsset('opt-1'), + user_metadata: { note: 'before' } as Record + } + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + original + ]) + await store.updateModelsForNodeType('CheckpointLoaderSimple') + + const serverResponse = { + ...original, + user_metadata: { note: 'server-confirmed' } + } + vi.mocked(assetService.updateAsset).mockResolvedValueOnce(serverResponse) + + await store.updateAssetMetadata( + original, + { note: 'optimistic' }, + 'CheckpointLoaderSimple' + ) + + const cached = store.getAssets('CheckpointLoaderSimple')[0] + expect(cached.user_metadata).toEqual({ note: 'server-confirmed' }) + }) + + it('rolls back to the original metadata when the server rejects', async () => { + const store = useAssetsStore() + const original = { + ...createMockAsset('opt-2'), + user_metadata: { note: 'before' } as Record + } + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + original + ]) + await store.updateModelsForNodeType('CheckpointLoaderSimple') + + vi.mocked(assetService.updateAsset).mockRejectedValueOnce( + new Error('500 Internal Error') + ) + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + await store.updateAssetMetadata( + original, + { note: 'will be reverted' }, + 'CheckpointLoaderSimple' + ) + + const cached = store.getAssets('CheckpointLoaderSimple')[0] + expect(cached.user_metadata).toEqual({ note: 'before' }) + consoleSpy.mockRestore() + }) + }) + + describe('updateAssetTags diff-based dispatch', () => { + it('skips both endpoints and does not mutate the cache when tags are unchanged', async () => { + const store = useAssetsStore() + const asset = createMockAsset('tags-noop', ['models', 'checkpoints']) + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + asset + ]) + await store.updateModelsForNodeType('CheckpointLoaderSimple') + + await store.updateAssetTags( + asset, + ['checkpoints', 'models'], + 'CheckpointLoaderSimple' + ) + + expect(vi.mocked(assetService.addAssetTags)).not.toHaveBeenCalled() + expect(vi.mocked(assetService.removeAssetTags)).not.toHaveBeenCalled() + }) + + it('calls only the add endpoint when there are no tags to remove', async () => { + const store = useAssetsStore() + const asset = createMockAsset('tags-add-only', ['models']) + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + asset + ]) + await store.updateModelsForNodeType('CheckpointLoaderSimple') + + vi.mocked(assetService.addAssetTags).mockResolvedValueOnce({ + added: ['featured'], + total_tags: ['models', 'featured'] + }) + + await store.updateAssetTags( + asset, + ['models', 'featured'], + 'CheckpointLoaderSimple' + ) + + expect(vi.mocked(assetService.addAssetTags)).toHaveBeenCalledWith( + 'tags-add-only', + ['featured'] + ) + expect(vi.mocked(assetService.removeAssetTags)).not.toHaveBeenCalled() + expect(store.getAssets('CheckpointLoaderSimple')[0].tags).toEqual([ + 'models', + 'featured' + ]) + }) + + it('rolls back the cache when removeAssetTags succeeds but addAssetTags rejects', async () => { + // Documents the known recovery gap on partial-failure during a + // "change category" mutation: remove succeeds server-side, add fails, + // and the cache is restored to the original tags. The server now has + // the old category tag removed, so the cache and backend diverge until + // the next refetch — surface that gap here rather than papering over it. + const store = useAssetsStore() + const asset = createMockAsset('tags-partial-fail', ['models', 'loras']) + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + asset + ]) + await store.updateModelsForNodeType('LoraLoader') + + vi.mocked(assetService.removeAssetTags).mockResolvedValueOnce({ + removed: ['loras'], + total_tags: ['models'] + }) + vi.mocked(assetService.addAssetTags).mockRejectedValueOnce( + new Error('500 add failed') + ) + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + await store.updateAssetTags( + asset, + ['models', 'checkpoints'], + 'LoraLoader' + ) + + expect(vi.mocked(assetService.removeAssetTags)).toHaveBeenCalledWith( + 'tags-partial-fail', + ['loras'] + ) + expect(vi.mocked(assetService.addAssetTags)).toHaveBeenCalledWith( + 'tags-partial-fail', + ['checkpoints'] + ) + // Cache restored to original tags even though the server has already + // removed 'loras'. This codifies a known divergence — fix the recovery + // semantics in updateAssetTags to address it (e.g. invalidate the + // category cache, or reconcile against the last confirmed total_tags). + expect(store.getAssets('LoraLoader')[0].tags).toEqual(['models', 'loras']) + consoleSpy.mockRestore() + }) + }) +}) + +describe('assetsStore - Deletion State and Input Mapping', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + vi.clearAllMocks() + }) + + describe('setAssetDeleting / isAssetDeleting', () => { + it('tracks per-asset deletion state and clears it on flip', () => { + const store = useAssetsStore() + + expect(store.isAssetDeleting('asset-A')).toBe(false) + + store.setAssetDeleting('asset-A', true) + expect(store.isAssetDeleting('asset-A')).toBe(true) + expect(store.isAssetDeleting('asset-B')).toBe(false) + + store.setAssetDeleting('asset-A', false) + expect(store.isAssetDeleting('asset-A')).toBe(false) + }) + }) + + describe('getInputName', () => { + it('resolves a hashed filename to the human-readable name when the input asset is in the cache', async () => { + mockIsCloud.value = true + try { + setActivePinia(createTestingPinia({ stubActions: false })) + const store = useAssetsStore() + + vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([ + { + id: 'input-1', + name: 'cute-puppy.png', + asset_hash: 'abc123def.png', + tags: ['input'] + } + ]) + await store.updateInputs() + + expect(store.getInputName('abc123def.png')).toBe('cute-puppy.png') + } finally { + mockIsCloud.value = false + } + }) + + it('falls back to the original filename when the input asset is not cached', () => { + const store = useAssetsStore() + expect(store.getInputName('unknown.png')).toBe('unknown.png') + }) + }) + + describe('updateInputs cloud routing', () => { + it('reads from assetService.getAssetsByTag with limit 100 when isCloud is true', async () => { + mockIsCloud.value = true + try { + setActivePinia(createTestingPinia({ stubActions: false })) + const store = useAssetsStore() + + vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([]) + await store.updateInputs() + + expect(vi.mocked(assetService.getAssetsByTag)).toHaveBeenCalledWith( + 'input', + false, + { limit: 100 } + ) + expect( + assetService.invalidateInputAssetsIncludingPublic + ).toHaveBeenCalledOnce() + } finally { + mockIsCloud.value = false + } + }) + }) }) diff --git a/src/stores/assetsStore.ts b/src/stores/assetsStore.ts index df655a3b3bd..34f9910e94a 100644 --- a/src/stores/assetsStore.ts +++ b/src/stores/assetsStore.ts @@ -123,7 +123,7 @@ export const useAssetsStore = defineStore('assets', () => { state: inputAssets, isLoading: inputLoading, error: inputError, - execute: updateInputs + execute: executeUpdateInputs } = useAsyncState(fetchInputFiles, [], { immediate: false, resetOnExecute: false, @@ -132,6 +132,12 @@ export const useAssetsStore = defineStore('assets', () => { } }) + const updateInputs = async () => { + const result = await executeUpdateInputs() + assetService.invalidateInputAssetsIncludingPublic() + return result + } + /** * Fetch history assets with pagination support * @param loadMore - true for pagination (append), false for initial load (replace)