mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): fiddle w/ video stuff
This commit is contained in:
committed by
Mary Hipp Rogers
parent
9380d8901c
commit
f98bbc32dd
@@ -1,3 +1,6 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
// #region Field data schemas
|
||||
@@ -11,6 +14,13 @@ type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
|
||||
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
|
||||
zImageFieldCollection.safeParse(field).success;
|
||||
|
||||
export const zVideoField = z.object({
|
||||
video_id: z.string().trim().min(1),
|
||||
});
|
||||
export type VideoField = z.infer<typeof zVideoField>;
|
||||
export const isVideoField = (field: unknown): field is VideoField => zVideoField.safeParse(field).success;
|
||||
assert<Equals<VideoField, S['VideoField']>>();
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
});
|
||||
|
||||
@@ -5,14 +5,15 @@ import type { SliceConfig } from 'app/store/types';
|
||||
import { isPlainObject } from 'es-toolkit';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { VideoOutput, zVideoOutput } from 'features/nodes/types/common';
|
||||
import type { VideoField } from 'features/nodes/types/common';
|
||||
import { zVideoField } from 'features/nodes/types/common';
|
||||
import { assert } from 'tsafe';
|
||||
import z from 'zod';
|
||||
|
||||
const zVideoState = z.object({
|
||||
_version: z.literal(1),
|
||||
startingFrameImage: zImageWithDims.nullable(),
|
||||
generatedVideo: zVideoOutput.nullable(),
|
||||
generatedVideo: zVideoField.nullable(),
|
||||
});
|
||||
|
||||
export type VideoState = z.infer<typeof zVideoState>;
|
||||
@@ -31,17 +32,14 @@ const slice = createSlice({
|
||||
state.startingFrameImage = action.payload;
|
||||
},
|
||||
|
||||
generatedVideoChanged: (state, action: PayloadAction<VideoOutput | null>) => {
|
||||
state.generatedVideo = action.payload;
|
||||
generatedVideoChanged: (state, action: PayloadAction<{ videoField: VideoField | null }>) => {
|
||||
const { videoField } = action.payload;
|
||||
state.generatedVideo = videoField;
|
||||
},
|
||||
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
startingFrameImageChanged,
|
||||
generatedVideoChanged,
|
||||
} = slice.actions;
|
||||
export const { startingFrameImageChanged, generatedVideoChanged } = slice.actions;
|
||||
|
||||
export const videoSliceConfig: SliceConfig<typeof slice> = {
|
||||
slice,
|
||||
@@ -62,4 +60,4 @@ export const selectVideoSlice = (state: RootState) => state.video;
|
||||
const createVideoSelector = <T>(selector: Selector<VideoState, T>) => createSelector(selectVideoSlice, selector);
|
||||
|
||||
export const selectStartingFrameImage = createVideoSelector((video) => video.startingFrameImage);
|
||||
export const selectGeneratedVideo = createVideoSelector((video) => video.generatedVideo);
|
||||
export const selectGeneratedVideo = createVideoSelector((video) => video.generatedVideo);
|
||||
|
||||
@@ -4,8 +4,7 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { buildRunwayVideoGraph } from 'features/nodes/util/graph/generation/buildRunwayVideoGraph';
|
||||
import { selectCanvasDestination } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { GraphBuilderArg } from 'features/nodes/util/graph/types';
|
||||
@@ -19,7 +18,7 @@ import { AssertionError } from 'tsafe';
|
||||
const log = logger('generation');
|
||||
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
|
||||
|
||||
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedCanvas());
|
||||
@@ -29,7 +28,6 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
const destination = selectCanvasDestination(state);
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
|
||||
const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state, manager: null };
|
||||
|
||||
return await buildRunwayVideoGraph(graphBuilderArg);
|
||||
@@ -58,30 +56,44 @@ const enqueueVideo = async (store: AppStore, prepend: boolean) => {
|
||||
|
||||
const { g, seed, positivePrompt } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedNode: seed,
|
||||
positivePromptNode: positivePrompt,
|
||||
origin: 'canvas',
|
||||
// const prepareBatchResult = withResult(() =>
|
||||
// prepareLinearUIBatch({
|
||||
// state,
|
||||
// g,
|
||||
// prepend,
|
||||
// seedNode: seed,
|
||||
// positivePromptNode: positivePrompt,
|
||||
// origin: 'canvas',
|
||||
// destination,
|
||||
// })
|
||||
// );
|
||||
|
||||
// if (prepareBatchResult.isErr()) {
|
||||
// log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
// return;
|
||||
// }
|
||||
|
||||
// const batchConfig = prepareBatchResult.value;
|
||||
|
||||
|
||||
const batchConfig = {
|
||||
prepend,
|
||||
batch: {
|
||||
graph: g.getGraph(),
|
||||
runs: 1,
|
||||
origin,
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
},
|
||||
};
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
queueApi.endpoints.enqueueBatch.initiate(
|
||||
batchConfig,
|
||||
{
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
@@ -93,10 +105,9 @@ export const useEnqueueVideo = () => {
|
||||
const store = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
|
||||
return enqueueVideo(store, prepend);
|
||||
},
|
||||
[ store]
|
||||
[store]
|
||||
);
|
||||
return enqueue;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFocusRegion } from 'common/hooks/focus';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ReactPlayer from 'react-player';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
|
||||
import { useGetVideoDTOQuery } from 'services/api/endpoints/videos';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
|
||||
|
||||
export const VideoPlayerPanel = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
const generatedVideo = useAppSelector(selectGeneratedVideo);
|
||||
const lastSelectedVideoId = useAppSelector(selectLastSelectedImage);
|
||||
const {data: videoDTO} = useGetVideoDTOQuery(lastSelectedVideoId ?? skipToken);
|
||||
const { data: videoDTO } = useGetVideoDTOQuery(generatedVideo?.video_id ?? skipToken);
|
||||
|
||||
useFocusRegion('video', ref);
|
||||
|
||||
@@ -30,28 +28,30 @@ export const VideoPlayerPanel = memo(() => {
|
||||
return videoDTO.video_url;
|
||||
}, [videoDTO]);
|
||||
|
||||
|
||||
return (
|
||||
<Flex ref={ref} w="full" h="full" flexDirection="column" gap={4}>
|
||||
|
||||
{videoUrl &&
|
||||
{videoUrl && (
|
||||
<>
|
||||
<Box flex={0.75} position="relative" >
|
||||
<ReactPlayer
|
||||
<Box flex={0.75} position="relative">
|
||||
<ReactPlayer
|
||||
src={videoUrl}
|
||||
width="75%"
|
||||
height="75%"
|
||||
controls={true}
|
||||
style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)', maxWidth: '900px' }}
|
||||
/>
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: '50%',
|
||||
left: '50%',
|
||||
transform: 'translate(-50%, -50%)',
|
||||
maxWidth: '900px',
|
||||
}}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
|
||||
</>}
|
||||
</>
|
||||
)}
|
||||
{!videoUrl && <Text>No video generated</Text>}
|
||||
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
VideoPlayerPanel.displayName = 'VideoPlayerPanel';
|
||||
VideoPlayerPanel.displayName = 'VideoPlayerPanel';
|
||||
|
||||
@@ -2,7 +2,6 @@ import type { DockviewApi, GridviewApi, IDockviewReactProps, IGridviewReactProps
|
||||
import { DockviewReact, GridviewReact, LayoutPriority, Orientation } from 'dockview';
|
||||
import { BoardsPanel } from 'features/gallery/components/BoardsListPanelContent';
|
||||
import { GalleryPanel } from 'features/gallery/components/Gallery';
|
||||
import { ImageViewerPanel } from 'features/gallery/components/ImageViewer/ImageViewerPanel';
|
||||
import { FloatingLeftPanelButtons } from 'features/ui/components/FloatingLeftPanelButtons';
|
||||
import { FloatingRightPanelButtons } from 'features/ui/components/FloatingRightPanelButtons';
|
||||
import type {
|
||||
@@ -20,6 +19,7 @@ import { memo, useCallback, useEffect } from 'react';
|
||||
import { DockviewTab } from './DockviewTab';
|
||||
import { DockviewTabLaunchpad } from './DockviewTabLaunchpad';
|
||||
import { DockviewTabProgress } from './DockviewTabProgress';
|
||||
import { GenerateLaunchpadPanel } from './GenerateLaunchpadPanel';
|
||||
import { navigationApi } from './navigation-api';
|
||||
import { PanelHotkeysLogical } from './PanelHotkeysLogical';
|
||||
import {
|
||||
@@ -41,9 +41,8 @@ import {
|
||||
SETTINGS_PANEL_ID,
|
||||
VIEWER_PANEL_ID,
|
||||
} from './shared';
|
||||
import { VideoTabLeftPanel } from './VideoTabLeftPanel';
|
||||
import { GenerateLaunchpadPanel } from './GenerateLaunchpadPanel';
|
||||
import { VideoPlayerPanel } from './VideoPlayerPanel';
|
||||
import { VideoTabLeftPanel } from './VideoTabLeftPanel';
|
||||
|
||||
const tabComponents = {
|
||||
[DOCKVIEW_TAB_ID]: DockviewTab,
|
||||
|
||||
@@ -22222,21 +22222,6 @@ export type components = {
|
||||
* @description The id of the video
|
||||
*/
|
||||
video_id: string;
|
||||
/**
|
||||
* Width
|
||||
* @description The width of the video in pixels
|
||||
*/
|
||||
width: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height of the video in pixels
|
||||
*/
|
||||
height: number;
|
||||
/**
|
||||
* Duration Seconds
|
||||
* @description The duration of the video in seconds
|
||||
*/
|
||||
duration_seconds: number;
|
||||
};
|
||||
/**
|
||||
* VideoIdsResult
|
||||
@@ -22266,6 +22251,21 @@ export type components = {
|
||||
VideoOutput: {
|
||||
/** @description The output video */
|
||||
video: components["schemas"]["VideoField"];
|
||||
/**
|
||||
* Width
|
||||
* @description The width of the video in pixels
|
||||
*/
|
||||
width: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height of the video in pixels
|
||||
*/
|
||||
height: number;
|
||||
/**
|
||||
* Duration Seconds
|
||||
* @description The duration of the video in seconds
|
||||
*/
|
||||
duration_seconds: number;
|
||||
/**
|
||||
* type
|
||||
* @default video_output
|
||||
|
||||
@@ -10,13 +10,14 @@ import {
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import type { VideoField } from 'features/nodes/types/common';
|
||||
import { isImageField, isImageFieldCollection, isVideoField } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { generatedVideoChanged } from 'features/parameters/store/videoSlice';
|
||||
import type { LRUCache } from 'lru-cache';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S, VideoOutput } from 'services/api/types';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories } from 'services/api/util';
|
||||
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
|
||||
import { $lastProgressEvent } from 'services/events/stores';
|
||||
@@ -205,13 +206,17 @@ export const buildOnInvocationComplete = (
|
||||
return imageDTOs;
|
||||
};
|
||||
|
||||
const getResultVideoDTOs = async (data: S['InvocationCompleteEvent']): Promise<VideoOutput | null> => {
|
||||
// @ts-expect-error: This is a workaround to get the video name from the result
|
||||
if (data.invocation.type === 'runway_generate_video') {
|
||||
// @ts-expect-error: This is a workaround to get the video name from the result
|
||||
return {videoId: data.result.video_id};
|
||||
const getResultVideoFields = (data: S['InvocationCompleteEvent']): VideoField[] => {
|
||||
const { result } = data;
|
||||
const videoFields: VideoField[] = [];
|
||||
|
||||
for (const [_name, value] of objectEntries(result)) {
|
||||
if (isVideoField(value)) {
|
||||
videoFields.push(value);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
|
||||
return videoFields;
|
||||
};
|
||||
|
||||
return async (data: S['InvocationCompleteEvent']) => {
|
||||
@@ -235,9 +240,9 @@ export const buildOnInvocationComplete = (
|
||||
|
||||
await addImagesToGallery(data);
|
||||
|
||||
const videoResult = await getResultVideoDTOs(data);
|
||||
if (videoResult) {
|
||||
dispatch(generatedVideoChanged({ video_id: videoResult.video.video_id, type: 'video_output' }));
|
||||
const videoField = getResultVideoFields(data)[0];
|
||||
if (videoField) {
|
||||
dispatch(generatedVideoChanged({ videoField }));
|
||||
}
|
||||
|
||||
$lastProgressEvent.set(null);
|
||||
|
||||
Reference in New Issue
Block a user