fix(ui): iterations works for video models

This commit is contained in:
psychedelicious
2025-08-22 19:01:39 +10:00
committed by Mary Hipp Rogers
parent b9e32e59a2
commit dbb9032648
6 changed files with 81 additions and 54 deletions

View File

@@ -4,22 +4,22 @@ import { range } from 'es-toolkit/compat';
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { API_BASE_MODELS, VIDEO_BASE_MODELS } from 'features/parameters/types/constants';
import type { components } from 'services/api/schema';
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import type { AnyModelConfig, BaseModelType, Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const getExtendedPrompts = (arg: {
seedBehaviour: SeedBehaviour;
iterations: number;
prompts: string[];
model: ModelIdentifierField;
base: BaseModelType;
}): string[] => {
const { seedBehaviour, iterations, prompts, model } = arg;
const { seedBehaviour, iterations, prompts, base } = arg;
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
// this, we need to create a batch of the right size by repeating the prompts.
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(model.base)) {
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(base) || VIDEO_BASE_MODELS.includes(base)) {
return range(iterations).flatMap(() => prompts);
}
return prompts;
@@ -29,17 +29,16 @@ export const prepareLinearUIBatch = (arg: {
state: RootState;
g: Graph;
prepend: boolean;
base: BaseModelType;
positivePromptNode: Invocation<'string'>;
seedNode?: Invocation<'integer'>;
origin: string;
destination: string;
}): EnqueueBatchArg => {
const { state, g, prepend, positivePromptNode, seedNode, origin, destination } = arg;
const { iterations, model, shouldRandomizeSeed, seed } = state.params;
const { state, g, base, prepend, positivePromptNode, seedNode, origin, destination } = arg;
const { iterations, shouldRandomizeSeed, seed } = state.params;
const { prompts, seedBehaviour } = state.dynamicPrompts;
assert(model, 'No model found in state when preparing batch');
const data: Batch['data'] = [];
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
@@ -63,6 +62,7 @@ export const prepareLinearUIBatch = (arg: {
start: shouldRandomizeSeed ? undefined : seed,
});
console.log(seeds);
secondBatchDatumList.push({
node_path: seedNode.id,
field_name: 'value',
@@ -71,7 +71,7 @@ export const prepareLinearUIBatch = (arg: {
data.push(secondBatchDatumList);
}
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, base });
// zipped batch of prompts
firstBatchDatumList.push({

View File

@@ -39,11 +39,15 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
const destination = selectCanvasDestination(state);
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const generationMode = await canvasManager.compositor.getGenerationMode();
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
@@ -101,6 +105,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,

View File

@@ -37,11 +37,14 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
const destination = 'generate';
const buildGraphResult = await withResultAsync(async () => {
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
@@ -8,6 +9,8 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
const log = logger('generation');
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
@@ -15,11 +18,19 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const state = getState();
const model = state.params.model;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,

View File

@@ -4,7 +4,8 @@ import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResultAsync } from 'common/util/result';
import { withResult, withResultAsync } from 'common/util/result';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildRunwayVideoGraph } from 'features/nodes/util/graph/generation/buildRunwayVideoGraph';
import { buildVeo3VideoGraph } from 'features/nodes/util/graph/generation/buildVeo3VideoGraph';
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
@@ -17,21 +18,23 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos');
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
dispatch(enqueueRequestedVideos());
const state = getState();
const destination = selectCanvasDestination(state);
const model = state.video.videoModel;
if (!model) {
log.error('No model found in state');
return;
}
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
const model = state.video.videoModel;
assert(model, 'No model found in state');
const base = model.base;
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
switch (base) {
@@ -65,36 +68,37 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
return;
}
const { g } = buildGraphResult.value;
const { g, positivePrompt, seed } = buildGraphResult.value;
// const prepareBatchResult = withResult(() =>
// prepareLinearUIBatch({
// state,
// g,
// prepend,
// seedNode: seed,
// positivePromptNode: positivePrompt,
// origin: 'canvas',
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
base,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'videos',
destination: 'gallery',
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
// const batchConfig = {
// prepend,
// batch: {
// graph: g.getGraph(),
// runs: 1,
// origin,
// destination,
// })
// );
// if (prepareBatchResult.isErr()) {
// log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
// return;
// }
// const batchConfig = prepareBatchResult.value;
const batchConfig = {
prepend,
batch: {
graph: g.getGraph(),
runs: 1,
origin,
destination,
},
};
// },
// };
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@@ -261,7 +261,11 @@ const getReasonsWhyCannotEnqueueVideoTab = (arg: {
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
}
if (!video.startingFrameImage?.image_name) {
if (!video.videoModel) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (video.videoModel?.base === 'runway' && !video.startingFrameImage?.image_name) {
reasons.push({ content: i18n.t('parameters.invoke.noStartingFrameImage') });
}