mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): "subscribe" to particular nodes
feels like a dirty hack but oh well it works
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { find } from 'lodash';
|
||||
import {
|
||||
Graph,
|
||||
ImageToImageInvocation,
|
||||
@@ -31,7 +32,12 @@ const buildBaseNode = (
|
||||
return mapTabToFunction(activeTabName)(state);
|
||||
};
|
||||
|
||||
export const buildGraph = (state: RootState): Graph => {
|
||||
type BuildGraphOutput = {
|
||||
graph: Graph;
|
||||
nodeIdsToSubscribe: string[];
|
||||
};
|
||||
|
||||
export const buildGraph = (state: RootState): BuildGraphOutput => {
|
||||
const { generation, postprocessing } = state;
|
||||
const { iterations } = generation;
|
||||
const { hiresFix, hiresStrength } = postprocessing;
|
||||
@@ -39,6 +45,7 @@ export const buildGraph = (state: RootState): Graph => {
|
||||
const baseNode = buildBaseNode(state);
|
||||
|
||||
let graph: Graph = { nodes: baseNode };
|
||||
const nodeIdsToSubscribe: string[] = [];
|
||||
|
||||
if (iterations > 1) {
|
||||
graph = buildIteration({ graph, iterations });
|
||||
@@ -56,7 +63,8 @@ export const buildGraph = (state: RootState): Graph => {
|
||||
},
|
||||
edges: [...(graph.edges || []), edge],
|
||||
};
|
||||
nodeIdsToSubscribe.push(Object.keys(node)[0]);
|
||||
}
|
||||
|
||||
return graph;
|
||||
return { graph, nodeIdsToSubscribe };
|
||||
};
|
||||
|
||||
@@ -88,6 +88,10 @@ export interface SystemState
|
||||
* Whether or not a scheduled cancelation is pending
|
||||
*/
|
||||
isCancelScheduled: boolean;
|
||||
/**
|
||||
* Array of node IDs that we want to handle when events received
|
||||
*/
|
||||
subscribedNodeIds: string[];
|
||||
}
|
||||
|
||||
const initialSystemState: SystemState = {
|
||||
@@ -134,6 +138,7 @@ const initialSystemState: SystemState = {
|
||||
sessionId: null,
|
||||
cancelType: 'immediate',
|
||||
isCancelScheduled: false,
|
||||
subscribedNodeIds: [],
|
||||
};
|
||||
|
||||
export const systemSlice = createSlice({
|
||||
@@ -325,6 +330,12 @@ export const systemSlice = createSlice({
|
||||
cancelTypeChanged: (state, action: PayloadAction<CancelType>) => {
|
||||
state.cancelType = action.payload;
|
||||
},
|
||||
/**
|
||||
* The array of subscribed node ids was changed
|
||||
*/
|
||||
subscribedNodeIdsSet: (state, action: PayloadAction<string[]>) => {
|
||||
state.subscribedNodeIds = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
@@ -524,6 +535,7 @@ export const {
|
||||
cancelScheduled,
|
||||
scheduledCancelAborted,
|
||||
cancelTypeChanged,
|
||||
subscribedNodeIdsSet,
|
||||
} = systemSlice.actions;
|
||||
|
||||
export default systemSlice.reducer;
|
||||
|
||||
@@ -112,6 +112,15 @@ export const socketMiddleware = () => {
|
||||
// Everything else only happens once we have created a session
|
||||
if (isFulfilledSessionCreatedAction(action)) {
|
||||
const oldSessionId = getState().system.sessionId;
|
||||
const subscribedNodeIds = getState().system.subscribedNodeIds;
|
||||
|
||||
const shouldHandleEvent = (id: string): boolean => {
|
||||
if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
|
||||
return true;
|
||||
}
|
||||
|
||||
return subscribedNodeIds.includes(id);
|
||||
};
|
||||
|
||||
if (oldSessionId) {
|
||||
// Unsubscribe when invocations complete
|
||||
@@ -150,28 +159,36 @@ export const socketMiddleware = () => {
|
||||
|
||||
// Set up listeners for the present subscription
|
||||
socket.on('invocation_started', (data: InvocationStartedEvent) => {
|
||||
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
||||
if (shouldHandleEvent(data.source_id)) {
|
||||
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
|
||||
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
||||
if (shouldHandleEvent(data.source_id)) {
|
||||
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('invocation_error', (data: InvocationErrorEvent) => {
|
||||
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
||||
if (shouldHandleEvent(data.source_id)) {
|
||||
dispatch(invocationError({ data, timestamp: getTimestamp() }));
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
|
||||
const sessionId = data.graph_execution_state_id;
|
||||
if (shouldHandleEvent(data.source_id)) {
|
||||
const sessionId = data.graph_execution_state_id;
|
||||
|
||||
const { cancelType, isCancelScheduled } = getState().system;
|
||||
const { cancelType, isCancelScheduled } = getState().system;
|
||||
|
||||
// Handle scheduled cancelation
|
||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||
dispatch(sessionCanceled({ sessionId }));
|
||||
// Handle scheduled cancelation
|
||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||
dispatch(sessionCanceled({ sessionId }));
|
||||
}
|
||||
|
||||
dispatch(invocationComplete({ data, timestamp: getTimestamp() }));
|
||||
}
|
||||
|
||||
dispatch(invocationComplete({ data, timestamp: getTimestamp() }));
|
||||
});
|
||||
|
||||
// Finally we actually invoke the session, starting processing
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { GraphExecutionState } from '../api';
|
||||
import { Graph, GraphExecutionState } from '../api';
|
||||
|
||||
/**
|
||||
* A progress image, we get one for each step in the generation
|
||||
@@ -9,6 +9,12 @@ export type ProgressImage = {
|
||||
height: number;
|
||||
};
|
||||
|
||||
export type AnyInvocation = NonNullable<
|
||||
NonNullable<Graph['nodes']>[string]['type']
|
||||
>;
|
||||
|
||||
export type AnyResult = GraphExecutionState['results'][string];
|
||||
|
||||
/**
|
||||
* A `generator_progress` socket.io event.
|
||||
*
|
||||
@@ -16,7 +22,8 @@ export type ProgressImage = {
|
||||
*/
|
||||
export type GeneratorProgressEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation_id: string;
|
||||
invocation: AnyInvocation;
|
||||
source_id: string;
|
||||
progress_image?: ProgressImage;
|
||||
step: number;
|
||||
total_steps: number;
|
||||
@@ -31,8 +38,8 @@ export type GeneratorProgressEvent = {
|
||||
*/
|
||||
export type InvocationCompleteEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation_id: string;
|
||||
result: GraphExecutionState['results'][string];
|
||||
source_id: string;
|
||||
result: AnyResult;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -42,7 +49,8 @@ export type InvocationCompleteEvent = {
|
||||
*/
|
||||
export type InvocationErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation_id: string;
|
||||
invocation: AnyInvocation;
|
||||
source_id: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
@@ -53,7 +61,8 @@ export type InvocationErrorEvent = {
|
||||
*/
|
||||
export type InvocationStartedEvent = {
|
||||
graph_execution_state_id: string;
|
||||
invocation_id: string;
|
||||
invocation: AnyInvocation;
|
||||
source_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,23 +2,25 @@ import { createAppAsyncThunk } from 'app/storeUtils';
|
||||
import { SessionsService } from 'services/api';
|
||||
import { buildGraph } from 'common/util/buildGraph';
|
||||
import { isFulfilled } from '@reduxjs/toolkit';
|
||||
import { subscribedNodeIdsSet } from 'features/system/store/systemSlice';
|
||||
|
||||
type SessionCreatedArg = Parameters<
|
||||
(typeof SessionsService)['createSession']
|
||||
>[0];
|
||||
// type SessionCreatedArg = {
|
||||
// graph: Parameters<
|
||||
// (typeof SessionsService)['createSession']
|
||||
// >[0]['requestBody'];
|
||||
// nodeIdsToSubscribe?: string[];
|
||||
// };
|
||||
|
||||
/**
|
||||
* `SessionsService.createSession()` thunk
|
||||
*/
|
||||
export const sessionCreated = createAppAsyncThunk(
|
||||
'api/sessionCreated',
|
||||
async (arg: SessionCreatedArg['requestBody'], _thunkApi) => {
|
||||
let graph = arg;
|
||||
if (!arg) {
|
||||
const { getState } = _thunkApi;
|
||||
const state = getState();
|
||||
graph = buildGraph(state);
|
||||
}
|
||||
async (_arg, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const { graph, nodeIdsToSubscribe } = buildGraph(state);
|
||||
|
||||
dispatch(subscribedNodeIdsSet(nodeIdsToSubscribe));
|
||||
|
||||
const response = await SessionsService.createSession({
|
||||
requestBody: graph,
|
||||
|
||||
Reference in New Issue
Block a user