mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 16:18:06 -05:00
Compare commits
2 Commits
main
...
maryhipp/e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5d7471326 | ||
|
|
ae8bb9a9a7 |
@@ -27,6 +27,7 @@ export const zLogNamespace = z.enum([
|
||||
'queue',
|
||||
'workflows',
|
||||
'video',
|
||||
'enqueue',
|
||||
]);
|
||||
export type LogNamespace = z.infer<typeof zLogNamespace>;
|
||||
|
||||
|
||||
@@ -253,11 +253,11 @@ const PublishWorkflowButton = memo(() => {
|
||||
),
|
||||
duration: null,
|
||||
});
|
||||
assert(result.value.enqueueResult.batch.batch_id);
|
||||
assert(result.value.batchConfig.validation_run_data);
|
||||
assert(result.value?.enqueueResult.batch.batch_id);
|
||||
assert(result.value?.batchConfig.validation_run_data);
|
||||
$validationRunData.set({
|
||||
batchId: result.value.enqueueResult.batch.batch_id,
|
||||
workflowId: result.value.batchConfig.validation_run_data.workflow_id,
|
||||
batchId: result.value?.enqueueResult.batch.batch_id,
|
||||
workflowId: result.value?.batchConfig.validation_run_data.workflow_id,
|
||||
});
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
}
|
||||
|
||||
51
invokeai/frontend/web/src/features/queue/README.md
Normal file
51
invokeai/frontend/web/src/features/queue/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Queue Enqueue Patterns
|
||||
|
||||
This directory contains the hooks and utilities that translate UI actions into queue batches. The flow is intentionally
|
||||
modular so adding a new enqueue type (e.g. a new generation mode) follows a predictable recipe.
|
||||
|
||||
## Key building blocks
|
||||
|
||||
- `hooks/useEnqueue*.ts` – Feature-specific hooks (generate, canvas, upscaling, video, workflows). Each hook wires local
|
||||
state to the shared enqueue utilities.
|
||||
- `hooks/utils/graphBuilders.ts` – Maps base models (sdxl, flux, etc.) to their graph builder functions and normalizes
|
||||
synchronous vs. asynchronous builders.
|
||||
- `hooks/utils/executeEnqueue.ts` – Orchestrates the enqueue lifecycle:
|
||||
1. dispatch the `enqueueRequested*` action
|
||||
2. build the graph/batch data
|
||||
3. call `queueApi.endpoints.enqueueBatch`
|
||||
4. run success/error callbacks
|
||||
|
||||
## Adding a new enqueue type
|
||||
|
||||
1. **Implement the graph builder (if needed).**
|
||||
- Create the graph construction logic in `features/nodes/util/graph/generation/...` so it returns a
|
||||
`GraphBuilderReturn`.
|
||||
- If the builder reuses existing primitives, consider wiring it into `graphBuilders.ts` by extending the `graphBuilderMap`.
|
||||
|
||||
2. **Create the enqueue hook.**
|
||||
- Add `useEnqueue<Feature>.ts` mirroring the existing hooks. Import `executeEnqueue` and supply feature-specific
|
||||
`build`, `prepareBatch`, and `onSuccess` callbacks.
|
||||
- If the feature depends on a new base model, add it to `graphBuilders.ts`.
|
||||
|
||||
3. **Register the tab in `useInvoke`.**
|
||||
- `useInvoke.ts` looks up handlers based on the active tab. Import your new hook and call it inside the `switch`
|
||||
(or future registry) so the UI can enqueue from the feature.
|
||||
|
||||
4. **Add Redux action (optional).**
|
||||
- Most enqueue hooks dispatch a `enqueueRequested*` action for devtools visibility. Create one with `createAction` if
|
||||
you want similar tracing.
|
||||
|
||||
5. **Cover with tests.**
|
||||
- Unit-test feature-specific behavior (graph selection, batch tweaks). The shared helpers already have coverage in
|
||||
`hooks/utils/`.
|
||||
|
||||
## Tips
|
||||
|
||||
- Keep `build` lean: fetch state, compose graph/batch data, and return `null` when prerequisites are missing. The shared
|
||||
helper will skip enqueueing and your `onError` will handle logging.
|
||||
- Use the shared `prepareLinearUIBatch` for single-graph UI workflows. For advanced cases (multi-run batches, workflow
|
||||
validation runs), supply a custom `prepareBatch` function.
|
||||
- Prefer updating `graphBuilders.ts` when adding a new base model so every image-based enqueue automatically benefits.
|
||||
|
||||
With this structure, the main task when introducing a new enqueue type is describing how to build its graph and how to
|
||||
massage the batch payload—everything else (dispatching, API calls, history updates) is handled by the utilities.
|
||||
@@ -1,154 +1,114 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
import { AssertionError } from 'tsafe';
|
||||
|
||||
import type { EnqueueBatchArg } from './utils/executeEnqueue';
|
||||
import { executeEnqueue } from './utils/executeEnqueue';
|
||||
import { buildGraphForBase } from './utils/graphBuilders';
|
||||
|
||||
const log = logger('generation');
|
||||
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
|
||||
|
||||
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedCanvas());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const destination = selectCanvasDestination(state);
|
||||
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const generationMode = await canvasManager.compositor.getGenerationMode();
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
|
||||
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(graphBuilderArg);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(graphBuilderArg);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(graphBuilderArg);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(graphBuilderArg);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(graphBuilderArg);
|
||||
case 'imagen3':
|
||||
return buildImagen3Graph(graphBuilderArg);
|
||||
case 'imagen4':
|
||||
return buildImagen4Graph(graphBuilderArg);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(graphBuilderArg);
|
||||
case 'flux-kontext':
|
||||
return buildFluxKontextGraph(graphBuilderArg);
|
||||
case 'gemini-2.5':
|
||||
return buildGemini2_5Graph(graphBuilderArg);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seed, positivePrompt } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
// Push to prompt history on successful enqueue
|
||||
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
type CanvasBuildResult = {
|
||||
batchConfig: EnqueueBatchArg;
|
||||
};
|
||||
|
||||
export const useEnqueueCanvas = () => {
|
||||
const store = useAppStore();
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
if (!canvasManager) {
|
||||
log.error('Canvas manager is not available');
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
return enqueueCanvas(store, canvasManager, prepend);
|
||||
|
||||
return executeEnqueue({
|
||||
store,
|
||||
options: { prepend },
|
||||
requestedAction: enqueueRequestedCanvas,
|
||||
log,
|
||||
build: async ({ store: innerStore, options }) => {
|
||||
const state = innerStore.getState();
|
||||
|
||||
const destination = selectCanvasDestination(state);
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return null;
|
||||
}
|
||||
|
||||
const generationMode = await canvasManager.compositor.getGenerationMode();
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
|
||||
|
||||
const buildGraphResult = await withResultAsync(
|
||||
async () => await buildGraphForBase(model.base, graphBuilderArg)
|
||||
);
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({ status, title, description });
|
||||
return null;
|
||||
}
|
||||
|
||||
const { g, seed, positivePrompt } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base: model.base,
|
||||
prepend: options.prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
batchConfig: prepareBatchResult.value,
|
||||
} satisfies CanvasBuildResult;
|
||||
},
|
||||
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
|
||||
onSuccess: ({ store: innerStore }) => {
|
||||
const state = innerStore.getState();
|
||||
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
},
|
||||
});
|
||||
},
|
||||
[canvasManager, store]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
};
|
||||
|
||||
@@ -1,143 +1,103 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
import { AssertionError } from 'tsafe';
|
||||
|
||||
import type { EnqueueBatchArg } from './utils/executeEnqueue';
|
||||
import { executeEnqueue } from './utils/executeEnqueue';
|
||||
import { buildGraphForBase } from './utils/graphBuilders';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate');
|
||||
|
||||
const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedGenerate());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(graphBuilderArg);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(graphBuilderArg);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(graphBuilderArg);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(graphBuilderArg);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(graphBuilderArg);
|
||||
case 'imagen3':
|
||||
return buildImagen3Graph(graphBuilderArg);
|
||||
case 'imagen4':
|
||||
return buildImagen4Graph(graphBuilderArg);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(graphBuilderArg);
|
||||
case 'flux-kontext':
|
||||
return buildFluxKontextGraph(graphBuilderArg);
|
||||
case 'gemini-2.5':
|
||||
return buildGemini2_5Graph(graphBuilderArg);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seed, positivePrompt } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'generate',
|
||||
destination: 'generate',
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
// Push to prompt history on successful enqueue
|
||||
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
type GenerateBuildResult = {
|
||||
batchConfig: EnqueueBatchArg;
|
||||
};
|
||||
|
||||
export const useEnqueueGenerate = () => {
|
||||
const store = useAppStore();
|
||||
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
return enqueueGenerate(store, prepend);
|
||||
return executeEnqueue({
|
||||
store,
|
||||
options: { prepend },
|
||||
requestedAction: enqueueRequestedGenerate,
|
||||
log,
|
||||
build: async ({ store: innerStore, options }) => {
|
||||
const state = innerStore.getState();
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return null;
|
||||
}
|
||||
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
const buildGraphResult = await withResultAsync(
|
||||
async () => await buildGraphForBase(model.base, graphBuilderArg)
|
||||
);
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({ status, title, description });
|
||||
return null;
|
||||
}
|
||||
|
||||
const { g, seed, positivePrompt } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base: model.base,
|
||||
prepend: options.prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'generate',
|
||||
destination: 'generate',
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
batchConfig: prepareBatchResult.value,
|
||||
} satisfies GenerateBuildResult;
|
||||
},
|
||||
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
|
||||
onSuccess: ({ store: innerStore }) => {
|
||||
const state = innerStore.getState();
|
||||
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
},
|
||||
});
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
};
|
||||
|
||||
@@ -1,62 +1,64 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { positivePromptAddedToHistory, selectPositivePrompt } from 'features/controlLayers/store/paramsSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { useCallback } from 'react';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
import type { EnqueueBatchArg } from './utils/executeEnqueue';
|
||||
import { executeEnqueue } from './utils/executeEnqueue';
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedUpscaling());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
// Push to prompt history on successful enqueue
|
||||
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
type UpscaleBuildResult = {
|
||||
batchConfig: EnqueueBatchArg;
|
||||
};
|
||||
|
||||
export const useEnqueueUpscaling = () => {
|
||||
const store = useAppStore();
|
||||
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
return enqueueUpscaling(store, prepend);
|
||||
return executeEnqueue({
|
||||
store,
|
||||
options: { prepend },
|
||||
requestedAction: enqueueRequestedUpscaling,
|
||||
log,
|
||||
build: async ({ store: innerStore, options }) => {
|
||||
const state = innerStore.getState();
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return null;
|
||||
}
|
||||
|
||||
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base: model.base,
|
||||
prepend: options.prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
return { batchConfig } satisfies UpscaleBuildResult;
|
||||
},
|
||||
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
|
||||
onSuccess: ({ store: innerStore }) => {
|
||||
const state = innerStore.getState();
|
||||
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
},
|
||||
});
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
@@ -14,114 +13,107 @@ import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
import { AssertionError } from 'tsafe';
|
||||
|
||||
import type { EnqueueBatchArg } from './utils/executeEnqueue';
|
||||
import { executeEnqueue } from './utils/executeEnqueue';
|
||||
|
||||
const log = logger('generation');
|
||||
export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos');
|
||||
|
||||
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
type VideoBuildResult = {
|
||||
batchConfig: EnqueueBatchArg;
|
||||
};
|
||||
|
||||
dispatch(enqueueRequestedVideos());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.video.videoModel;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
const getVideoGraphBuilder = (base: string) => {
|
||||
switch (base) {
|
||||
case 'veo3':
|
||||
return buildVeo3VideoGraph;
|
||||
case 'runway':
|
||||
return buildRunwayVideoGraph;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
|
||||
switch (base) {
|
||||
case 'veo3':
|
||||
return await buildVeo3VideoGraph(graphBuilderArg);
|
||||
case 'runway':
|
||||
return await buildRunwayVideoGraph(graphBuilderArg);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, positivePrompt, seed } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'videos',
|
||||
destination: 'gallery',
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
// const batchConfig = {
|
||||
// prepend,
|
||||
// batch: {
|
||||
// graph: g.getGraph(),
|
||||
// runs: 1,
|
||||
// origin,
|
||||
// destination,
|
||||
// },
|
||||
// };
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
// Push to prompt history on successful enqueue
|
||||
dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueVideo = () => {
|
||||
const store = useAppStore();
|
||||
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
return enqueueVideo(store, prepend);
|
||||
return executeEnqueue({
|
||||
store,
|
||||
options: { prepend },
|
||||
requestedAction: enqueueRequestedVideos,
|
||||
log,
|
||||
build: async ({ store: innerStore, options }) => {
|
||||
const state = innerStore.getState();
|
||||
|
||||
const model = state.video.videoModel;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return null;
|
||||
}
|
||||
|
||||
const builder = getVideoGraphBuilder(model.base);
|
||||
if (!builder) {
|
||||
log.error({ base: model.base }, 'No graph builders for base');
|
||||
return null;
|
||||
}
|
||||
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
const buildGraphResult = await withResultAsync(async () => await builder(graphBuilderArg));
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({ status, title, description });
|
||||
return null;
|
||||
}
|
||||
|
||||
const { g, positivePrompt, seed } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base: model.base,
|
||||
prepend: options.prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'videos',
|
||||
destination: 'gallery',
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
batchConfig: prepareBatchResult.value,
|
||||
} satisfies VideoBuildResult;
|
||||
},
|
||||
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
|
||||
onSuccess: ({ store: innerStore }) => {
|
||||
const state = innerStore.getState();
|
||||
innerStore.dispatch(positivePromptAddedToHistory(selectPositivePrompt(state)));
|
||||
},
|
||||
});
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppDispatch, AppStore, RootState } from 'app/store/store';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { groupBy } from 'es-toolkit/compat';
|
||||
import {
|
||||
@@ -15,10 +15,11 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { useCallback } from 'react';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, EnqueueBatchArg, S } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { executeEnqueue } from './utils/executeEnqueue';
|
||||
|
||||
export const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
|
||||
|
||||
const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise<Batch['data']> => {
|
||||
@@ -119,60 +120,50 @@ const getValidationRunData = (state: RootState, templates: Templates): S['Valida
|
||||
};
|
||||
};
|
||||
|
||||
const enqueueWorkflows = async (
|
||||
store: AppStore,
|
||||
templates: Templates,
|
||||
prepend: boolean,
|
||||
isApiValidationRun: boolean
|
||||
) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedWorkflows());
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
const workflow = buildWorkflowWithValidation(nodesState);
|
||||
|
||||
if (workflow) {
|
||||
// embedded workflows don't have an id
|
||||
delete workflow.id;
|
||||
}
|
||||
|
||||
const runs = state.params.iterations;
|
||||
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow,
|
||||
runs,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend,
|
||||
};
|
||||
|
||||
if (isApiValidationRun) {
|
||||
batchConfig.validation_run_data = getValidationRunData(state, templates);
|
||||
|
||||
// If the batch is an API validation run, we only want to run it once
|
||||
batchConfig.batch.runs = 1;
|
||||
}
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueWorkflows = () => {
|
||||
const store = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean, isApiValidationRun: boolean) => {
|
||||
return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun);
|
||||
return executeEnqueue({
|
||||
store,
|
||||
options: { prepend, isApiValidationRun },
|
||||
requestedAction: enqueueRequestedWorkflows,
|
||||
build: async ({ store: innerStore, options }) => {
|
||||
const { dispatch, getState } = innerStore;
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const templates = $templates.get();
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
const workflow = buildWorkflowWithValidation(nodesState);
|
||||
|
||||
if (workflow) {
|
||||
// embedded workflows don't have an id
|
||||
delete workflow.id;
|
||||
}
|
||||
|
||||
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow,
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend: options.prepend,
|
||||
};
|
||||
|
||||
if (options.isApiValidationRun) {
|
||||
batchConfig.validation_run_data = getValidationRunData(state, templates);
|
||||
batchConfig.batch.runs = 1;
|
||||
}
|
||||
|
||||
return { batchConfig } satisfies { batchConfig: EnqueueBatchArg };
|
||||
},
|
||||
prepareBatch: ({ buildResult }) => buildResult.batchConfig,
|
||||
});
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStore, RootState } from 'app/store/store';
|
||||
import type { EnqueueBatchArg, EnqueueBatchResponse } from './executeEnqueue';
|
||||
import { executeEnqueue } from './executeEnqueue';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
const createTestStore = () => {
|
||||
const state = {} as RootState;
|
||||
const dispatch = vi.fn<(action: unknown) => unknown>((action) => {
|
||||
if (typeof action === 'object' && action !== null && 'type' in action) {
|
||||
return undefined;
|
||||
}
|
||||
const unwrap = vi.fn<() => Promise<EnqueueBatchResponse>>().mockResolvedValue({
|
||||
batch_id: 'batch-1',
|
||||
item_ids: ['item-1'],
|
||||
} as EnqueueBatchResponse);
|
||||
return { unwrap };
|
||||
});
|
||||
const getState = vi.fn(() => state);
|
||||
return { dispatch, getState } as unknown as AppStore;
|
||||
};
|
||||
|
||||
const createBatchArg = (prepend: boolean): EnqueueBatchArg => ({
|
||||
prepend,
|
||||
batch: {
|
||||
graph: {} as EnqueueBatchArg['batch']['graph'],
|
||||
runs: 1,
|
||||
data: [],
|
||||
origin: 'test',
|
||||
destination: 'test',
|
||||
},
|
||||
});
|
||||
|
||||
describe('executeEnqueue', () => {
|
||||
it('dispatches enqueue flow and invokes success callback', async () => {
|
||||
const store = createTestStore();
|
||||
const requestedAction = createAction('test/enqueue');
|
||||
const options = { prepend: false } as const;
|
||||
const batchConfig = createBatchArg(options.prepend);
|
||||
const onSuccess = vi.fn();
|
||||
const build = vi.fn(async () => ({ batchConfig }));
|
||||
const prepareBatch = vi.fn(() => batchConfig);
|
||||
|
||||
const result = await executeEnqueue({
|
||||
store,
|
||||
options,
|
||||
requestedAction,
|
||||
build,
|
||||
prepareBatch,
|
||||
onSuccess,
|
||||
log: { error: vi.fn() },
|
||||
});
|
||||
|
||||
expect(store.dispatch).toHaveBeenCalledWith(requestedAction());
|
||||
expect(build).toHaveBeenCalledWith({ store, options });
|
||||
expect(prepareBatch).toHaveBeenCalledWith({ store, options, buildResult: { batchConfig } });
|
||||
expect(onSuccess).toHaveBeenCalled();
|
||||
expect(result?.batchConfig).toBe(batchConfig);
|
||||
});
|
||||
|
||||
it('stops when build returns null', async () => {
|
||||
const store = createTestStore();
|
||||
const requestedAction = createAction('test/enqueue');
|
||||
const options = { prepend: true } as const;
|
||||
const build = vi.fn(async () => null);
|
||||
const prepareBatch = vi.fn();
|
||||
|
||||
const result = await executeEnqueue({
|
||||
store,
|
||||
options,
|
||||
requestedAction,
|
||||
build,
|
||||
prepareBatch,
|
||||
log: { error: vi.fn() },
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(build).toHaveBeenCalled();
|
||||
expect(prepareBatch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('invokes onError when build throws', async () => {
|
||||
const store = createTestStore();
|
||||
const requestedAction = createAction('test/enqueue');
|
||||
const options = { prepend: false } as const;
|
||||
const error = new Error('boom');
|
||||
const build = vi.fn(async () => {
|
||||
throw error;
|
||||
});
|
||||
const onError = vi.fn();
|
||||
const logError = vi.fn();
|
||||
|
||||
const result = await executeEnqueue({
|
||||
store,
|
||||
options,
|
||||
requestedAction,
|
||||
build,
|
||||
prepareBatch: vi.fn(),
|
||||
onError,
|
||||
log: { error: logError },
|
||||
});
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(onError).toHaveBeenCalledWith({ store, options, error });
|
||||
expect(logError).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,70 @@
|
||||
import type { ActionCreatorWithoutPayload } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { paths } from 'services/api/schema';
|
||||
|
||||
export type EnqueueBatchArg =
|
||||
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];
|
||||
export type EnqueueBatchResponse =
|
||||
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
export type EnqueueOptionsBase = { prepend: boolean };
|
||||
|
||||
interface ExecuteEnqueueConfig<TOptions extends EnqueueOptionsBase, TBuildResult> {
|
||||
store: AppStore;
|
||||
options: TOptions;
|
||||
requestedAction: ActionCreatorWithoutPayload<string>;
|
||||
build: (context: { store: AppStore; options: TOptions }) => Promise<TBuildResult | null>;
|
||||
prepareBatch: (context: { store: AppStore; options: TOptions; buildResult: TBuildResult }) => EnqueueBatchArg;
|
||||
onSuccess?: (context: {
|
||||
store: AppStore;
|
||||
options: TOptions;
|
||||
buildResult: TBuildResult;
|
||||
batch: EnqueueBatchArg;
|
||||
response: EnqueueBatchResponse;
|
||||
}) => void;
|
||||
onError?: (context: { store: AppStore; options: TOptions; error: unknown }) => void;
|
||||
log?: ReturnType<typeof logger>;
|
||||
}
|
||||
|
||||
export const executeEnqueue = async <TOptions extends EnqueueOptionsBase, TBuildResult>({
|
||||
store,
|
||||
options,
|
||||
requestedAction,
|
||||
build,
|
||||
prepareBatch,
|
||||
onSuccess,
|
||||
onError,
|
||||
log = logger('enqueue'),
|
||||
}: ExecuteEnqueueConfig<TOptions, TBuildResult>) => {
|
||||
const { dispatch } = store;
|
||||
dispatch(requestedAction());
|
||||
|
||||
try {
|
||||
const buildResult = await build({ store, options });
|
||||
if (!buildResult) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatch({ store, options, buildResult });
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
onSuccess?.({ store, options, buildResult, batch: batchConfig, response: enqueueResult });
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
onError?.({ store, options, error });
|
||||
return null;
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,81 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { RootState } from 'app/store/store';
|
||||
|
||||
const mocks = vi.hoisted(() => {
|
||||
const mockGraph: Graph = {} as Graph;
|
||||
const mockPrompt = { id: 'prompt-node' } as Invocation<'string'>;
|
||||
const asyncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
|
||||
const syncReturnValue = { g: mockGraph, positivePrompt: mockPrompt };
|
||||
|
||||
return {
|
||||
asyncReturnValue,
|
||||
syncReturnValue,
|
||||
buildSDXLGraphMock: vi.fn().mockResolvedValue(asyncReturnValue),
|
||||
buildImagen3GraphMock: vi.fn().mockReturnValue(syncReturnValue),
|
||||
createDefaultBuilder: () => vi.fn().mockResolvedValue(asyncReturnValue),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('features/nodes/util/graph/generation/buildSDXLGraph', () => ({
|
||||
buildSDXLGraph: mocks.buildSDXLGraphMock,
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildSD1Graph', () => ({
|
||||
buildSD1Graph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildSD3Graph', () => ({
|
||||
buildSD3Graph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildFLUXGraph', () => ({
|
||||
buildFLUXGraph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildFluxKontextGraph', () => ({
|
||||
buildFluxKontextGraph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildCogView4Graph', () => ({
|
||||
buildCogView4Graph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildImagen3Graph', () => ({
|
||||
buildImagen3Graph: mocks.buildImagen3GraphMock,
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildImagen4Graph', () => ({
|
||||
buildImagen4Graph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildChatGPT4oGraph', () => ({
|
||||
buildChatGPT4oGraph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
vi.mock('features/nodes/util/graph/generation/buildGemini2_5Graph', () => ({
|
||||
buildGemini2_5Graph: mocks.createDefaultBuilder(),
|
||||
}));
|
||||
|
||||
import { buildGraphForBase } from './graphBuilders';
|
||||
|
||||
describe('buildGraphForBase', () => {
|
||||
const baseArg: GraphBuilderArg = {
|
||||
generationMode: 'txt2img',
|
||||
state: {} as RootState,
|
||||
manager: null,
|
||||
};
|
||||
|
||||
it('awaits asynchronous graph builders', async () => {
|
||||
const result = await buildGraphForBase('sdxl', baseArg);
|
||||
|
||||
expect(result).toBe(mocks.asyncReturnValue);
|
||||
expect(mocks.buildSDXLGraphMock).toHaveBeenCalledWith(baseArg);
|
||||
});
|
||||
|
||||
it('supports synchronous graph builders', async () => {
|
||||
const result = await buildGraphForBase('imagen3', baseArg);
|
||||
|
||||
expect(result).toBe(mocks.syncReturnValue);
|
||||
expect(mocks.buildImagen3GraphMock).toHaveBeenCalledWith(baseArg);
|
||||
});
|
||||
|
||||
it('throws for unknown base models', async () => {
|
||||
await expect(buildGraphForBase('unknown-model', baseArg)).rejects.toThrow(
|
||||
'No graph builders for base unknown-model'
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,34 @@
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildGemini2_5Graph } from 'features/nodes/util/graph/generation/buildGemini2_5Graph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
type GraphBuilderFn = (arg: GraphBuilderArg) => GraphBuilderReturn | Promise<GraphBuilderReturn>;
|
||||
|
||||
const graphBuilderMap: Record<string, GraphBuilderFn> = {
|
||||
sdxl: buildSDXLGraph,
|
||||
'sd-1': buildSD1Graph,
|
||||
'sd-2': buildSD1Graph,
|
||||
'sd-3': buildSD3Graph,
|
||||
flux: buildFLUXGraph,
|
||||
'flux-kontext': buildFluxKontextGraph,
|
||||
cogview4: buildCogView4Graph,
|
||||
imagen3: buildImagen3Graph,
|
||||
imagen4: buildImagen4Graph,
|
||||
'chatgpt-4o': buildChatGPT4oGraph,
|
||||
'gemini-2.5': buildGemini2_5Graph,
|
||||
};
|
||||
|
||||
export const buildGraphForBase = async (base: string, arg: GraphBuilderArg) => {
|
||||
const builder = graphBuilderMap[base];
|
||||
assert(builder, `No graph builders for base ${base}`);
|
||||
return await builder(arg);
|
||||
};
|
||||
Reference in New Issue
Block a user