mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
combine nodes that generate and save videos
This commit is contained in:
committed by
psychedelicious
parent
d8fcf18b6c
commit
6afe995dbe
@@ -188,3 +188,11 @@ export const zImageOutput = z.object({
|
||||
});
|
||||
export type ImageOutput = z.infer<typeof zImageOutput>;
|
||||
// #endregion
|
||||
|
||||
// #region ImageOutput
|
||||
export const zVideoOutput = z.object({
|
||||
video_id: z.string().trim().min(1),
|
||||
type: z.literal('video_output'),
|
||||
});
|
||||
export type VideoOutput = z.infer<typeof zVideoOutput>;
|
||||
// #endregion
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectVideoFirstFrameImage, selectVideoLastFrameImage } from 'features/parameters/store/videoSlice';
|
||||
import { zImageField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { t } from 'i18next';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
// Default video parameters - these could be moved to a video params slice in the future
|
||||
const DEFAULT_VIDEO_DURATION = 5;
|
||||
const DEFAULT_VIDEO_ASPECT_RATIO = "1280:768"; // Default landscape
|
||||
const DEFAULT_ENHANCE_PROMPT = true;
|
||||
|
||||
// Video parameter extraction helper
|
||||
const getVideoParameters = (state: RootState) => {
|
||||
// In the future, these could come from a dedicated video parameters slice
|
||||
// For now, we use defaults but allow them to be overridden by any video-specific state
|
||||
return {
|
||||
duration: DEFAULT_VIDEO_DURATION,
|
||||
aspectRatio: DEFAULT_VIDEO_ASPECT_RATIO,
|
||||
enhancePrompt: DEFAULT_ENHANCE_PROMPT,
|
||||
};
|
||||
};
|
||||
|
||||
export const buildRunwayVideoGraph = (arg: GraphBuilderArg): GraphBuilderReturn => {
|
||||
const { generationMode, state, manager } = arg;
|
||||
|
||||
log.debug({ generationMode, manager: manager?.id }, 'Building Runway video graph');
|
||||
|
||||
// Runway video generation supports text-to-video and image-to-video
|
||||
// We can support multiple generation modes depending on whether frame images are provided
|
||||
const supportedModes = ['txt2img'] as const;
|
||||
if (!supportedModes.includes(generationMode as any)) {
|
||||
throw new UnsupportedGenerationModeError(t('toast.runwayIncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const prompts = selectPresetModifiedPrompts(state);
|
||||
const videoFirstFrameImage = selectVideoFirstFrameImage(state);
|
||||
const videoLastFrameImage = selectVideoLastFrameImage(state);
|
||||
const videoParams = getVideoParameters(state);
|
||||
|
||||
// Get seed from params
|
||||
const { seed, shouldRandomizeSeed } = params;
|
||||
const finalSeed = shouldRandomizeSeed ? undefined : seed;
|
||||
|
||||
// Determine if this is image-to-video or text-to-video
|
||||
const hasFrameImages = videoFirstFrameImage || videoLastFrameImage;
|
||||
|
||||
const g = new Graph(getPrefixedId('runway_video_graph'));
|
||||
|
||||
const positivePrompt = g.addNode({
|
||||
id: getPrefixedId('positive_prompt'),
|
||||
type: 'string',
|
||||
value: prompts.positive,
|
||||
});
|
||||
|
||||
// Create the runway video generation node
|
||||
const runwayVideoNode = g.addNode({
|
||||
id: getPrefixedId('runway_generate_video'),
|
||||
// @ts-expect-error: This node is not available in the OSS application
|
||||
type: 'runway_generate_video',
|
||||
duration: videoParams.duration,
|
||||
aspect_ratio: videoParams.aspectRatio,
|
||||
seed: finalSeed,
|
||||
});
|
||||
|
||||
// @ts-expect-error: This node is not available in the OSS application
|
||||
g.addEdge(positivePrompt, 'value', runwayVideoNode, 'prompt');
|
||||
|
||||
|
||||
// Add first frame image if provided
|
||||
if (videoFirstFrameImage) {
|
||||
const firstFrameImageField = zImageField.parse(videoFirstFrameImage);
|
||||
// @ts-expect-error: This connection is specific to runway node
|
||||
runwayVideoNode.first_frame_image = firstFrameImageField;
|
||||
}
|
||||
|
||||
// Add last frame image if provided
|
||||
if (videoLastFrameImage) {
|
||||
const lastFrameImageField = zImageField.parse(videoLastFrameImage);
|
||||
// @ts-expect-error: This connection is specific to runway node
|
||||
runwayVideoNode.last_frame_image = lastFrameImageField;
|
||||
}
|
||||
|
||||
// Set up metadata
|
||||
g.upsertMetadata({
|
||||
positive_prompt: prompts.positive,
|
||||
negative_prompt: prompts.negative || '',
|
||||
video_duration: videoParams.duration,
|
||||
video_aspect_ratio: videoParams.aspectRatio,
|
||||
seed: finalSeed,
|
||||
enhance_prompt: videoParams.enhancePrompt,
|
||||
generation_type: hasFrameImages ? 'image-to-video' : 'text-to-video',
|
||||
});
|
||||
|
||||
// Add video frame images to metadata if they exist
|
||||
if (hasFrameImages) {
|
||||
g.upsertMetadata({
|
||||
first_frame_image: videoFirstFrameImage,
|
||||
last_frame_image: videoLastFrameImage,
|
||||
}, 'merge');
|
||||
}
|
||||
|
||||
g.setMetadataReceivingNode(runwayVideoNode);
|
||||
|
||||
return {
|
||||
g,
|
||||
positivePrompt,
|
||||
};
|
||||
};
|
||||
@@ -5,6 +5,7 @@ import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { VideoOutput, zVideoOutput } from 'features/nodes/types/common';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
@@ -12,10 +13,7 @@ const zVideoState = z.object({
|
||||
_version: z.literal(1),
|
||||
videoFirstFrameImage: zImageWithDims.nullable(),
|
||||
videoLastFrameImage: zImageWithDims.nullable(),
|
||||
generatedVideo: z.object({
|
||||
url: z.string(),
|
||||
taskId: z.number(),
|
||||
}).nullable(),
|
||||
generatedVideo: zVideoOutput.nullable(),
|
||||
});
|
||||
|
||||
export type VideoState = z.infer<typeof zVideoState>;
|
||||
@@ -39,7 +37,7 @@ const slice = createSlice({
|
||||
state.videoLastFrameImage = action.payload;
|
||||
},
|
||||
|
||||
generatedVideoChanged: (state, action: PayloadAction<{ url: string, taskId: number } | null>) => {
|
||||
generatedVideoChanged: (state, action: PayloadAction<VideoOutput | null>) => {
|
||||
state.generatedVideo = action.payload;
|
||||
},
|
||||
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useFocusRegion } from 'common/hooks/focus';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ReactPlayer from 'react-player';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
import { useDispatch } from 'react-redux';
|
||||
import { saveVideo } from 'features/video/saveVideo';
|
||||
|
||||
|
||||
export const VideoPlayerPanel = memo(() => {
|
||||
@@ -17,18 +14,6 @@ export const VideoPlayerPanel = memo(() => {
|
||||
|
||||
useFocusRegion('video', ref);
|
||||
|
||||
const { dispatch, getState } = useAppStore();
|
||||
|
||||
const handleSaveVideo = useCallback(() => {
|
||||
console.log('generatedVideo', generatedVideo);
|
||||
if (!generatedVideo?.taskId) {
|
||||
return
|
||||
}
|
||||
console.log('saving video', generatedVideo.taskId);
|
||||
saveVideo({ dispatch, getState, taskId: `${generatedVideo.taskId}` });
|
||||
}, [dispatch, getState, generatedVideo]);
|
||||
|
||||
|
||||
|
||||
return (
|
||||
<Flex ref={ref} w="full" h="full" flexDirection="column" gap={4}>
|
||||
@@ -36,16 +21,15 @@ export const VideoPlayerPanel = memo(() => {
|
||||
{generatedVideo &&
|
||||
<>
|
||||
<Box flex={0.75} position="relative" >
|
||||
<ReactPlayer
|
||||
{/* <ReactPlayer
|
||||
src={generatedVideo.url}
|
||||
width="75%"
|
||||
height="75%"
|
||||
controls={true}
|
||||
style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)', maxWidth: '900px' }}
|
||||
/>
|
||||
/> */}
|
||||
</Box>
|
||||
|
||||
<Button leftIcon={<PiCheckBold />} colorScheme="invokeBlue" onClick={handleSaveVideo}>Keep</Button>
|
||||
|
||||
</>}
|
||||
{!generatedVideo && <Text>No video generated</Text>}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const buildSaveVideoGraph = ({
|
||||
state,
|
||||
}: {
|
||||
state: RootState;
|
||||
}): { graph: Graph; outputNodeId: string } => {
|
||||
|
||||
const taskId = state.video.generatedVideo?.taskId;
|
||||
|
||||
assert(taskId, 'No task ID found in state');
|
||||
|
||||
const graph = new Graph(getPrefixedId('save-video-graph'));
|
||||
const outputNode = graph.addNode({
|
||||
// @ts-expect-error: These nodes are not available in the OSS application
|
||||
type: 'save_runway_video',
|
||||
id: getPrefixedId('save_runway_video'),
|
||||
runway_task_id: taskId,
|
||||
});
|
||||
return { graph, outputNodeId: outputNode.id };
|
||||
|
||||
};
|
||||
@@ -1,41 +0,0 @@
|
||||
import type { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { buildRunGraphDependencies, runGraph } from 'services/api/run-graph';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert } from 'tsafe';
|
||||
import { buildSaveVideoGraph } from './graph';
|
||||
import { saveVideoApi } from './state';
|
||||
|
||||
|
||||
|
||||
export const saveVideo = async (arg: { dispatch: AppDispatch; getState: AppGetState; taskId?: string }) => {
|
||||
const { dispatch, getState, taskId } = arg;
|
||||
const socket = $socket.get();
|
||||
if (!socket) {
|
||||
return;
|
||||
}
|
||||
const { graph, outputNodeId } = buildSaveVideoGraph({
|
||||
state: getState(),
|
||||
});
|
||||
const dependencies = buildRunGraphDependencies(dispatch, socket);
|
||||
try {
|
||||
const { output } = await runGraph({
|
||||
graph,
|
||||
outputNodeId,
|
||||
dependencies,
|
||||
options: {
|
||||
prepend: true,
|
||||
},
|
||||
});
|
||||
assert(output.type === 'string_output');
|
||||
saveVideoApi.setSuccess(output.value);
|
||||
} catch {
|
||||
saveVideoApi.reset();
|
||||
toast({
|
||||
id: 'SAVE_VIDEO_FAILED',
|
||||
title: t('toast.saveVideoFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -1,98 +0,0 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { atom } from 'nanostores';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
type SuccessState = {
|
||||
isSuccess: true;
|
||||
isError: false;
|
||||
isPending: false;
|
||||
result: string;
|
||||
error: null;
|
||||
imageDTO?: ImageDTO;
|
||||
};
|
||||
|
||||
type ErrorState = {
|
||||
isSuccess: false;
|
||||
isError: true;
|
||||
isPending: false;
|
||||
result: null;
|
||||
error: Error;
|
||||
imageDTO?: ImageDTO;
|
||||
};
|
||||
|
||||
type PendingState = {
|
||||
isSuccess: false;
|
||||
isError: false;
|
||||
isPending: true;
|
||||
result: null;
|
||||
error: null;
|
||||
imageDTO?: ImageDTO;
|
||||
};
|
||||
|
||||
type IdleState = {
|
||||
isSuccess: false;
|
||||
isError: false;
|
||||
isPending: false;
|
||||
result: null;
|
||||
error: null;
|
||||
imageDTO?: ImageDTO;
|
||||
};
|
||||
|
||||
export type PromptExpansionRequestState = IdleState | PendingState | SuccessState | ErrorState;
|
||||
|
||||
const IDLE_STATE: IdleState = {
|
||||
isSuccess: false,
|
||||
isError: false,
|
||||
isPending: false,
|
||||
result: null,
|
||||
error: null,
|
||||
imageDTO: undefined,
|
||||
};
|
||||
|
||||
const $state = atom<PromptExpansionRequestState>(deepClone(IDLE_STATE));
|
||||
|
||||
const reset = () => {
|
||||
$state.set(deepClone(IDLE_STATE));
|
||||
};
|
||||
|
||||
const setPending = (imageDTO?: ImageDTO) => {
|
||||
$state.set({
|
||||
...$state.get(),
|
||||
isSuccess: false,
|
||||
isError: false,
|
||||
isPending: true,
|
||||
result: null,
|
||||
error: null,
|
||||
imageDTO,
|
||||
});
|
||||
};
|
||||
|
||||
const setSuccess = (result: string) => {
|
||||
$state.set({
|
||||
...$state.get(),
|
||||
isSuccess: true,
|
||||
isError: false,
|
||||
isPending: false,
|
||||
result,
|
||||
error: null,
|
||||
});
|
||||
};
|
||||
|
||||
const setError = (error: Error) => {
|
||||
$state.set({
|
||||
...$state.get(),
|
||||
isSuccess: false,
|
||||
isError: true,
|
||||
isPending: false,
|
||||
result: null,
|
||||
error,
|
||||
});
|
||||
};
|
||||
|
||||
export const saveVideoApi = {
|
||||
$state,
|
||||
reset,
|
||||
setPending,
|
||||
setSuccess,
|
||||
setError,
|
||||
};
|
||||
@@ -373,6 +373,7 @@ export type OutputFields<T extends AnyInvocation> = Extract<
|
||||
|
||||
// Node Outputs
|
||||
export type ImageOutput = S['ImageOutput'];
|
||||
export type VideoOutput = S['VideoOutput'];
|
||||
|
||||
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
|
||||
export type StarterModel = S['StarterModel'];
|
||||
|
||||
@@ -15,7 +15,7 @@ import { generatedVideoChanged } from 'features/parameters/store/videoSlice';
|
||||
import type { LRUCache } from 'lru-cache';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import type { ImageDTO, S, VideoOutput } from 'services/api/types';
|
||||
import { getCategories } from 'services/api/util';
|
||||
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
|
||||
import { $lastProgressEvent } from 'services/events/stores';
|
||||
@@ -196,11 +196,11 @@ export const buildOnInvocationComplete = (
|
||||
return imageDTOs;
|
||||
};
|
||||
|
||||
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<{url: string , taskId: number} | null> => {
|
||||
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<VideoOutput | null> => {
|
||||
// @ts-expect-error: This is a workaround to get the video name from the result
|
||||
if (data.invocation.type === 'runway_generate_video') {
|
||||
// @ts-expect-error: This is a workaround to get the video name from the result
|
||||
return {url: data.result.video_url, taskId: data.result.runway_task_id};
|
||||
return {videoId: data.result.video_id};
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user