combine nodes that generate and save videos

This commit is contained in:
Mary Hipp
2025-08-12 11:10:08 -04:00
committed by psychedelicious
parent d8fcf18b6c
commit 6afe995dbe
9 changed files with 19 additions and 310 deletions

View File

@@ -188,3 +188,11 @@ export const zImageOutput = z.object({
});
export type ImageOutput = z.infer<typeof zImageOutput>;
// #endregion
// #region ImageOutput
export const zVideoOutput = z.object({
video_id: z.string().trim().min(1),
type: z.literal('video_output'),
});
export type VideoOutput = z.infer<typeof zVideoOutput>;
// #endregion

View File

@@ -1,118 +0,0 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectVideoFirstFrameImage, selectVideoLastFrameImage } from 'features/parameters/store/videoSlice';
import { zImageField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { assert } from 'tsafe';
const log = logger('system');
// Default video parameters - these could be moved to a video params slice in the future
const DEFAULT_VIDEO_DURATION = 5;
const DEFAULT_VIDEO_ASPECT_RATIO = "1280:768"; // Default landscape
const DEFAULT_ENHANCE_PROMPT = true;
// Video parameter extraction helper
const getVideoParameters = (state: RootState) => {
// In the future, these could come from a dedicated video parameters slice
// For now, we use defaults but allow them to be overridden by any video-specific state
return {
duration: DEFAULT_VIDEO_DURATION,
aspectRatio: DEFAULT_VIDEO_ASPECT_RATIO,
enhancePrompt: DEFAULT_ENHANCE_PROMPT,
};
};
export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
const { generationMode, state, manager } = arg;
log.debug({ generationMode, manager: manager?.id }, 'Building Runway video graph');
// Runway video generation supports text-to-video and image-to-video
// We can support multiple generation modes depending on whether frame images are provided
const supportedModes = ['txt2img'] as const;
if (!supportedModes.includes(generationMode as any)) {
throw new UnsupportedGenerationModeError(t('toast.runwayIncompatibleGenerationMode'));
}
const params = selectParamsSlice(state);
const prompts = selectPresetModifiedPrompts(state);
const videoFirstFrameImage = selectVideoFirstFrameImage(state);
const videoLastFrameImage = selectVideoLastFrameImage(state);
const videoParams = getVideoParameters(state);
// Get seed from params
const { seed, shouldRandomizeSeed } = params;
const finalSeed = shouldRandomizeSeed ? undefined : seed;
// Determine if this is image-to-video or text-to-video
const hasFrameImages = videoFirstFrameImage || videoLastFrameImage;
const g = new Graph(getPrefixedId('runway_video_graph'));
const positivePrompt = g.addNode({
id: getPrefixedId('positive_prompt'),
type: 'string',
value: prompts.positive,
});
// Create the runway video generation node
const runwayVideoNode = g.addNode({
id: getPrefixedId('runway_generate_video'),
// @ts-expect-error: This node is not available in the OSS application
type: 'runway_generate_video',
duration: videoParams.duration,
aspect_ratio: videoParams.aspectRatio,
seed: finalSeed,
});
// @ts-expect-error: This node is not available in the OSS application
g.addEdge(positivePrompt, 'value', runwayVideoNode, 'prompt');
// Add first frame image if provided
if (videoFirstFrameImage) {
const firstFrameImageField = zImageField.parse(videoFirstFrameImage);
// @ts-expect-error: This connection is specific to runway node
runwayVideoNode.first_frame_image = firstFrameImageField;
}
// Add last frame image if provided
if (videoLastFrameImage) {
const lastFrameImageField = zImageField.parse(videoLastFrameImage);
// @ts-expect-error: This connection is specific to runway node
runwayVideoNode.last_frame_image = lastFrameImageField;
}
// Set up metadata
g.upsertMetadata({
positive_prompt: prompts.positive,
negative_prompt: prompts.negative || '',
video_duration: videoParams.duration,
video_aspect_ratio: videoParams.aspectRatio,
seed: finalSeed,
enhance_prompt: videoParams.enhancePrompt,
generation_type: hasFrameImages ? 'image-to-video' : 'text-to-video',
});
// Add video frame images to metadata if they exist
if (hasFrameImages) {
g.upsertMetadata({
first_frame_image: videoFirstFrameImage,
last_frame_image: videoLastFrameImage,
}, 'merge');
}
g.setMetadataReceivingNode(runwayVideoNode);
return {
g,
positivePrompt,
};
};

View File

@@ -5,6 +5,7 @@ import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { VideoOutput, zVideoOutput } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
@@ -12,10 +13,7 @@ const zVideoState = z.object({
_version: z.literal(1),
videoFirstFrameImage: zImageWithDims.nullable(),
videoLastFrameImage: zImageWithDims.nullable(),
generatedVideo: z.object({
url: z.string(),
taskId: z.number(),
}).nullable(),
generatedVideo: zVideoOutput.nullable(),
});
export type VideoState = z.infer<typeof zVideoState>;
@@ -39,7 +37,7 @@ const slice = createSlice({
state.videoLastFrameImage = action.payload;
},
generatedVideoChanged: (state, action: PayloadAction<{ url: string, taskId: number } | null>) => {
generatedVideoChanged: (state, action: PayloadAction<VideoOutput | null>) => {
state.generatedVideo = action.payload;
},

View File

@@ -1,13 +1,10 @@
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { useFocusRegion } from 'common/hooks/focus';
import { memo, useCallback, useRef } from 'react';
import { memo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import ReactPlayer from 'react-player';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
import { PiCheckBold } from 'react-icons/pi';
import { useDispatch } from 'react-redux';
import { saveVideo } from 'features/video/saveVideo';
export const VideoPlayerPanel = memo(() => {
@@ -17,18 +14,6 @@ export const VideoPlayerPanel = memo(() => {
useFocusRegion('video', ref);
const { dispatch, getState } = useAppStore();
const handleSaveVideo = useCallback(() => {
console.log('generatedVideo', generatedVideo);
if (!generatedVideo?.taskId) {
return
}
console.log('saving video', generatedVideo.taskId);
saveVideo({ dispatch, getState, taskId: `${generatedVideo.taskId}` });
}, [dispatch, getState, generatedVideo]);
return (
<Flex ref={ref} w="full" h="full" flexDirection="column" gap={4}>
@@ -36,16 +21,15 @@ export const VideoPlayerPanel = memo(() => {
{generatedVideo &&
<>
<Box flex={0.75} position="relative" >
<ReactPlayer
{/* <ReactPlayer
src={generatedVideo.url}
width="75%"
height="75%"
controls={true}
style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)', maxWidth: '900px' }}
/>
/> */}
</Box>
<Button leftIcon={<PiCheckBold />} colorScheme="invokeBlue" onClick={handleSaveVideo}>Keep</Button>
</>}
{!generatedVideo && <Text>No video generated</Text>}

View File

@@ -1,25 +0,0 @@
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import { assert } from 'tsafe';
export const buildSaveVideoGraph = ({
state,
}: {
state: RootState;
}): { graph: Graph; outputNodeId: string } => {
const taskId = state.video.generatedVideo?.taskId;
assert(taskId, 'No task ID found in state');
const graph = new Graph(getPrefixedId('save-video-graph'));
const outputNode = graph.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'save_runway_video',
id: getPrefixedId('save_runway_video'),
runway_task_id: taskId,
});
return { graph, outputNodeId: outputNode.id };
};

View File

@@ -1,41 +0,0 @@
import type { AppDispatch, AppGetState } from 'app/store/store';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
import { $socket } from 'services/events/stores';
import { assert } from 'tsafe';
import { buildSaveVideoGraph } from './graph';
import { saveVideoApi } from './state';
export const saveVideo = async (arg: { dispatch: AppDispatch; getState: AppGetState; taskId?: string }) => {
const { dispatch, getState, taskId } = arg;
const socket = $socket.get();
if (!socket) {
return;
}
const { graph, outputNodeId } = buildSaveVideoGraph({
state: getState(),
});
const dependencies = buildRunGraphDependencies(dispatch, socket);
try {
const { output } = await runGraph({
graph,
outputNodeId,
dependencies,
options: {
prepend: true,
},
});
assert(output.type === 'string_output');
saveVideoApi.setSuccess(output.value);
} catch {
saveVideoApi.reset();
toast({
id: 'SAVE_VIDEO_FAILED',
title: t('toast.saveVideoFailed'),
status: 'error',
});
}
};

View File

@@ -1,98 +0,0 @@
import { deepClone } from 'common/util/deepClone';
import { atom } from 'nanostores';
import type { ImageDTO } from 'services/api/types';
type SuccessState = {
isSuccess: true;
isError: false;
isPending: false;
result: string;
error: null;
imageDTO?: ImageDTO;
};
type ErrorState = {
isSuccess: false;
isError: true;
isPending: false;
result: null;
error: Error;
imageDTO?: ImageDTO;
};
type PendingState = {
isSuccess: false;
isError: false;
isPending: true;
result: null;
error: null;
imageDTO?: ImageDTO;
};
type IdleState = {
isSuccess: false;
isError: false;
isPending: false;
result: null;
error: null;
imageDTO?: ImageDTO;
};
export type PromptExpansionRequestState = IdleState | PendingState | SuccessState | ErrorState;
const IDLE_STATE: IdleState = {
isSuccess: false,
isError: false,
isPending: false,
result: null,
error: null,
imageDTO: undefined,
};
const $state = atom<PromptExpansionRequestState>(deepClone(IDLE_STATE));
const reset = () => {
$state.set(deepClone(IDLE_STATE));
};
const setPending = (imageDTO?: ImageDTO) => {
$state.set({
...$state.get(),
isSuccess: false,
isError: false,
isPending: true,
result: null,
error: null,
imageDTO,
});
};
const setSuccess = (result: string) => {
$state.set({
...$state.get(),
isSuccess: true,
isError: false,
isPending: false,
result,
error: null,
});
};
const setError = (error: Error) => {
$state.set({
...$state.get(),
isSuccess: false,
isError: true,
isPending: false,
result: null,
error,
});
};
export const saveVideoApi = {
$state,
reset,
setPending,
setSuccess,
setError,
};

View File

@@ -373,6 +373,7 @@ export type OutputFields<T extends AnyInvocation> = Extract<
// Node Outputs
export type ImageOutput = S['ImageOutput'];
export type VideoOutput = S['VideoOutput'];
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
export type StarterModel = S['StarterModel'];

View File

@@ -15,7 +15,7 @@ import { generatedVideoChanged } from 'features/parameters/store/videoSlice';
import type { LRUCache } from 'lru-cache';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import type { ImageDTO, S, VideoOutput } from 'services/api/types';
import { getCategories } from 'services/api/util';
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
import { $lastProgressEvent } from 'services/events/stores';
@@ -196,11 +196,11 @@ export const buildOnInvocationComplete = (
return imageDTOs;
};
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<{url: string , taskId: number} | null> => {
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<VideoOutput | null> => {
// @ts-expect-error: This is a workaround to get the video name from the result
if (data.invocation.type === 'runway_generate_video') {
// @ts-expect-error: This is a workaround to get the video name from the result
return {url: data.result.video_url, taskId: data.result.runway_task_id};
return {videoId: data.result.video_id};
}
return null;
};