mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): use correct model config object in video graph builders
This commit is contained in:
committed by
Mary Hipp Rogers
parent
5cabc37a87
commit
9fcba3b876
@@ -6,7 +6,11 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectStartingFrameImage, selectVideoSlice } from 'features/parameters/store/videoSlice';
|
||||
import {
|
||||
selectStartingFrameImage,
|
||||
selectVideoModelConfig,
|
||||
selectVideoSlice,
|
||||
} from 'features/parameters/store/videoSlice';
|
||||
import { t } from 'i18next';
|
||||
import type { VideoApiModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -23,6 +27,10 @@ export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn
|
||||
throw new UnsupportedGenerationModeError(t('toast.runwayIncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
const model = selectVideoModelConfig(state);
|
||||
assert(model, 'No model selected');
|
||||
assert(model.base === 'runway', 'Selected model is not a Runway model');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const videoParams = selectVideoSlice(state);
|
||||
const prompts = selectPresetModifiedPrompts(state);
|
||||
@@ -34,9 +42,7 @@ export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn
|
||||
const firstFrameImageField = zImageField.parse(startingFrameImage);
|
||||
|
||||
const { seed, shouldRandomizeSeed } = params;
|
||||
const { videoModel, videoDuration, videoAspectRatio, videoResolution } = videoParams;
|
||||
|
||||
assert(videoModel, 'Runway video requires a model');
|
||||
const { videoDuration, videoAspectRatio, videoResolution } = videoParams;
|
||||
|
||||
const finalSeed = shouldRandomizeSeed ? undefined : seed;
|
||||
|
||||
@@ -64,9 +70,8 @@ export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn
|
||||
|
||||
// Set up metadata
|
||||
g.upsertMetadata({
|
||||
model: Graph.getModelMetadataField(videoModel as VideoApiModelConfig),
|
||||
model: Graph.getModelMetadataField(model),
|
||||
positive_prompt: prompts.positive,
|
||||
negative_prompt: prompts.negative || '',
|
||||
duration: videoDuration,
|
||||
aspect_ratio: videoAspectRatio,
|
||||
resolution: videoResolution,
|
||||
|
||||
@@ -6,7 +6,11 @@ import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { selectStartingFrameImage, selectVideoSlice } from 'features/parameters/store/videoSlice';
|
||||
import {
|
||||
selectStartingFrameImage,
|
||||
selectVideoModelConfig,
|
||||
selectVideoSlice,
|
||||
} from 'features/parameters/store/videoSlice';
|
||||
import { t } from 'i18next';
|
||||
import type { VideoApiModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -23,17 +27,19 @@ export const buildVeo3VideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn =>
|
||||
throw new UnsupportedGenerationModeError(t('toast.veo3IncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
const model = selectVideoModelConfig(state);
|
||||
assert(model, 'No model selected');
|
||||
assert(model.base === 'runway', 'Selected model is not a Veo3 model');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const videoParams = selectVideoSlice(state);
|
||||
const prompts = selectPresetModifiedPrompts(state);
|
||||
assert(prompts.positive.length > 0, 'Veo3 video requires positive prompt to have at least one character');
|
||||
|
||||
const { seed, shouldRandomizeSeed } = params;
|
||||
const { videoModel, videoResolution, videoDuration, videoAspectRatio } = videoParams;
|
||||
const { videoResolution, videoDuration, videoAspectRatio } = videoParams;
|
||||
const finalSeed = shouldRandomizeSeed ? undefined : seed;
|
||||
|
||||
assert(videoModel, 'Veo3 video requires a model');
|
||||
|
||||
const g = new Graph(getPrefixedId('veo3_video_graph'));
|
||||
|
||||
const positivePrompt = g.addNode({
|
||||
@@ -47,7 +53,7 @@ export const buildVeo3VideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn =>
|
||||
id: getPrefixedId('google_veo_3_generate_video'),
|
||||
// @ts-expect-error: This node is not available in the OSS application
|
||||
type: 'google_veo_3_generate_video',
|
||||
model: videoModel,
|
||||
model: model,
|
||||
aspect_ratio: '16:9',
|
||||
resolution: videoResolution,
|
||||
seed: finalSeed,
|
||||
@@ -66,9 +72,8 @@ export const buildVeo3VideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn =>
|
||||
|
||||
// Set up metadata
|
||||
g.upsertMetadata({
|
||||
model: Graph.getModelMetadataField(videoModel as VideoApiModelConfig),
|
||||
model: Graph.getModelMetadataField(model),
|
||||
positive_prompt: prompts.positive,
|
||||
negative_prompt: prompts.negative || '',
|
||||
duration: videoDuration,
|
||||
aspect_ratio: videoAspectRatio,
|
||||
resolution: videoResolution,
|
||||
|
||||
Reference in New Issue
Block a user