Compare commits

...

2 Commits

Author SHA1 Message Date
Mary Hipp
b5d7471326 add tests and readme 2025-09-17 13:53:08 -04:00
Mary Hipp
ae8bb9a9a7 enqueue refactor 2025-09-17 11:40:11 -04:00
12 changed files with 673 additions and 424 deletions

View File

@@ -27,6 +27,7 @@ export const zLogNamespace = z.enum([
'queue',
'workflows',
'video',
'enqueue',
]);
export type LogNamespace = z.infer<typeof zLogNamespace>;

View File

@@ -253,11 +253,11 @@ const PublishWorkflowButton = memo(() => {
),
duration: null,
});
assert(result.value.enqueueResult.batch.batch_id);
assert(result.value.batchConfig.validation_run_data);
assert(result.value?.enqueueResult.batch.batch_id);
assert(result.value?.batchConfig.validation_run_data);
$validationRunData.set({
batchId: result.value.enqueueResult.batch.batch_id,
workflowId: result.value.batchConfig.validation_run_data.workflow_id,
batchId: result.value?.enqueueResult.batch.batch_id,
workflowId: result.value?.batchConfig.validation_run_data.workflow_id,
});
log.debug(parseify(result.value), 'Enqueued batch');
}

View File

@@ -0,0 +1,51 @@
# Queue Enqueue Patterns
This directory contains the hooks and utilities that translate UI actions into queue batches. The flow is intentionally
modular so adding a new enqueue type (e.g. a new generation mode) follows a predictable recipe.
## Key building blocks
- `hooks/useEnqueue*.ts` Feature-specific hooks (generate, canvas, upscaling, video, workflows). Each hook wires local
state to the shared enqueue utilities.
- `hooks/utils/graphBuilders.ts` Maps base models (sdxl, flux, etc.) to their graph builder functions and normalizes
synchronous vs. asynchronous builders.
- `hooks/utils/executeEnqueue.ts` Orchestrates the enqueue lifecycle:
1. dispatch the `enqueueRequested*` action
2. build the graph/batch data
3. call `queueApi.endpoints.enqueueBatch`
4. run success/error callbacks
## Adding a new enqueue type
1. **Implement the graph builder (if needed).**
- Create the graph construction logic in `features/nodes/util/graph/generation/...` so it returns a
`GraphBuilderReturn`.
- If the builder reuses existing primitives, consider wiring it into `graphBuilders.ts` by extending the `graphBuilderMap`.
2. **Create the enqueue hook.**
- Add `useEnqueue<Feature>.ts` mirroring the existing hooks. Import `executeEnqueue` and supply feature-specific
`build`, `prepareBatch`, and `onSuccess` callbacks.
- If the feature depends on a new base model, add it to `graphBuilders.ts`.
3. **Register the tab in `useInvoke`.**
- `useInvoke.ts` looks up handlers based on the active tab. Import your new hook and call it inside the `switch`
(or future registry) so the UI can enqueue from the feature.
4. **Add Redux action (optional).**
- Most enqueue hooks dispatch a `enqueueRequested*` action for devtools visibility. Create one with `createAction` if
you want similar tracing.
5. **Cover with tests.**
- Unit-test feature-specific behavior (graph selection, batch tweaks). The shared helpers already have coverage in
`hooks/utils/`.
## Tips
- Keep `build` lean: fetch state, compose graph/batch data, and return `null` when prerequisites are missing. The shared
helper will skip enqueueing and your `onError` will handle logging.
- Use the shared `prepareLinearUIBatch` for single-graph UI workflows. For advanced cases (multi-run batches, workflow
validation runs), supply a custom `prepareBatch` function.
- Prefer updating `graphBuilders.ts` when adding a new base model so every image-based enqueue automatically benefits.
With this structure, the main task when introducing a new enqueue type is describing how to build its graph and how to
massage the batch payload—everything else (dispatching, API calls, history updates) is handled by the utilities.

View File

@@ -1,154 +1,114 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
import { buildGraphForBase } from './utils/graphBuilders';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
const state = getState();
const destination = selectCanvasDestination(state);
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return buildFluxKontextGraph(graphBuilderArg);
case 'gemini-2.5':
return buildGemini2_5Graph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type CanvasBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueCanvas = () => {
const store = useAppStore();
const canvasManager = useCanvasManagerSafe();
const enqueue = useCallback(
(prepend: boolean) => {
if (!canvasManager) {
log.error('Canvas manager is not available');
return;
return null;
}
return enqueueCanvas(store, canvasManager, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedCanvas,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const destination = selectCanvasDestination(state);
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
const buildGraphResult = await withResultAsync(
async () => await buildGraphForBase(model.base, graphBuilderArg)
);
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies CanvasBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[canvasManager, store]
);
return enqueue;
};

View File

@@ -1,143 +1,103 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
import { buildGraphForBase } from './utils/graphBuilders';
const log = logger('generation');
export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate');
const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedGenerate());
const state = getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
case 'sdxl':
return await buildSDXLGraph(graphBuilderArg);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(graphBuilderArg);
case `sd-3`:
return await buildSD3Graph(graphBuilderArg);
case `flux`:
return await buildFLUXGraph(graphBuilderArg);
case 'cogview4':
return await buildCogView4Graph(graphBuilderArg);
case 'imagen3':
return buildImagen3Graph(graphBuilderArg);
case 'imagen4':
return buildImagen4Graph(graphBuilderArg);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(graphBuilderArg);
case 'flux-kontext':
return buildFluxKontextGraph(graphBuilderArg);
case 'gemini-2.5':
return buildGemini2_5Graph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'generate',
destination: 'generate',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type GenerateBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueGenerate = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueGenerate(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedGenerate,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
const buildGraphResult = await withResultAsync(
async () => await buildGraphForBase(model.base, graphBuilderArg)
);
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'generate',
destination: 'generate',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies GenerateBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,62 +1,64 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
const log = logger('generation');
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedUpscaling());
const state = getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
type UpscaleBuildResult = {
batchConfig: EnqueueBatchArg;
};
export const useEnqueueUpscaling = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueUpscaling(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedUpscaling,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return null;
}
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'upscaling',
destination: 'gallery',
});
return { batchConfig } satisfies UpscaleBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,7 +1,6 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
@@ -14,114 +13,107 @@ import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
import { AssertionError } from 'tsafe';
import type { EnqueueBatchArg } from './utils/executeEnqueue';
import { executeEnqueue } from './utils/executeEnqueue';
const log = logger('generation');
export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos');
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
type VideoBuildResult = {
batchConfig: EnqueueBatchArg;
};
dispatch(enqueueRequestedVideos());
const state = getState();
const model = state.video.videoModel;
if (!model) {
log.error('No model found in state');
return;
const getVideoGraphBuilder = (base: string) => {
switch (base) {
case 'veo3':
return buildVeo3VideoGraph;
case 'runway':
return buildRunwayVideoGraph;
default:
return null;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
case 'veo3':
return await buildVeo3VideoGraph(graphBuilderArg);
case 'runway':
return await buildRunwayVideoGraph(graphBuilderArg);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, positivePrompt, seed } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'videos',
destination: 'gallery',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
// const batchConfig = {
// prepend,
// batch: {
// graph: g.getGraph(),
// runs: 1,
// origin,
// destination,
// },
// };
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
// Push to prompt history on successful enqueue
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
return { batchConfig, enqueueResult };
};
export const useEnqueueVideo = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueVideo(store, prepend);
return executeEnqueue({
store,
options: { prepend },
requestedAction: enqueueRequestedVideos,
log,
build: async ({ store: innerStore, options }) => {
const state = innerStore.getState();
const model = state.video.videoModel;
if (!model) {
log.error('No model found in state');
return null;
}
const builder = getVideoGraphBuilder(model.base);
if (!builder) {
log.error({ base: model.base }, 'No graph builders for base');
return null;
}
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
const buildGraphResult = await withResultAsync(async () => await builder(graphBuilderArg));
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({ status, title, description });
return null;
}
const { g, positivePrompt, seed } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base: model.base,
prepend: options.prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'videos',
destination: 'gallery',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return null;
}
return {
batchConfig: prepareBatchResult.value,
} satisfies VideoBuildResult;
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
onSuccess: ({ store: innerStore }) => {
const state = innerStore.getState();
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
},
});
},
[store]
);
return enqueue;
};

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppDispatch, AppStore, RootState } from 'app/store/store';
import type { AppDispatch, RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { groupBy } from 'es-toolkit/compat';
import {
@@ -15,10 +15,11 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, EnqueueBatchArg, S } from 'services/api/types';
import { assert } from 'tsafe';
import { executeEnqueue } from './utils/executeEnqueue';
export const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise<Batch['data']> => {
@@ -119,60 +120,50 @@ const getValidationRunData = (state: RootState, templates: Templates): S['Valida
};
};
const enqueueWorkflows = async (
store: AppStore,
templates: Templates,
prepend: boolean,
isApiValidationRun: boolean
) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedWorkflows());
const state = getState();
const nodesState = selectNodesSlice(state);
const graph = buildNodesGraph(state, templates);
const workflow = buildWorkflowWithValidation(nodesState);
if (workflow) {
// embedded workflows don't have an id
delete workflow.id;
}
const runs = state.params.iterations;
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow,
runs,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend,
};
if (isApiValidationRun) {
batchConfig.validation_run_data = getValidationRunData(state, templates);
// If the batch is an API validation run, we only want to run it once
batchConfig.batch.runs = 1;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueWorkflows = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean, isApiValidationRun: boolean) => {
return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun);
return executeEnqueue({
store,
options: { prepend, isApiValidationRun },
requestedAction: enqueueRequestedWorkflows,
build: async ({ store: innerStore, options }) => {
const { dispatch, getState } = innerStore;
const state = getState();
const nodesState = selectNodesSlice(state);
const templates = $templates.get();
const graph = buildNodesGraph(state, templates);
const workflow = buildWorkflowWithValidation(nodesState);
if (workflow) {
// embedded workflows don't have an id
delete workflow.id;
}
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow,
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: options.prepend,
};
if (options.isApiValidationRun) {
batchConfig.validation_run_data = getValidationRunData(state, templates);
batchConfig.batch.runs = 1;
}
return { batchConfig } satisfies { batchConfig: EnqueueBatchArg };
},
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
});
},
[store]
);

View File

@@ -0,0 +1,107 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStore, RootState } from 'app/store/store';
import type { EnqueueBatchArg, EnqueueBatchResponse } from './executeEnqueue';
import { executeEnqueue } from './executeEnqueue';
import { describe, expect, it, vi } from 'vitest';
const createTestStore = () => {
const state = {} as RootState;
const dispatch = vi.fn<(action: unknown) => unknown>((action) => {
if (typeof action === 'object' && action !== null && 'type' in action) {
return undefined;
}
const unwrap = vi.fn<() => Promise<EnqueueBatchResponse>>().mockResolvedValue({
batch_id: 'batch-1',
item_ids: ['item-1'],
} as EnqueueBatchResponse);
return { unwrap };
});
const getState = vi.fn(() => state);
return { dispatch, getState } as unknown as AppStore;
};
const createBatchArg = (prepend: boolean): EnqueueBatchArg => ({
prepend,
batch: {
graph: {} as EnqueueBatchArg['batch']['graph'],
runs: 1,
data: [],
origin: 'test',
destination: 'test',
},
});
describe('executeEnqueue', () => {
it('dispatches enqueue flow and invokes success callback', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: false } as const;
const batchConfig = createBatchArg(options.prepend);
const onSuccess = vi.fn();
const build = vi.fn(async () => ({ batchConfig }));
const prepareBatch = vi.fn(() => batchConfig);
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch,
onSuccess,
log: { error: vi.fn() },
});
expect(store.dispatch).toHaveBeenCalledWith(requestedAction());
expect(build).toHaveBeenCalledWith({ store, options });
expect(prepareBatch).toHaveBeenCalledWith({ store, options, buildResult: { batchConfig } });
expect(onSuccess).toHaveBeenCalled();
expect(result?.batchConfig).toBe(batchConfig);
});
it('stops when build returns null', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: true } as const;
const build = vi.fn(async () => null);
const prepareBatch = vi.fn();
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch,
log: { error: vi.fn() },
});
expect(result).toBeNull();
expect(build).toHaveBeenCalled();
expect(prepareBatch).not.toHaveBeenCalled();
});
it('invokes onError when build throws', async () => {
const store = createTestStore();
const requestedAction = createAction('test/enqueue');
const options = { prepend: false } as const;
const error = new Error('boom');
const build = vi.fn(async () => {
throw error;
});
const onError = vi.fn();
const logError = vi.fn();
const result = await executeEnqueue({
store,
options,
requestedAction,
build,
prepareBatch: vi.fn(),
onError,
log: { error: logError },
});
expect(result).toBeNull();
expect(onError).toHaveBeenCalledWith({ store, options, error });
expect(logError).toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,70 @@
import type { ActionCreatorWithoutPayload } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { paths } from 'services/api/schema';
export type EnqueueBatchArg =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];
export type EnqueueBatchResponse =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['responses']['201']['content']['application/json'];
export type EnqueueOptionsBase = { prepend: boolean };
interface ExecuteEnqueueConfig<TOptions extends EnqueueOptionsBase, TBuildResult> {
store: AppStore;
options: TOptions;
requestedAction: ActionCreatorWithoutPayload<string>;
build: (context: { store: AppStore; options: TOptions }) => Promise<TBuildResult | null>;
prepareBatch: (context: { store: AppStore; options: TOptions; buildResult: TBuildResult }) => EnqueueBatchArg;
onSuccess?: (context: {
store: AppStore;
options: TOptions;
buildResult: TBuildResult;
batch: EnqueueBatchArg;
response: EnqueueBatchResponse;
}) => void;
onError?: (context: { store: AppStore; options: TOptions; error: unknown }) => void;
log?: ReturnType<typeof logger>;
}
export const executeEnqueue = async <TOptions extends EnqueueOptionsBase, TBuildResult>({
store,
options,
requestedAction,
build,
prepareBatch,
onSuccess,
onError,
log = logger('enqueue'),
}: ExecuteEnqueueConfig<TOptions, TBuildResult>) => {
const { dispatch } = store;
dispatch(requestedAction());
try {
const buildResult = await build({ store, options });
if (!buildResult) {
return null;
}
const batchConfig = prepareBatch({ store, options, buildResult });
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
onSuccess?.({ store, options, buildResult, batch: batchConfig, response: enqueueResult });
return { batchConfig, enqueueResult };
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
onError?.({ store, options, error });
return null;
}
};

View File

@@ -0,0 +1,81 @@
import { describe, expect, it, vi } from 'vitest';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
import type { Invocation } from 'services/api/types';
import type { RootState } from 'app/store/store';
const mocks = vi.hoisted(() => {
const mockGraph: Graph = {} as Graph;
const mockPrompt = { id: 'prompt-node' } as Invocation<'string'>;
const asyncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
const syncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
return {
asyncReturnValue,
syncReturnValue,
buildSDXLGraphMock: vi.fn().mockResolvedValue(asyncReturnValue),
buildImagen3GraphMock: vi.fn().mockReturnValue(syncReturnValue),
createDefaultBuilder: () => vi.fn().mockResolvedValue(asyncReturnValue),
};
});
vi.mock('features/nodes/util/graph/generation/buildSDXLGraph', () => ({
buildSDXLGraph: mocks.buildSDXLGraphMock,
}));
vi.mock('features/nodes/util/graph/generation/buildSD1Graph', () => ({
buildSD1Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildSD3Graph', () => ({
buildSD3Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildFLUXGraph', () => ({
buildFLUXGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildFluxKontextGraph', () => ({
buildFluxKontextGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildCogView4Graph', () => ({
buildCogView4Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildImagen3Graph', () => ({
buildImagen3Graph: mocks.buildImagen3GraphMock,
}));
vi.mock('features/nodes/util/graph/generation/buildImagen4Graph', () => ({
buildImagen4Graph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildChatGPT4oGraph', () => ({
buildChatGPT4oGraph: mocks.createDefaultBuilder(),
}));
vi.mock('features/nodes/util/graph/generation/buildGemini2_5Graph', () => ({
buildGemini2_5Graph: mocks.createDefaultBuilder(),
}));
import { buildGraphForBase } from './graphBuilders';
describe('buildGraphForBase', () => {
const baseArg: GraphBuilderArg = {
generationMode: 'txt2img',
state: {} as RootState,
manager: null,
};
it('awaits asynchronous graph builders', async () => {
const result = await buildGraphForBase('sdxl', baseArg);
expect(result).toBe(mocks.asyncReturnValue);
expect(mocks.buildSDXLGraphMock).toHaveBeenCalledWith(baseArg);
});
it('supports synchronous graph builders', async () => {
const result = await buildGraphForBase('imagen3', baseArg);
expect(result).toBe(mocks.syncReturnValue);
expect(mocks.buildImagen3GraphMock).toHaveBeenCalledWith(baseArg);
});
it('throws for unknown base models', async () => {
await expect(buildGraphForBase('unknown-model', baseArg)).rejects.toThrow(
'No graph builders for base unknown-model'
);
});
});

View File

@@ -0,0 +1,34 @@
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { assert } from 'tsafe';
type GraphBuilderFn = (arg: GraphBuilderArg) => GraphBuilderReturn | Promise<GraphBuilderReturn>;
const graphBuilderMap: Record<string, GraphBuilderFn> = {
sdxl: buildSDXLGraph,
'sd-1': buildSD1Graph,
'sd-2': buildSD1Graph,
'sd-3': buildSD3Graph,
flux: buildFLUXGraph,
'flux-kontext': buildFluxKontextGraph,
cogview4: buildCogView4Graph,
imagen3: buildImagen3Graph,
imagen4: buildImagen4Graph,
'chatgpt-4o': buildChatGPT4oGraph,
'gemini-2.5': buildGemini2_5Graph,
};
export const buildGraphForBase = async (base: string, arg: GraphBuilderArg) => {
const builder = graphBuilderMap[base];
assert(builder, `No graph builders for base ${base}`);
return await builder(arg);
};