mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-17 15:53:05 -05:00
refactor(ui): canvas flow events (wip)
This commit is contained in:
@@ -29,23 +29,17 @@ import {
|
|||||||
selectStagedImageIndex,
|
selectStagedImageIndex,
|
||||||
selectStagedImages,
|
selectStagedImages,
|
||||||
stagingAreaImageSelected,
|
stagingAreaImageSelected,
|
||||||
stagingAreaImageStaged,
|
|
||||||
stagingAreaNextStagedImageSelected,
|
stagingAreaNextStagedImageSelected,
|
||||||
stagingAreaPrevStagedImageSelected,
|
stagingAreaPrevStagedImageSelected,
|
||||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||||
import { isImageField, type ProgressImage } from 'features/nodes/types/common';
|
import type { ProgressImage } from 'features/nodes/types/common';
|
||||||
import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { memo, useCallback, useEffect } from 'react';
|
||||||
import type { Atom } from 'nanostores';
|
|
||||||
import { atom } from 'nanostores';
|
|
||||||
import { memo, useCallback, useEffect, useState } from 'react';
|
|
||||||
import { flushSync } from 'react-dom';
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
|
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
|
||||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
|
||||||
import type { ImageDTO, S } from 'services/api/types';
|
import type { ImageDTO, S } from 'services/api/types';
|
||||||
import { $socket } from 'services/events/stores';
|
import { $lastCanvasProgressImage, $socket } from 'services/events/stores';
|
||||||
import type { Equals } from 'tsafe';
|
import type { Equals } from 'tsafe';
|
||||||
import { assert, objectEntries } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
|
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
|
||||||
|
|
||||||
@@ -131,53 +125,14 @@ const SimpleActiveSession = memo(() => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isStaging = useAppSelector(selectIsStaging);
|
const isStaging = useAppSelector(selectIsStaging);
|
||||||
const socket = useStore($socket);
|
const socket = useStore($socket);
|
||||||
const [$progressImage] = useState(() => atom<EphemeralProgressImage | null>(null));
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!socket) {
|
if (!socket) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const onInvocationProgress = (event: S['InvocationProgressEvent']) => {
|
|
||||||
if (!event) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (event.origin !== 'canvas') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!event.image) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
$progressImage.set({ sessionId: event.session_id, image: event.image });
|
|
||||||
};
|
|
||||||
const onInvocationComplete = async (event: S['InvocationCompleteEvent']) => {
|
|
||||||
const progressImage = $progressImage.get();
|
|
||||||
if (!progressImage) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (progressImage.sessionId !== event.session_id) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!isCanvasOutputEvent(event)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let imageDTO: ImageDTO | null = null;
|
|
||||||
for (const [_name, value] of objectEntries(event.result)) {
|
|
||||||
if (isImageField(value)) {
|
|
||||||
imageDTO = await getImageDTOSafe(value.image_name);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!imageDTO) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
flushSync(() => {
|
|
||||||
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
|
||||||
});
|
|
||||||
$progressImage.set(null);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onQueueItemStatusChanged = (event: S['QueueItemStatusChangedEvent']) => {
|
const onQueueItemStatusChanged = (event: S['QueueItemStatusChangedEvent']) => {
|
||||||
const progressImage = $progressImage.get();
|
const progressImage = $lastCanvasProgressImage.get();
|
||||||
if (!progressImage) {
|
if (!progressImage) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -187,20 +142,16 @@ const SimpleActiveSession = memo(() => {
|
|||||||
if (event.status !== 'canceled' && event.status !== 'failed') {
|
if (event.status !== 'canceled' && event.status !== 'failed') {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
$progressImage.set(null);
|
$lastCanvasProgressImage.set(null);
|
||||||
};
|
};
|
||||||
console.log('SUB session preview image listeners');
|
console.log('SUB session preview image listeners');
|
||||||
socket.on('invocation_progress', onInvocationProgress);
|
|
||||||
socket.on('invocation_complete', onInvocationComplete);
|
|
||||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
console.log('UNSUB session preview image listeners');
|
console.log('UNSUB session preview image listeners');
|
||||||
socket.off('invocation_progress', onInvocationProgress);
|
|
||||||
socket.off('invocation_complete', onInvocationComplete);
|
|
||||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||||
};
|
};
|
||||||
}, [$progressImage, dispatch, socket]);
|
}, [dispatch, socket]);
|
||||||
|
|
||||||
const onReset = useCallback(() => {
|
const onReset = useCallback(() => {
|
||||||
dispatch(canvasReset());
|
dispatch(canvasReset());
|
||||||
@@ -226,15 +177,15 @@ const SimpleActiveSession = memo(() => {
|
|||||||
</Text>
|
</Text>
|
||||||
<Button onClick={onReset}>reset</Button>
|
<Button onClick={onReset}>reset</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
<SelectedImage $progressImage={$progressImage} />
|
<SelectedImage />
|
||||||
<SessionImages />
|
<SessionImages />
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
SimpleActiveSession.displayName = 'SimpleActiveSession';
|
SimpleActiveSession.displayName = 'SimpleActiveSession';
|
||||||
|
|
||||||
const SelectedImage = memo(({ $progressImage }: { $progressImage: Atom<EphemeralProgressImage | null> }) => {
|
const SelectedImage = memo(() => {
|
||||||
const progressImage = useStore($progressImage);
|
const progressImage = useStore($lastCanvasProgressImage);
|
||||||
const selectedImage = useAppSelector(selectSelectedImage);
|
const selectedImage = useAppSelector(selectSelectedImage);
|
||||||
|
|
||||||
if (progressImage) {
|
if (progressImage) {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { canvasSessionStarted } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
|
||||||
import type { ParamsState, RgbaColor } from 'features/controlLayers/store/types';
|
import type { ParamsState, RgbaColor } from 'features/controlLayers/store/types';
|
||||||
import { getInitialParamsState } from 'features/controlLayers/store/types';
|
import { getInitialParamsState } from 'features/controlLayers/store/types';
|
||||||
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
|
||||||
@@ -25,7 +24,6 @@ import { clamp } from 'lodash-es';
|
|||||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
|
||||||
export const paramsSlice = createSlice({
|
export const paramsSlice = createSlice({
|
||||||
name: 'params',
|
name: 'params',
|
||||||
initialState: getInitialParamsState(),
|
initialState: getInitialParamsState(),
|
||||||
@@ -188,9 +186,6 @@ export const paramsSlice = createSlice({
|
|||||||
},
|
},
|
||||||
paramsReset: (state) => resetState(state),
|
paramsReset: (state) => resetState(state),
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
|
||||||
builder.addCase(canvasSessionStarted, (state) => resetState(state));
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const resetState = (state: ParamsState): ParamsState => {
|
const resetState = (state: ParamsState): ParamsState => {
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks
|
|||||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { isCanvasOutputEvent } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
|
import { flushSync } from 'react-dom';
|
||||||
import type { ApiTagDescription } from 'services/api';
|
import type { ApiTagDescription } from 'services/api';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { ImageDTO, S } from 'services/api/types';
|
import type { ImageDTO, S } from 'services/api/types';
|
||||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||||
import { $lastProgressEvent } from 'services/events/stores';
|
import { $lastCanvasProgressImage, $lastProgressEvent } from 'services/events/stores';
|
||||||
import type { Param0 } from 'tsafe';
|
import type { Param0 } from 'tsafe';
|
||||||
import { objectEntries } from 'tsafe';
|
import { objectEntries } from 'tsafe';
|
||||||
import type { JsonObject } from 'type-fest';
|
import type { JsonObject } from 'type-fest';
|
||||||
@@ -176,7 +177,11 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
flushSync(() => {
|
||||||
|
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
||||||
|
});
|
||||||
|
|
||||||
|
$lastCanvasProgressImage.set(null);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
|
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events
|
|||||||
import type { Socket } from 'socket.io-client';
|
import type { Socket } from 'socket.io-client';
|
||||||
import type { JsonObject } from 'type-fest';
|
import type { JsonObject } from 'type-fest';
|
||||||
|
|
||||||
import { $lastProgressEvent } from './stores';
|
import { $lastCanvasProgressEvent, $lastProgressEvent } from './stores';
|
||||||
|
|
||||||
const log = logger('events');
|
const log = logger('events');
|
||||||
|
|
||||||
@@ -428,6 +428,10 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
|
|||||||
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
|
// If the queue item is completed, failed, or cancelled, we want to clear the last progress event
|
||||||
$lastProgressEvent.set(null);
|
$lastProgressEvent.set(null);
|
||||||
|
|
||||||
|
if (data.origin === 'canvas') {
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
}
|
||||||
|
|
||||||
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
|
// When a validation run is completed, we want to clear the validation run batch ID & set the workflow as published
|
||||||
const validationRunData = $validationRunData.get();
|
const validationRunData = $validationRunData.get();
|
||||||
if (!validationRunData || batch_status.batch_id !== validationRunData.batchId || status !== 'completed') {
|
if (!validationRunData || batch_status.batch_id !== validationRunData.batchId || status !== 'completed') {
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import type { ProgressImage } from 'features/nodes/types/common';
|
||||||
import { round } from 'lodash-es';
|
import { round } from 'lodash-es';
|
||||||
import { atom, computed, map } from 'nanostores';
|
import { atom, computed, map } from 'nanostores';
|
||||||
import type { S } from 'services/api/types';
|
import type { S } from 'services/api/types';
|
||||||
@@ -15,18 +16,33 @@ $lastProgressEvent.subscribe((event) => {
|
|||||||
switch (event.destination) {
|
switch (event.destination) {
|
||||||
case 'workflows':
|
case 'workflows':
|
||||||
$lastWorkflowsProgressEvent.set(event);
|
$lastWorkflowsProgressEvent.set(event);
|
||||||
|
if (event.image) {
|
||||||
|
$lastWorkflowsProgressImage.set({ sessionId: event.session_id, image: event.image });
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case 'upscaling':
|
case 'upscaling':
|
||||||
$lastUpscalingProgressEvent.set(event);
|
$lastUpscalingProgressEvent.set(event);
|
||||||
|
if (event.image) {
|
||||||
|
$lastUpscalingProgressImage.set({ sessionId: event.session_id, image: event.image });
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case 'canvas':
|
case 'canvas':
|
||||||
$lastCanvasProgressEvent.set(event);
|
$lastCanvasProgressEvent.set(event);
|
||||||
|
if (event.image) {
|
||||||
|
$lastCanvasProgressImage.set({ sessionId: event.session_id, image: event.image });
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
|
||||||
|
|
||||||
export const $lastCanvasProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
export const $lastCanvasProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||||
|
export const $lastCanvasProgressImage = atom<EphemeralProgressImage | null>(null);
|
||||||
export const $lastWorkflowsProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
export const $lastWorkflowsProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||||
|
export const $lastWorkflowsProgressImage = atom<EphemeralProgressImage | null>(null);
|
||||||
export const $lastUpscalingProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
export const $lastUpscalingProgressEvent = atom<S['InvocationProgressEvent'] | null>(null);
|
||||||
|
export const $lastUpscalingProgressImage = atom<EphemeralProgressImage | null>(null);
|
||||||
|
|
||||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null);
|
export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null);
|
||||||
export const $hasProgressImage = computed($lastProgressEvent, (val) => Boolean(val?.image));
|
export const $hasProgressImage = computed($lastProgressEvent, (val) => Boolean(val?.image));
|
||||||
|
|||||||
Reference in New Issue
Block a user