feat(ui): "subscribe" to particular nodes

feels like a dirty hack but oh well it works
This commit is contained in:
psychedelicious
2023-04-07 16:34:02 +10:00
parent d0e9ec267c
commit 1e09fdc8be
5 changed files with 76 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
import { RootState } from 'app/store';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { find } from 'lodash';
import {
Graph,
ImageToImageInvocation,
@@ -31,7 +32,12 @@ const buildBaseNode = (
return mapTabToFunction(activeTabName)(state);
};
export const buildGraph = (state: RootState): Graph => {
type BuildGraphOutput = {
graph: Graph;
nodeIdsToSubscribe: string[];
};
export const buildGraph = (state: RootState): BuildGraphOutput => {
const { generation, postprocessing } = state;
const { iterations } = generation;
const { hiresFix, hiresStrength } = postprocessing;
@@ -39,6 +45,7 @@ export const buildGraph = (state: RootState): Graph => {
const baseNode = buildBaseNode(state);
let graph: Graph = { nodes: baseNode };
const nodeIdsToSubscribe: string[] = [];
if (iterations > 1) {
graph = buildIteration({ graph, iterations });
@@ -56,7 +63,8 @@ export const buildGraph = (state: RootState): Graph => {
},
edges: [...(graph.edges || []), edge],
};
nodeIdsToSubscribe.push(Object.keys(node)[0]);
}
return graph;
return { graph, nodeIdsToSubscribe };
};

View File

@@ -88,6 +88,10 @@ export interface SystemState
* Whether or not a scheduled cancelation is pending
*/
isCancelScheduled: boolean;
/**
* Array of node IDs that we want to handle when events received
*/
subscribedNodeIds: string[];
}
const initialSystemState: SystemState = {
@@ -134,6 +138,7 @@ const initialSystemState: SystemState = {
sessionId: null,
cancelType: 'immediate',
isCancelScheduled: false,
subscribedNodeIds: [],
};
export const systemSlice = createSlice({
@@ -325,6 +330,12 @@ export const systemSlice = createSlice({
cancelTypeChanged: (state, action: PayloadAction<CancelType>) => {
state.cancelType = action.payload;
},
/**
* The array of subscribed node ids was changed
*/
subscribedNodeIdsSet: (state, action: PayloadAction<string[]>) => {
state.subscribedNodeIds = action.payload;
},
},
extraReducers(builder) {
/**
@@ -524,6 +535,7 @@ export const {
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,
subscribedNodeIdsSet,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@@ -112,6 +112,15 @@ export const socketMiddleware = () => {
// Everything else only happens once we have created a session
if (isFulfilledSessionCreatedAction(action)) {
const oldSessionId = getState().system.sessionId;
const subscribedNodeIds = getState().system.subscribedNodeIds;
const shouldHandleEvent = (id: string): boolean => {
if (subscribedNodeIds.length === 1 && subscribedNodeIds[0] === '*') {
return true;
}
return subscribedNodeIds.includes(id);
};
if (oldSessionId) {
// Unsubscribe when invocations complete
@@ -150,28 +159,36 @@ export const socketMiddleware = () => {
// Set up listeners for the present subscription
socket.on('invocation_started', (data: InvocationStartedEvent) => {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
if (shouldHandleEvent(data.source_id)) {
dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
}
});
socket.on('generator_progress', (data: GeneratorProgressEvent) => {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
if (shouldHandleEvent(data.source_id)) {
dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_error', (data: InvocationErrorEvent) => {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
if (shouldHandleEvent(data.source_id)) {
dispatch(invocationError({ data, timestamp: getTimestamp() }));
}
});
socket.on('invocation_complete', (data: InvocationCompleteEvent) => {
const sessionId = data.graph_execution_state_id;
if (shouldHandleEvent(data.source_id)) {
const sessionId = data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
const { cancelType, isCancelScheduled } = getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
dispatch(sessionCanceled({ sessionId }));
}
dispatch(invocationComplete({ data, timestamp: getTimestamp() }));
}
dispatch(invocationComplete({ data, timestamp: getTimestamp() }));
});
// Finally we actually invoke the session, starting processing

View File

@@ -1,4 +1,4 @@
import { GraphExecutionState } from '../api';
import { Graph, GraphExecutionState } from '../api';
/**
* A progress image, we get one for each step in the generation
@@ -9,6 +9,12 @@ export type ProgressImage = {
height: number;
};
export type AnyInvocation = NonNullable<
NonNullable<Graph['nodes']>[string]['type']
>;
export type AnyResult = GraphExecutionState['results'][string];
/**
* A `generator_progress` socket.io event.
*
@@ -16,7 +22,8 @@ export type ProgressImage = {
*/
export type GeneratorProgressEvent = {
graph_execution_state_id: string;
invocation_id: string;
invocation: AnyInvocation;
source_id: string;
progress_image?: ProgressImage;
step: number;
total_steps: number;
@@ -31,8 +38,8 @@ export type GeneratorProgressEvent = {
*/
export type InvocationCompleteEvent = {
graph_execution_state_id: string;
invocation_id: string;
result: GraphExecutionState['results'][string];
source_id: string;
result: AnyResult;
};
/**
@@ -42,7 +49,8 @@ export type InvocationCompleteEvent = {
*/
export type InvocationErrorEvent = {
graph_execution_state_id: string;
invocation_id: string;
invocation: AnyInvocation;
source_id: string;
error: string;
};
@@ -53,7 +61,8 @@ export type InvocationErrorEvent = {
*/
export type InvocationStartedEvent = {
graph_execution_state_id: string;
invocation_id: string;
invocation: AnyInvocation;
source_id: string;
};
/**

View File

@@ -2,23 +2,25 @@ import { createAppAsyncThunk } from 'app/storeUtils';
import { SessionsService } from 'services/api';
import { buildGraph } from 'common/util/buildGraph';
import { isFulfilled } from '@reduxjs/toolkit';
import { subscribedNodeIdsSet } from 'features/system/store/systemSlice';
type SessionCreatedArg = Parameters<
(typeof SessionsService)['createSession']
>[0];
// type SessionCreatedArg = {
// graph: Parameters<
// (typeof SessionsService)['createSession']
// >[0]['requestBody'];
// nodeIdsToSubscribe?: string[];
// };
/**
* `SessionsService.createSession()` thunk
*/
export const sessionCreated = createAppAsyncThunk(
'api/sessionCreated',
async (arg: SessionCreatedArg['requestBody'], _thunkApi) => {
let graph = arg;
if (!arg) {
const { getState } = _thunkApi;
const state = getState();
graph = buildGraph(state);
}
async (_arg, { dispatch, getState }) => {
const state = getState();
const { graph, nodeIdsToSubscribe } = buildGraph(state);
dispatch(subscribedNodeIdsSet(nodeIdsToSubscribe));
const response = await SessionsService.createSession({
requestBody: graph,