feat(ui): fiddle w/ video stuff

This commit is contained in:
psychedelicious
2025-08-19 19:06:30 +10:00
committed by Mary Hipp Rogers
parent 9380d8901c
commit f98bbc32dd
7 changed files with 109 additions and 86 deletions

View File

@@ -1,3 +1,6 @@
import type { S } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import { z } from 'zod';
// #region Field data schemas
@@ -11,6 +14,13 @@ type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
zImageFieldCollection.safeParse(field).success;
export const zVideoField = z.object({
video_id: z.string().trim().min(1),
});
export type VideoField = z.infer<typeof zVideoField>;
export const isVideoField = (field: unknown): field is VideoField => zVideoField.safeParse(field).success;
assert<Equals<VideoField, S['VideoField']>>();
export const zBoardField = z.object({
board_id: z.string().trim().min(1),
});

View File

@@ -5,14 +5,15 @@ import type { SliceConfig } from 'app/store/types';
import { isPlainObject } from 'es-toolkit';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zImageWithDims } from 'features/controlLayers/store/types';
import { VideoOutput, zVideoOutput } from 'features/nodes/types/common';
import type { VideoField } from 'features/nodes/types/common';
import { zVideoField } from 'features/nodes/types/common';
import { assert } from 'tsafe';
import z from 'zod';
const zVideoState = z.object({
_version: z.literal(1),
startingFrameImage: zImageWithDims.nullable(),
generatedVideo: zVideoOutput.nullable(),
generatedVideo: zVideoField.nullable(),
});
export type VideoState = z.infer<typeof zVideoState>;
@@ -31,17 +32,14 @@ const slice = createSlice({
state.startingFrameImage = action.payload;
},
generatedVideoChanged: (state, action: PayloadAction<VideoOutput | null>) => {
state.generatedVideo = action.payload;
generatedVideoChanged: (state, action: PayloadAction<{ videoField: VideoField | null }>) => {
const { videoField } = action.payload;
state.generatedVideo = videoField;
},
},
});
export const {
startingFrameImageChanged,
generatedVideoChanged,
} = slice.actions;
export const { startingFrameImageChanged, generatedVideoChanged } = slice.actions;
export const videoSliceConfig: SliceConfig<typeof slice> = {
slice,
@@ -62,4 +60,4 @@ export const selectVideoSlice = (state: RootState) => state.video;
const createVideoSelector = <T>(selector: Selector<VideoState, T>) => createSelector(selectVideoSlice, selector);
export const selectStartingFrameImage = createVideoSelector((video) => video.startingFrameImage);
export const selectGeneratedVideo = createVideoSelector((video) => video.generatedVideo);
export const selectGeneratedVideo = createVideoSelector((video) => video.generatedVideo);

View File

@@ -4,8 +4,7 @@ import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { withResultAsync } from 'common/util/result';
import { buildRunwayVideoGraph } from 'features/nodes/util/graph/generation/buildRunwayVideoGraph';
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
@@ -19,7 +18,7 @@ import { AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
@@ -29,7 +28,6 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const destination = selectCanvasDestination(state);
const buildGraphResult = await withResultAsync(async () => {
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
return await buildRunwayVideoGraph(graphBuilderArg);
@@ -58,30 +56,44 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
const { g, seed, positivePrompt } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedNode: seed,
positivePromptNode: positivePrompt,
origin: 'canvas',
// const prepareBatchResult = withResult(() =>
// prepareLinearUIBatch({
// state,
// g,
// prepend,
// seedNode: seed,
// positivePromptNode: positivePrompt,
// origin: 'canvas',
// destination,
// })
// );
// if (prepareBatchResult.isErr()) {
// log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
// return;
// }
// const batchConfig = prepareBatchResult.value;
const batchConfig = {
prepend,
batch: {
graph: g.getGraph(),
runs: 1,
origin,
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
},
};
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
queueApi.endpoints.enqueueBatch.initiate(
batchConfig,
{
...enqueueMutationFixedCacheKeyOptions,
track: false,
}
)
);
const enqueueResult = await req.unwrap();
@@ -93,10 +105,9 @@ export const useEnqueueVideo = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueVideo(store, prepend);
},
[ store]
[store]
);
return enqueue;
};
};

View File

@@ -1,22 +1,20 @@
import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useFocusRegion } from 'common/hooks/focus';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
import { memo, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import ReactPlayer from 'react-player';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
import { useGetVideoDTOQuery } from 'services/api/endpoints/videos';
import { skipToken } from '@reduxjs/toolkit/query';
import { useImageDTO } from 'services/api/endpoints/images';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
export const VideoPlayerPanel = memo(() => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
const generatedVideo = useAppSelector(selectGeneratedVideo);
const lastSelectedVideoId = useAppSelector(selectLastSelectedImage);
const {data: videoDTO} = useGetVideoDTOQuery(lastSelectedVideoId ?? skipToken);
const { data: videoDTO } = useGetVideoDTOQuery(generatedVideo?.video_id ?? skipToken);
useFocusRegion('video', ref);
@@ -30,28 +28,30 @@ export const VideoPlayerPanel = memo(() => {
return videoDTO.video_url;
}, [videoDTO]);
return (
<Flex ref={ref} w="full" h="full" flexDirection="column" gap={4}>
{videoUrl &&
{videoUrl && (
<>
<Box flex={0.75} position="relative" >
<ReactPlayer
<Box flex={0.75} position="relative">
<ReactPlayer
src={videoUrl}
width="75%"
height="75%"
controls={true}
style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)', maxWidth: '900px' }}
/>
style={{
position: 'absolute',
top: '50%',
left: '50%',
transform: 'translate(-50%, -50%)',
maxWidth: '900px',
}}
/>
</Box>
</>}
</>
)}
{!videoUrl && <Text>No video generated</Text>}
</Flex>
);
});
VideoPlayerPanel.displayName = 'VideoPlayerPanel';
VideoPlayerPanel.displayName = 'VideoPlayerPanel';

View File

@@ -2,7 +2,6 @@ import type { DockviewApi, GridviewApi, IDockviewReactProps, IGridviewReactProps
import { DockviewReact, GridviewReact, LayoutPriority, Orientation } from 'dockview';
import { BoardsPanel } from 'features/gallery/components/BoardsListPanelContent';
import { GalleryPanel } from 'features/gallery/components/Gallery';
import { ImageViewerPanel } from 'features/gallery/components/ImageViewer/ImageViewerPanel';
import { FloatingLeftPanelButtons } from 'features/ui/components/FloatingLeftPanelButtons';
import { FloatingRightPanelButtons } from 'features/ui/components/FloatingRightPanelButtons';
import type {
@@ -20,6 +19,7 @@ import { memo, useCallback, useEffect } from 'react';
import { DockviewTab } from './DockviewTab';
import { DockviewTabLaunchpad } from './DockviewTabLaunchpad';
import { DockviewTabProgress } from './DockviewTabProgress';
import { GenerateLaunchpadPanel } from './GenerateLaunchpadPanel';
import { navigationApi } from './navigation-api';
import { PanelHotkeysLogical } from './PanelHotkeysLogical';
import {
@@ -41,9 +41,8 @@ import {
SETTINGS_PANEL_ID,
VIEWER_PANEL_ID,
} from './shared';
import { VideoTabLeftPanel } from './VideoTabLeftPanel';
import { GenerateLaunchpadPanel } from './GenerateLaunchpadPanel';
import { VideoPlayerPanel } from './VideoPlayerPanel';
import { VideoTabLeftPanel } from './VideoTabLeftPanel';
const tabComponents = {
[DOCKVIEW_TAB_ID]: DockviewTab,

View File

@@ -22222,21 +22222,6 @@ export type components = {
* @description The id of the video
*/
video_id: string;
/**
* Width
* @description The width of the video in pixels
*/
width: number;
/**
* Height
* @description The height of the video in pixels
*/
height: number;
/**
* Duration Seconds
* @description The duration of the video in seconds
*/
duration_seconds: number;
};
/**
* VideoIdsResult
@@ -22266,6 +22251,21 @@ export type components = {
VideoOutput: {
/** @description The output video */
video: components["schemas"]["VideoField"];
/**
* Width
* @description The width of the video in pixels
*/
width: number;
/**
* Height
* @description The height of the video in pixels
*/
height: number;
/**
* Duration Seconds
* @description The duration of the video in seconds
*/
duration_seconds: number;
/**
* type
* @default video_output

View File

@@ -10,13 +10,14 @@ import {
} from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import type { VideoField } from 'features/nodes/types/common';
import { isImageField, isImageFieldCollection, isVideoField } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { generatedVideoChanged } from 'features/parameters/store/videoSlice';
import type { LRUCache } from 'lru-cache';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S, VideoOutput } from 'services/api/types';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories } from 'services/api/util';
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
import { $lastProgressEvent } from 'services/events/stores';
@@ -205,13 +206,17 @@ export const buildOnInvocationComplete = (
return imageDTOs;
};
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<VideoOutput | null> => {
// @ts-expect-error: This is a workaround to get the video name from the result
if (data.invocation.type === 'runway_generate_video') {
// @ts-expect-error: This is a workaround to get the video name from the result
return {videoId: data.result.video_id};
const getResultVideoFields = (data: S['InvocationCompleteEvent']): VideoField[] => {
const { result } = data;
const videoFields: VideoField[] = [];
for (const [_name, value] of objectEntries(result)) {
if (isVideoField(value)) {
videoFields.push(value);
}
}
return null;
return videoFields;
};
return async (data: S['InvocationCompleteEvent']) => {
@@ -235,9 +240,9 @@ export const buildOnInvocationComplete = (
await addImagesToGallery(data);
const videoResult = await getResultVideoDTOs(data);
if (videoResult) {
dispatch(generatedVideoChanged({ video_id: videoResult.video.video_id, type: 'video_output' }));
const videoField = getResultVideoFields(data)[0];
if (videoField) {
dispatch(generatedVideoChanged({ videoField }));
}
$lastProgressEvent.set(null);