feat(nodes,ui): consolidate events, reduce network requests

This commit is contained in:
psychedelicious
2023-09-17 23:36:10 +10:00
parent 593d91815d
commit aab7c2c152
37 changed files with 198 additions and 403 deletions

View File

@@ -2,6 +2,7 @@ from typing import Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
@@ -25,10 +26,11 @@ from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
class SessionQueueAndProcessorStatusResult(SessionQueueStatus, SessionProcessorStatus):
class SessionQueueAndProcessorStatus(BaseModel):
"""The overall status of session queue and processor"""
pass
queue: SessionQueueStatus
processor: SessionProcessorStatus
@session_queue_router.post(
@@ -87,23 +89,25 @@ async def list_queue_items(
@session_queue_router.put(
"/{queue_id}/resume",
"/{queue_id}/processor/resume",
operation_id="resume",
responses={200: {"model": SessionProcessorStatus}},
)
async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
) -> SessionProcessorStatus:
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
"/{queue_id}/pause",
"/{queue_id}/processor/pause",
operation_id="pause",
responses={200: {"model": SessionProcessorStatus}},
)
async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
) -> SessionProcessorStatus:
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
@@ -185,28 +189,16 @@ async def get_next_queue_item(
"/{queue_id}/status",
operation_id="get_queue_status",
responses={
200: {"model": SessionQueueStatus},
200: {"model": SessionQueueAndProcessorStatus},
},
)
async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueStatus:
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
@session_queue_router.get(
"/{queue_id}/processor/status",
operation_id="get_processor_status",
responses={
200: {"model": SessionProcessorStatus},
},
)
async def get_processor_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_processor.get_status()
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(

View File

@@ -22,10 +22,6 @@ class SocketIO:
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
self.__sio.on("subscribe_processor", handler=self._handle_sub_processor)
self.__sio.on("unsubscribe_processor", handler=self._handle_unsub_processor)
local_handler.register(event_name=EventServiceBase.processor_event, _func=self._handle_processor_event)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
@@ -55,16 +51,3 @@ class SocketIO:
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
self.__sio.enter_room(sid, data["queue_id"])
async def _handle_processor_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room="processor",
)
async def _handle_sub_processor(self, sid, *args, **kwargs):
self.__sio.enter_room(sid, "processor")
async def _handle_unsub_processor(self, sid, *args, **kwargs):
self.__sio.enter_room(sid, "processor")

View File

@@ -4,7 +4,6 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
@@ -243,7 +242,3 @@ class EventServiceBase:
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
)
def emit_processor_status_changed(self, processor_status: SessionProcessorStatus) -> None:
"""Emitted when the queue is cleared"""
self.__emit_processor_event(event_name="processor_status_changed", payload=processor_status.dict())

View File

@@ -13,12 +13,12 @@ class SessionProcessorBase(ABC):
"""
@abstractmethod
def resume(self) -> None:
def resume(self) -> SessionProcessorStatus:
"""Starts or resumes the session processor"""
pass
@abstractmethod
def pause(self) -> None:
def pause(self) -> SessionProcessorStatus:
"""Pauses the session processor"""
pass

View File

@@ -4,4 +4,3 @@ from pydantic import BaseModel, Field
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")
is_stop_pending: bool = Field(description="Whether processor is pending stopping")

View File

@@ -22,6 +22,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
self.__resume_event = ThreadEvent()
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
@@ -33,7 +34,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
self._emit_status_changed()
def _poll_now(self) -> None:
self.__poll_now_event.set()
@@ -44,12 +44,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event,
poll_now_event=self.__poll_now_event,
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
)
self.__thread.start()
self._emit_status_changed()
async def _on_session_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
@@ -75,41 +73,36 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now()
def _is_started(self) -> bool:
return self.__thread.is_alive()
return self.__resume_event.is_set()
def _is_processing(self) -> bool:
return self.__queue_item is not None
def _is_stop_pending(self) -> bool:
return self.__stop_event.is_set()
def _emit_status_changed(self) -> None:
self.__invoker.services.events.emit_processor_status_changed(self.get_status())
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._is_started(),
is_processing=self._is_processing(),
is_stop_pending=self._is_stop_pending(),
)
def resume(self) -> None:
if self._is_started():
return
self.__stop_event.clear()
self._emit_status_changed()
self._start_thread()
def resume(self) -> SessionProcessorStatus:
if not self.__resume_event.is_set():
self.__resume_event.set()
return self.get_status()
def pause(self) -> None:
self.__stop_event.set()
self._emit_status_changed()
def pause(self) -> SessionProcessorStatus:
if self.__resume_event.is_set():
self.__resume_event.clear()
return self.get_status()
def __process(
self,
stop_event: ThreadEvent,
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
):
try:
stop_event.clear()
resume_event.set()
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
self.__invoker.services.logger
@@ -117,7 +110,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
# do not dequeue if there is already a session running
if self.__queue_item is None:
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
if queue_item is not None:
@@ -140,4 +133,3 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
self.__queue_item = None
self.__threadLimit.release()
self._emit_status_changed()

View File

@@ -232,6 +232,9 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
class SessionQueueStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
item_id: Optional[str] = Field(description="The current queue item id")
batch_id: Optional[str] = Field(description="The current queue item's batch id")
session_id: Optional[str] = Field(description="The current queue item's session id")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")

View File

@@ -787,17 +787,21 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
session_id=current_item.session_id if current_item else None,
batch_id=current_item.batch_id if current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),

View File

@@ -9,6 +9,7 @@ import {
import type { AppDispatch, RootState } from '../../store';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addFirstListImagesListener } from './listeners/addFirstListImagesListener.ts';
import { addAnyEnqueuedListener } from './listeners/anyEnqueued';
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
import { addAppStartedListener } from './listeners/appStarted';
import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndImagesDeleted';
@@ -22,6 +23,9 @@ import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addEnqueueRequestedCanvasListener } from './listeners/enqueueRequestedCanvas';
import { addEnqueueRequestedLinear } from './listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from './listeners/enqueueRequestedNodes';
import {
addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener,
@@ -48,6 +52,7 @@ import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addDynamicPromptsListener } from './listeners/promptChanged';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import {
addSessionCanceledFulfilledListener,
@@ -64,7 +69,6 @@ import {
addSessionInvokedPendingListener,
addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
@@ -74,19 +78,14 @@ import { addInvocationErrorEventListener as addInvocationErrorListener } from '.
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addEnqueueRequestedCanvasListener } from './listeners/enqueueRequestedCanvas';
import { addEnqueueRequestedNodes } from './listeners/enqueueRequestedNodes';
import { addEnqueueRequestedLinear } from './listeners/enqueueRequestedLinear';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addDynamicPromptsListener } from './listeners/promptChanged';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addProcessorStatusChangedEventListener } from './listeners/socketio/socketProcessorStatusChanged';
export const listenerMiddleware = createListenerMiddleware();
@@ -136,7 +135,7 @@ addImagesUnstarredListener();
addEnqueueRequestedCanvasListener();
addEnqueueRequestedNodes();
addEnqueueRequestedLinear();
addSessionReadyToInvokeListener();
addAnyEnqueuedListener();
// Canvas actions
addCanvasSavedToGalleryListener();
@@ -175,7 +174,6 @@ addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
addSocketQueueItemStatusChangedEventListener();
addProcessorStatusChangedEventListener();
// Session Created
addSessionCreatedPendingListener();

View File

@@ -0,0 +1,27 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
const matcher = isAnyOf(
queueApi.endpoints.enqueueBatch.matchFulfilled,
queueApi.endpoints.enqueueGraph.matchFulfilled
);
export const addAnyEnqueuedListener = () => {
startAppListening({
matcher,
effect: async (_, { dispatch, getState }) => {
const { data } = queueApi.endpoints.getQueueStatus.select()(getState());
if (!data || data.processor.is_started) {
return;
}
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
},
});
};

View File

@@ -1,57 +0,0 @@
import { logger } from 'app/logging/logger';
import { batchEnqueued } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
export const addBatchEnqueuedListener = () => {
startAppListening({
actionCreator: batchEnqueued,
effect: async (action, { dispatch }) => {
const log = logger('session');
const batchConfig = action.payload;
const { prepend } = batchConfig;
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
'Failed to enqueue batch'
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
},
});
};

View File

@@ -46,11 +46,7 @@ export const addControlNetImageProcessedListener = () => {
);
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug(
{ enqueueResult: parseify(enqueueResult) },
t('queue.graphQueued')

View File

@@ -139,11 +139,6 @@ export const addEnqueueRequestedCanvasListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');

View File

@@ -50,12 +50,6 @@ export const addEnqueueRequestedLinear = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({

View File

@@ -34,12 +34,6 @@ export const addEnqueueRequestedNodes = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({

View File

@@ -1,18 +0,0 @@
import { logger } from 'app/logging/logger';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionInvoked } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionReadyToInvokeListener = () => {
startAppListening({
actionCreator: sessionReadyToInvoke,
effect: (action, { getState, dispatch }) => {
const log = logger('session');
const { sessionId: session_id } = getState().system;
if (session_id) {
log.debug({ session_id }, `Session ready to invoke (${session_id})})`);
dispatch(sessionInvoked({ session_id }));
}
},
});
};

View File

@@ -1,19 +0,0 @@
import { logger } from 'app/logging/logger';
import { queueApi } from 'services/api/endpoints/queue';
import {
appSocketProcessorStatusChanged,
socketProcessorStatusChanged,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addProcessorStatusChangedEventListener = () => {
startAppListening({
actionCreator: socketProcessorStatusChanged,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.debug(action.payload, 'Processor status changed');
dispatch(appSocketProcessorStatusChanged(action.payload));
dispatch(queueApi.util.invalidateTags(['SessionProcessorStatus']));
},
});
};

View File

@@ -43,7 +43,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'SessionQueueStatus',
'SessionProcessorStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: batch_id },

View File

@@ -37,11 +37,6 @@ export const addUpscaleRequestedListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug(
{ enqueueResult: parseify(enqueueResult) },
t('queue.graphQueued')

View File

@@ -5,9 +5,8 @@ import { useTranslation } from 'react-i18next';
import { FaTimes } from 'react-icons/fa';
import {
useCancelQueueItemMutation,
useGetCurrentQueueItemQuery,
useGetQueueStatusQuery,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
type Props = {
@@ -17,18 +16,17 @@ type Props = {
const CancelCurrentQueueItemButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: currentQueueItem } = useGetCurrentQueueItemQuery();
const [cancelQueueItem] = useCancelQueueItemMutation({
const { data: queueStatus } = useGetQueueStatusQuery();
const [cancelQueueItem, { isLoading }] = useCancelQueueItemMutation({
fixedCacheKey: 'cancelQueueItem',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
if (!currentQueueItem) {
if (!queueStatus?.queue.item_id) {
return;
}
try {
await cancelQueueItem(currentQueueItem.item_id).unwrap();
await cancelQueueItem(queueStatus.queue.item_id).unwrap();
dispatch(
addToast({
title: t('queue.cancelSucceeded'),
@@ -43,14 +41,15 @@ const CancelCurrentQueueItemButton = ({ asIconButton }: Props) => {
})
);
}
}, [cancelQueueItem, currentQueueItem, dispatch, t]);
}, [cancelQueueItem, dispatch, queueStatus?.queue.item_id, t]);
return (
<QueueButton
isDisabled={!queueStatus?.queue.item_id}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.cancel')}
tooltip={t('queue.cancelTooltip')}
isDisabled={!currentQueueItem || isQueueMutationInProgress}
icon={<FaTimes />}
onClick={handleClick}
colorScheme="error"

View File

@@ -9,7 +9,6 @@ import {
} from 'services/api/endpoints/queue';
import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
import QueueButton from './common/QueueButton';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
type Props = {
asIconButton?: boolean;
@@ -19,10 +18,11 @@ const ClearQueueButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: queueStatusData } = useGetQueueStatusQuery();
const { data: queueStatus } = useGetQueueStatusQuery();
const [clearQueue] = useClearQueueMutation({ fixedCacheKey: 'clearQueue' });
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const [clearQueue, { isLoading }] = useClearQueueMutation({
fixedCacheKey: 'clearQueue',
});
const handleClick = useCallback(async () => {
try {
@@ -47,7 +47,8 @@ const ClearQueueButton = ({ asIconButton }: Props) => {
return (
<QueueButton
isDisabled={!queueStatusData?.total || isQueueMutationInProgress}
isDisabled={!queueStatus?.queue.total}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.clear')}
tooltip={t('queue.clearTooltip')}

View File

@@ -4,10 +4,9 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaPause } from 'react-icons/fa';
import {
useGetProcessorStatusQuery,
useGetQueueStatusQuery,
usePauseProcessorMutation,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
type Props = {
@@ -17,11 +16,10 @@ type Props = {
const PauseProcessorButton = ({ asIconButton }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: processorStatus } = useGetProcessorStatusQuery();
const [pauseProcessor] = usePauseProcessorMutation({
const { data: queueStatus } = useGetQueueStatusQuery();
const [pauseProcessor, { isLoading }] = usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
try {
@@ -47,8 +45,8 @@ const PauseProcessorButton = ({ asIconButton }: Props) => {
asIconButton={asIconButton}
label={t('queue.pause')}
tooltip={t('queue.pauseTooltip')}
isDisabled={!processorStatus?.is_started || isQueueMutationInProgress}
isLoading={processorStatus?.is_stop_pending}
isDisabled={!queueStatus?.processor.is_started}
isLoading={isLoading}
icon={<FaPause />}
onClick={handleClick}
colorScheme="gold"

View File

@@ -7,7 +7,6 @@ import {
useGetQueueStatusQuery,
usePruneQueueMutation,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import { listCursorChanged, listPriorityChanged } from '../store/queueSlice';
import QueueButton from './common/QueueButton';
@@ -18,15 +17,18 @@ type Props = {
const PruneQueueButton = ({ asIconButton }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [pruneQueue] = usePruneQueueMutation({ fixedCacheKey: 'pruneQueue' });
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const [pruneQueue, { isLoading }] = usePruneQueueMutation({
fixedCacheKey: 'pruneQueue',
});
const { count } = useGetQueueStatusQuery(undefined, {
selectFromResult: ({ data }) => {
if (!data) {
return { count: 0 };
}
return { count: data.completed + data.canceled + data.failed };
return {
count: data.queue.completed + data.queue.canceled + data.queue.failed,
};
},
});
@@ -53,7 +55,8 @@ const PruneQueueButton = ({ asIconButton }: Props) => {
return (
<QueueButton
isDisabled={!count || isQueueMutationInProgress}
isDisabled={!count}
isLoading={isLoading}
asIconButton={asIconButton}
label={t('queue.prune')}
tooltip={t('queue.pruneTooltip', { item_count: count })}

View File

@@ -7,16 +7,18 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import EnqueueButtonTooltip from './QueueButtonTooltip';
import { useEnqueueBatchMutation } from 'services/api/endpoints/queue';
import { useIsQueueEmpty } from '../hooks/useIsQueueEmpty';
import EnqueueButtonTooltip from './QueueButtonTooltip';
const QueueBackButton = () => {
const tabName = useAppSelector(activeTabNameSelector);
const { t } = useTranslation();
const { isReady } = useIsReadyToEnqueue();
const dispatch = useAppDispatch();
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch',
});
const isEmpty = useIsQueueEmpty();
const handleEnqueue = useCallback(() => {
@@ -28,15 +30,16 @@ const QueueBackButton = () => {
['ctrl+enter', 'meta+enter'],
handleEnqueue,
{
enabled: () => !isQueueMutationInProgress,
enabled: () => !isLoading,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
[isQueueMutationInProgress, tabName]
[tabName, isLoading]
);
return (
<IAIButton
isDisabled={!isReady || isQueueMutationInProgress}
isDisabled={!isReady}
isLoading={isLoading}
colorScheme="accent"
onClick={handleEnqueue}
tooltip={<EnqueueButtonTooltip />}

View File

@@ -55,7 +55,7 @@ const QueueCounts = memo(() => {
};
}
const { pending, in_progress } = data;
const { pending, in_progress } = data.queue;
return {
hasItems: pending + in_progress > 0,

View File

@@ -8,8 +8,8 @@ import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaBoltLightning } from 'react-icons/fa6';
import { useEnqueueBatchMutation } from 'services/api/endpoints/queue';
import { useIsQueueEmpty } from '../hooks/useIsQueueEmpty';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import EnqueueButtonTooltip from './QueueButtonTooltip';
const QueueFrontButton = () => {
@@ -17,7 +17,9 @@ const QueueFrontButton = () => {
const dispatch = useAppDispatch();
const { isReady } = useIsReadyToEnqueue();
const { t } = useTranslation();
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch',
});
const isEmpty = useIsQueueEmpty();
const handleEnqueue = useCallback(() => {
dispatch(clampSymmetrySteps());
@@ -28,17 +30,18 @@ const QueueFrontButton = () => {
['ctrl+shift+enter', 'meta+shift+enter'],
handleEnqueue,
{
enabled: () => !isQueueMutationInProgress,
enabled: () => !isLoading,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
[isQueueMutationInProgress, tabName]
[isLoading, tabName]
);
return (
<IAIIconButton
colorScheme="base"
aria-label={t('queue.queueFront')}
isDisabled={!isReady || isQueueMutationInProgress || isEmpty}
isDisabled={!isReady || isEmpty}
isLoading={isLoading}
onClick={handleEnqueue}
tooltip={<EnqueueButtonTooltip prepend />}
icon={<FaBoltLightning />}

View File

@@ -5,7 +5,7 @@ import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
const QueueStatusCard = () => {
const { t } = useTranslation();
const { data: queueStatusData } = useGetQueueStatusQuery();
const { data: queueStatus } = useGetQueueStatusQuery();
return (
<Flex
@@ -21,31 +21,31 @@ const QueueStatusCard = () => {
<Text as="span" fontWeight={600}>
{t('queue.pending')}:{' '}
</Text>
{queueStatusData?.pending}
{queueStatus?.queue.pending}
</Text>
<Text>
<Text as="span" fontWeight={600}>
{t('queue.inProgress')}:{' '}
</Text>
{queueStatusData?.in_progress}
{queueStatus?.queue.in_progress}
</Text>
<Text>
<Text as="span" fontWeight={600}>
{t('queue.completed')}:{' '}
</Text>
{queueStatusData?.completed}
{queueStatus?.queue.completed}
</Text>
<Text>
<Text as="span" fontWeight={600}>
{t('queue.failed')}:{' '}
</Text>
{queueStatusData?.failed}
{queueStatus?.queue.failed}
</Text>
<Text>
<Text as="span" fontWeight={600}>
{t('queue.canceled')}:{' '}
</Text>
{queueStatusData?.canceled}
{queueStatus?.queue.canceled}
</Text>
</Flex>
);

View File

@@ -4,10 +4,9 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import {
useGetProcessorStatusQuery,
useGetQueueStatusQuery,
useResumeProcessorMutation,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
type Props = {
@@ -15,13 +14,12 @@ type Props = {
};
const ResumeProcessorButton = ({ asIconButton }: Props) => {
const { data: processorStatus } = useGetProcessorStatusQuery();
const { data: queueStatus } = useGetQueueStatusQuery();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [resumeProcessor] = useResumeProcessorMutation({
const [resumeProcessor, { isLoading }] = useResumeProcessorMutation({
fixedCacheKey: 'resumeProcessor',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
try {
@@ -47,11 +45,8 @@ const ResumeProcessorButton = ({ asIconButton }: Props) => {
asIconButton={asIconButton}
label={t('queue.resume')}
tooltip={t('queue.resumeTooltip')}
isDisabled={
processorStatus?.is_started ||
processorStatus?.is_processing ||
isQueueMutationInProgress
}
isDisabled={queueStatus?.processor.is_started}
isLoading={isLoading}
icon={<FaPlay />}
onClick={handleClick}
colorScheme="green"

View File

@@ -6,7 +6,9 @@ export const useIsQueueEmpty = () => {
if (!data) {
return { isEmpty: true };
}
return { isEmpty: data.in_progress === 0 && data.pending === 0 };
return {
isEmpty: data.queue.in_progress === 0 && data.queue.pending === 0,
};
},
});
return isEmpty;

View File

@@ -1,15 +0,0 @@
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
export const useIsQueueStarted = () => {
const { isStarted } = useGetProcessorStatusQuery(undefined, {
selectFromResult: ({ data }) => {
if (!data) {
return { isStarted: false };
}
return { isStarted: data.is_started || data.is_processing };
},
});
return isStarted;
};

View File

@@ -5,7 +5,7 @@ import { SystemState } from 'features/system/store/systemSlice';
import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { systemSelector } from '../store/systemSelectors';
const progressBarSelector = createSelector(
@@ -24,22 +24,24 @@ const progressBarSelector = createSelector(
const ProgressBar = () => {
const { t } = useTranslation();
const { data: processorStatus } = useGetProcessorStatusQuery();
const { data: queueStatus } = useGetQueueStatusQuery();
const { currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(progressBarSelector);
const value = useMemo(() => {
if (currentStep && processorStatus?.is_processing) {
if (currentStep && Boolean(queueStatus?.queue.in_progress)) {
return Math.round((currentStep * 100) / totalSteps);
}
return 0;
}, [currentStep, processorStatus?.is_processing, totalSteps]);
}, [currentStep, queueStatus?.queue.in_progress, totalSteps]);
return (
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={processorStatus?.is_processing && !currentStatusHasSteps}
isIndeterminate={
Boolean(queueStatus?.queue.in_progress) && !currentStatusHasSteps
}
h="full"
w="full"
borderRadius={2}

View File

@@ -19,7 +19,6 @@ import {
appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted,
appSocketProcessorStatusChanged,
appSocketSessionRetrievalError,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
@@ -196,12 +195,6 @@ export const systemSlice = createSlice({
},
},
extraReducers(builder) {
builder.addCase(appSocketProcessorStatusChanged, (state, action) => {
const { is_started, is_stop_pending } = action.payload.data;
if (!is_started && !is_stop_pending) {
state.progressImage = null;
}
});
/**
* Socket Connected
*/

View File

@@ -4,11 +4,11 @@ import {
ThunkDispatch,
createEntityAdapter,
} from '@reduxjs/toolkit';
import { $queueId } from 'features/queue/store/nanoStores';
import { listParamsReset } from 'features/queue/store/queueSlice';
import queryString from 'query-string';
import { ApiTagDescription, api } from '..';
import { components, paths } from '../schema';
import { $queueId } from 'features/queue/store/nanoStores';
import { listParamsReset } from 'features/queue/store/queueSlice';
const getListQueueItemsUrl = (
queryArgs?: paths['/api/v1/queue/{queue_id}/list']['get']['parameters']['query']
@@ -70,7 +70,6 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
],
@@ -95,7 +94,6 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
],
@@ -109,19 +107,25 @@ export const queueApi = api.injectEndpoints({
}
},
}),
resumeProcessor: build.mutation<void, void>({
resumeProcessor: build.mutation<
paths['/api/v1/queue/{queue_id}/processor/resume']['put']['responses']['200']['content']['application/json'],
void
>({
query: () => ({
url: `queue/${$queueId.get()}/resume`,
url: `queue/${$queueId.get()}/processor/resume`,
method: 'PUT',
}),
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
}),
pauseProcessor: build.mutation<void, void>({
pauseProcessor: build.mutation<
paths['/api/v1/queue/{queue_id}/processor/pause']['put']['responses']['200']['content']['application/json'],
void
>({
query: () => ({
url: `queue/${$queueId.get()}/pause`,
url: `queue/${$queueId.get()}/processor/pause`,
method: 'PUT',
}),
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
invalidatesTags: ['CurrentSessionQueueItem', 'SessionQueueStatus'],
}),
pruneQueue: build.mutation<
paths['/api/v1/queue/{queue_id}/prune']['put']['responses']['200']['content']['application/json'],
@@ -133,7 +137,6 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'BatchStatus',
'SessionQueueItem',
'SessionQueueItemDTO',
@@ -217,16 +220,6 @@ export const queueApi = api.injectEndpoints({
}),
providesTags: ['SessionQueueStatus'],
}),
getProcessorStatus: build.query<
paths['/api/v1/queue/{queue_id}/processor/status']['get']['responses']['200']['content']['application/json'],
void
>({
query: () => ({
url: `queue/${$queueId.get()}/processor/status`,
method: 'GET',
}),
providesTags: ['SessionProcessorStatus'],
}),
getBatchStatus: build.query<
paths['/api/v1/queue/{queue_id}/b/{batch_id}/status']['get']['responses']['200']['content']['application/json'],
{ batch_id: string }
@@ -368,7 +361,6 @@ export const {
useGetNextQueueItemQuery,
useListQueueItemsQuery,
useCancelQueueItemMutation,
useGetProcessorStatusQuery,
useGetBatchStatusQuery,
} = queueApi;

View File

@@ -328,14 +328,14 @@ export type paths = {
*/
get: operations["list_queue_items"];
};
"/api/v1/queue/{queue_id}/resume": {
"/api/v1/queue/{queue_id}/processor/resume": {
/**
* Resume
* @description Resumes session processor
*/
put: operations["resume"];
};
"/api/v1/queue/{queue_id}/pause": {
"/api/v1/queue/{queue_id}/processor/pause": {
/**
* Pause
* @description Pauses session processor
@@ -384,13 +384,6 @@ export type paths = {
*/
get: operations["get_queue_status"];
};
"/api/v1/queue/{queue_id}/processor/status": {
/**
* Get Processor Status
* @description Gets the status of the session queue
*/
get: operations["get_processor_status"];
};
"/api/v1/queue/{queue_id}/b/{batch_id}/status": {
/**
* Get Batch Status
@@ -1177,11 +1170,6 @@ export type components = {
* @description The workflow to save with the image
*/
workflow?: string;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
* Skipped Layers
* @description Number of layers to skip in text encoder
@@ -1194,6 +1182,11 @@ export type components = {
* @enum {string}
*/
type: "clip_skip";
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
/**
* ClipSkipInvocationOutput
@@ -6960,11 +6953,14 @@ export type components = {
* @description Whether a session is being processed
*/
is_processing: boolean;
/**
* Is Stop Pending
* @description Whether processor is pending stopping
*/
is_stop_pending: boolean;
};
/**
* SessionQueueAndProcessorStatus
* @description The overall status of session queue and processor
*/
SessionQueueAndProcessorStatus: {
queue: components["schemas"]["SessionQueueStatus"];
processor: components["schemas"]["SessionProcessorStatus"];
};
/**
* SessionQueueItem
@@ -7126,6 +7122,21 @@ export type components = {
* @description The ID of the queue
*/
queue_id: string;
/**
* Item Id
* @description The current queue item id
*/
item_id?: string;
/**
* Batch Id
* @description The current queue item's batch id
*/
batch_id?: string;
/**
* Session Id
* @description The current queue item's session id
*/
session_id?: string;
/**
* Pending
* @description Number of queue items with status 'pending'
@@ -8150,23 +8161,17 @@ export type components = {
ui_order?: number;
};
/**
* StableDiffusionOnnxModelFormat
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
@@ -8179,6 +8184,12 @@ export type components = {
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
};
responses: never;
parameters: never;
@@ -9707,7 +9718,7 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": unknown;
"application/json": components["schemas"]["SessionProcessorStatus"];
};
};
/** @description Validation Error */
@@ -9733,7 +9744,7 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": unknown;
"application/json": components["schemas"]["SessionProcessorStatus"];
};
};
/** @description Validation Error */
@@ -9894,33 +9905,7 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["SessionQueueStatus"];
};
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Get Processor Status
* @description Gets the status of the session queue
*/
get_processor_status: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
queue_id: string;
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["SessionProcessorStatus"];
"application/json": components["schemas"]["SessionQueueAndProcessorStatus"];
};
};
/** @description Validation Error */

View File

@@ -8,7 +8,6 @@ import {
InvocationStartedEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
ProcessorStatusChangedEvent,
QueueItemStatusChangedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types';
@@ -233,19 +232,3 @@ export const socketQueueItemStatusChanged = createAction<{
export const appSocketQueueItemStatusChanged = createAction<{
data: QueueItemStatusChangedEvent;
}>('socket/appSocketQueueItemStatusChanged');
/**
* Socket.IO Processor Status Changed
*
* Do not use. Only for use in middleware.
*/
export const socketProcessorStatusChanged = createAction<{
data: ProcessorStatusChangedEvent;
}>('socket/socketProcessorStatusChanged');
/**
* App-level Processor Status Changed
*/
export const appSocketProcessorStatusChanged = createAction<{
data: ProcessorStatusChangedEvent;
}>('socket/appSocketProcessorStatusChanged');

View File

@@ -148,17 +148,6 @@ export type QueueItemStatusChangedEvent = {
status: components['schemas']['SessionQueueItemDTO']['status'];
};
/**
* A `queue_status_changed` socket.io event.
*
* @example socket.on('queue_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
*/
export type ProcessorStatusChangedEvent = {
is_started: boolean;
is_processing: boolean;
is_stop_pending: boolean;
};
export type ClientEmitSubscribeSession = {
session: string;
};
@@ -188,7 +177,6 @@ export type ServerToClientEvents = {
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void;
processor_status_changed: (payload: ProcessorStatusChangedEvent) => void;
};
export type ClientToServerEvents = {
@@ -198,6 +186,4 @@ export type ClientToServerEvents = {
unsubscribe_session: (payload: ClientEmitUnsubscribeSession) => void;
subscribe_queue: (payload: ClientEmitSubscribeQueue) => void;
unsubscribe_queue: (payload: ClientEmitUnsubscribeQueue) => void;
subscribe_processor: () => void;
unsubscribe_processor: () => void;
};

View File

@@ -16,7 +16,6 @@ import {
socketInvocationStarted,
socketModelLoadCompleted,
socketModelLoadStarted,
socketProcessorStatusChanged,
socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from '../actions';
@@ -41,8 +40,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
socket.emit('subscribe_processor');
});
socket.on('connect_error', (error) => {
@@ -172,8 +169,4 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}
dispatch(socketQueueItemStatusChanged({ data }));
});
socket.on('processor_status_changed', (data) => {
dispatch(socketProcessorStatusChanged({ data }));
});
};