mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
fix(ui): iterations works for video models
This commit is contained in:
committed by
Mary Hipp Rogers
parent
b9e32e59a2
commit
dbb9032648
@@ -4,22 +4,22 @@ import { range } from 'es-toolkit/compat';
|
||||
import type { SeedBehaviour } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { API_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import { API_BASE_MODELS, VIDEO_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
|
||||
import type { AnyModelConfig, BaseModelType, Batch, EnqueueBatchArg, Invocation } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const getExtendedPrompts = (arg: {
|
||||
seedBehaviour: SeedBehaviour;
|
||||
iterations: number;
|
||||
prompts: string[];
|
||||
model: ModelIdentifierField;
|
||||
base: BaseModelType;
|
||||
}): string[] => {
|
||||
const { seedBehaviour, iterations, prompts, model } = arg;
|
||||
const { seedBehaviour, iterations, prompts, base } = arg;
|
||||
// Normally, the seed behaviour implicity determines the batch size. But when we use models without seeds (like
|
||||
// ChatGPT 4o) in conjunction with the per-prompt seed behaviour, we lose out on that implicit batch size. To rectify
|
||||
// this, we need to create a batch of the right size by repeating the prompts.
|
||||
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(model.base)) {
|
||||
if (seedBehaviour === 'PER_PROMPT' || API_BASE_MODELS.includes(base) || VIDEO_BASE_MODELS.includes(base)) {
|
||||
return range(iterations).flatMap(() => prompts);
|
||||
}
|
||||
return prompts;
|
||||
@@ -29,17 +29,16 @@ export const prepareLinearUIBatch = (arg: {
|
||||
state: RootState;
|
||||
g: Graph;
|
||||
prepend: boolean;
|
||||
base: BaseModelType;
|
||||
positivePromptNode: Invocation<'string'>;
|
||||
seedNode?: Invocation<'integer'>;
|
||||
origin: string;
|
||||
destination: string;
|
||||
}): EnqueueBatchArg => {
|
||||
const { state, g, prepend, positivePromptNode, seedNode, origin, destination } = arg;
|
||||
const { iterations, model, shouldRandomizeSeed, seed } = state.params;
|
||||
const { state, g, base, prepend, positivePromptNode, seedNode, origin, destination } = arg;
|
||||
const { iterations, shouldRandomizeSeed, seed } = state.params;
|
||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||
|
||||
assert(model, 'No model found in state when preparing batch');
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
const firstBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
const secondBatchDatumList: components['schemas']['BatchDatum'][] = [];
|
||||
@@ -63,6 +62,7 @@ export const prepareLinearUIBatch = (arg: {
|
||||
start: shouldRandomizeSeed ? undefined : seed,
|
||||
});
|
||||
|
||||
console.log(seeds);
|
||||
secondBatchDatumList.push({
|
||||
node_path: seedNode.id,
|
||||
field_name: 'value',
|
||||
@@ -71,7 +71,7 @@ export const prepareLinearUIBatch = (arg: {
|
||||
data.push(secondBatchDatumList);
|
||||
}
|
||||
|
||||
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, model });
|
||||
const extendedPrompts = getExtendedPrompts({ seedBehaviour, iterations, prompts, base });
|
||||
|
||||
// zipped batch of prompts
|
||||
firstBatchDatumList.push({
|
||||
|
||||
@@ -39,11 +39,15 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
|
||||
|
||||
const destination = selectCanvasDestination(state);
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const generationMode = await canvasManager.compositor.getGenerationMode();
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode, state, manager: canvasManager };
|
||||
|
||||
@@ -101,6 +105,7 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
|
||||
@@ -37,11 +37,14 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
|
||||
const destination = 'generate';
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
|
||||
switch (base) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
@@ -8,6 +9,8 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
@@ -15,11 +18,19 @@ const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const { g, seed, positivePrompt } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
|
||||
@@ -4,7 +4,8 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildRunwayVideoGraph } from 'features/nodes/util/graph/generation/buildRunwayVideoGraph';
|
||||
import { buildVeo3VideoGraph } from 'features/nodes/util/graph/generation/buildVeo3VideoGraph';
|
||||
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
@@ -17,21 +18,23 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
|
||||
export const enqueueRequestedVideos = createAction('app/enqueueRequestedVideos');
|
||||
|
||||
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedCanvas());
|
||||
dispatch(enqueueRequestedVideos());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const destination = selectCanvasDestination(state);
|
||||
const model = state.video.videoModel;
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
return;
|
||||
}
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
const model = state.video.videoModel;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
|
||||
switch (base) {
|
||||
@@ -65,36 +68,37 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { g } = buildGraphResult.value;
|
||||
const { g, positivePrompt, seed } = buildGraphResult.value;
|
||||
|
||||
// const prepareBatchResult = withResult(() =>
|
||||
// prepareLinearUIBatch({
|
||||
// state,
|
||||
// g,
|
||||
// prepend,
|
||||
// seedNode: seed,
|
||||
// positivePromptNode: positivePrompt,
|
||||
// origin: 'canvas',
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
base,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'videos',
|
||||
destination: 'gallery',
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
// const batchConfig = {
|
||||
// prepend,
|
||||
// batch: {
|
||||
// graph: g.getGraph(),
|
||||
// runs: 1,
|
||||
// origin,
|
||||
// destination,
|
||||
// })
|
||||
// );
|
||||
|
||||
// if (prepareBatchResult.isErr()) {
|
||||
// log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
// return;
|
||||
// }
|
||||
|
||||
// const batchConfig = prepareBatchResult.value;
|
||||
|
||||
const batchConfig = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: g.getGraph(),
|
||||
runs: 1,
|
||||
origin,
|
||||
destination,
|
||||
},
|
||||
};
|
||||
// },
|
||||
// };
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
|
||||
@@ -261,7 +261,11 @@ const getReasonsWhyCannotEnqueueVideoTab = (arg: {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.promptExpansionResultPending') });
|
||||
}
|
||||
|
||||
if (!video.startingFrameImage?.image_name) {
|
||||
if (!video.videoModel) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
if (video.videoModel?.base === 'runway' && !video.startingFrameImage?.image_name) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noStartingFrameImage') });
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user