feat(ui): wip model handling and graph topology validation

This commit is contained in:
psychedelicious
2023-04-10 19:09:24 +10:00
parent d729d1c100
commit 80c555ef76
20 changed files with 1622 additions and 84 deletions

View File

@@ -4,6 +4,7 @@ import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent'
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
@@ -41,5 +42,9 @@ export const InputFieldComponent = (props: InputFieldComponentProps) => {
return <LatentsInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'model') {
return <ModelInputFieldComponent nodeId={nodeId} field={field} />;
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@@ -8,6 +8,7 @@ import {
HStack,
Tooltip,
Icon,
Code,
} from '@chakra-ui/react';
import { FaInfoCircle } from 'react-icons/fa';
import { Invocation } from '../types';
@@ -78,6 +79,7 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
>
<Flex flexDirection="column" gap={2}>
<>
<Code>{id}</Code>
<HStack justifyContent="space-between">
<Heading size="sm" fontWeight={500} color="base.100">
{title}

View File

@@ -1,15 +1,18 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react';
import { Box, HStack } from '@chakra-ui/react';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/buildNodesGraph';
import { buildAdjacencyList } from '../util/isCyclic';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
const adjacencyList = buildAdjacencyList(state.nodes.edges);
return (
<Box
sx={{
@@ -33,7 +36,10 @@ const NodeEditor = () => {
pointerEvents="none"
opacity={0.7}
>
{JSON.stringify(graph, undefined, 2)}
<HStack alignItems={'flex-start'} justifyContent="space-between">
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
<Box w="50%">{JSON.stringify(adjacencyList, null, 2)}</Box>
</HStack>
</Box>
</Box>
);

View File

@@ -0,0 +1,49 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { ModelInputField } from 'features/nodes/types';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
(state: RootState) => state.models.modelList,
(modelList) => {
return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const availableModels = useAppSelector(availableModelsSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{availableModels.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@@ -10,6 +10,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
model: 'model',
};
export const FIELDS: Record<FieldType, FieldUIConfig> = {
@@ -48,4 +49,9 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
model: {
color: 'teal',
title: 'Model',
description: 'Models are models.',
},
};

View File

@@ -92,7 +92,6 @@ const nodesSlice = createSlice({
},
extraReducers(builder) {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
console.log('schema received');
state.schema = action.payload;
state.invocations = parseSchema(action.payload);
});

View File

@@ -21,6 +21,10 @@ export type Invocation = {
* Description of the invocation
*/
description: string;
/**
* Invocation tags
*/
tags: string[];
/**
* Array of invocation inputs
*/
@@ -34,7 +38,15 @@ export type Invocation = {
};
export type FieldUIConfig = {
color: 'red' | 'orange' | 'yellow' | 'green' | 'blue' | 'purple' | 'pink';
color:
| 'red'
| 'orange'
| 'yellow'
| 'green'
| 'blue'
| 'purple'
| 'pink'
| 'teal';
title: string;
description: string;
};
@@ -46,7 +58,8 @@ export type FieldType =
| 'boolean'
| 'enum'
| 'image'
| 'latents';
| 'latents'
| 'model';
export type InputField =
| IntegerInputField
@@ -55,7 +68,8 @@ export type InputField =
| BooleanInputField
| ImageInputField
| LatentsInputField
| EnumInputField;
| EnumInputField
| ModelInputField;
export type OutputField = FieldBase;
@@ -116,3 +130,55 @@ export type EnumInputField = FieldBase & {
enumType: 'string' | 'integer' | 'number';
options: Array<string | number>;
};
export type ModelInputField = FieldBase & {
type: 'model';
value?: string;
};
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/
export type TypeHints = {
[fieldName: string]: FieldType;
};
export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
ui?: {
tags?: string[];
type_hints?: TypeHints;
};
title: string;
properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']>,
'type'
> & {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & { default: string };
};
};
export type InvocationSchemaType = {
default: string; // the type of the invocation
};
export type InvocationBaseSchemaObject = Omit<
OpenAPIV3.BaseSchemaObject,
'title' | 'type' | 'properties'
> &
InvocationSchemaExtra;
interface ArraySchemaObject extends InvocationBaseSchemaObject {
type: OpenAPIV3.ArraySchemaObjectType;
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
}
interface NonArraySchemaObject extends InvocationBaseSchemaObject {
type?: OpenAPIV3.NonArraySchemaObjectType;
}
export type InvocationSchemaObject = ArraySchemaObject | NonArraySchemaObject;
export const isInvocationSchemaObject = (
obj: OpenAPIV3.ReferenceObject | InvocationSchemaObject
): obj is InvocationSchemaObject => !('$ref' in obj);

View File

@@ -1,6 +1,4 @@
import { Edge, Node } from 'reactflow';
import { Graph } from 'services/api';
import { Invocation } from '../types';
import { v4 as uuidv4 } from 'uuid';
import { reduce } from 'lodash';
import { RootState } from 'app/store';
@@ -62,5 +60,5 @@ export const buildNodesGraph = (state: RootState): Graph => {
edges: parsedEdges,
};
return graph as Graph;
return graph;
};

View File

@@ -0,0 +1,58 @@
import collections
class Graph(object):
def __init__(self, edges):
self.edges = edges
self.adj = Graph._build_adjacency_list(edges)
@staticmethod
def _build_adjacency_list(edges):
adj = collections.defaultdict(list)
for edge in edges:
adj[edge[0]].append(edge[1])
adj[edge[1]] # side effect only
return adj
def dfs(G):
discovered = set()
finished = set()
for u in G.adj:
if u not in discovered and u not in finished:
discovered, finished = dfs_visit(G, u, discovered, finished)
def dfs_visit(G, u, discovered, finished):
discovered.add(u)
for v in G.adj[u]:
# Detect cycles
if v in discovered:
print(f"Cycle detected: found a back edge from {u} to {v}.")
break
# Recurse into DFS tree
if v not in finished:
dfs_visit(G, v, discovered, finished)
discovered.remove(u)
finished.add(u)
return discovered, finished
if __name__ == "__main__":
G = Graph([
('u', 'v'),
('u', 'x'),
('v', 'y'),
('w', 'y'),
('w', 'z'),
('x', 'v'),
('y', 'x'),
('z', 'z')])
print(G.adj)
dfs(G)

View File

@@ -11,6 +11,10 @@ import {
OutputField,
StringInputField,
isSchemaObject,
ModelInputField,
TypeHints,
FieldType,
isReferenceObject,
} from '../types';
/**
@@ -137,6 +141,21 @@ const buildBooleanInputField = (
return field;
};
const buildModelInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): ModelInputField => {
const field: ModelInputField = {
type: 'model',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? '',
};
return field;
};
const buildImageInputField = (
input: OpenAPIV3.SchemaObject,
name: string
@@ -184,6 +203,28 @@ const buildEnumInputField = (
return field;
};
export const getFieldType = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
typeHints?: TypeHints
): FieldType | undefined => {
let rawFieldType = '';
if (typeHints && name in typeHints) {
rawFieldType = typeHints[name];
} else if (!schemaObject.type) {
rawFieldType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.enum) {
rawFieldType = 'enum';
} else if (schemaObject.type) {
rawFieldType = schemaObject.type;
}
return FIELD_TYPE_MAP[rawFieldType];
};
/**
* Builds an input field from an invocation schema property.
* @param schemaObject The schema object
@@ -191,36 +232,37 @@ const buildEnumInputField = (
*/
export const buildInputField = (
schemaObject: OpenAPIV3.SchemaObject,
name: string
name: string,
typeHints?: TypeHints
) => {
if (!schemaObject.type) {
// the this input/output is a ref! extract the ref string
const rawType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
const fieldType = getFieldType(schemaObject, name, typeHints);
const fieldType = FIELD_TYPE_MAP[rawType];
if (fieldType === 'image') {
return buildImageInputField(schemaObject, name);
}
if (fieldType === 'latents') {
return buildLatentsInputField(schemaObject, name);
}
if (!fieldType) {
throw `Field type "${fieldType}" is unknown!`;
}
if (schemaObject.enum) {
if (['image', 'ImageField'].includes(fieldType)) {
return buildImageInputField(schemaObject, name);
}
if (['latents', 'LatentsField'].includes(fieldType)) {
return buildLatentsInputField(schemaObject, name);
}
if (['model'].includes(fieldType)) {
return buildModelInputField(schemaObject, name);
}
if (['enum'].includes(fieldType)) {
return buildEnumInputField(schemaObject, name);
}
if (schemaObject.type === 'integer') {
if (['integer'].includes(fieldType)) {
return buildIntegerInputField(schemaObject, name);
}
if (schemaObject.type === 'number') {
if (['number', 'float'].includes(fieldType)) {
return buildFloatInputField(schemaObject, name);
}
if (schemaObject.type === 'string') {
if (['string'].includes(fieldType)) {
return buildStringInputField(schemaObject, name);
}
if (schemaObject.type === 'boolean') {
if (['boolean'].includes(fieldType)) {
return buildBooleanInputField(schemaObject, name);
}
@@ -235,7 +277,8 @@ export const buildInputField = (
*/
export const buildOutputFields = (
refObject: OpenAPIV3.ReferenceObject,
openAPI: OpenAPIV3.Document
openAPI: OpenAPIV3.Document,
typeHints?: TypeHints
): Record<string, OutputField> => {
// extract output schema name from ref
const outputSchemaName = refObject.$ref.split('/').slice(-1)[0];
@@ -251,23 +294,17 @@ export const buildOutputFields = (
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
let rawType: string;
const fieldType = getFieldType(property, propertyName, typeHints);
if (property.allOf) {
// we need to parse the ref to get the actual output field types
rawType = refObjectToFieldType(
property.allOf[0] as OpenAPIV3.ReferenceObject
);
} else {
// we can just use the property's type
rawType = property.type!;
if (!fieldType) {
throw `Field type "${fieldType}" is unknown!`;
}
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: FIELD_TYPE_MAP[rawType],
type: fieldType,
};
}

View File

@@ -0,0 +1,21 @@
import { Edge } from 'reactflow';
export type AdjacencyList = { [nodeId: string]: string[] };
export const buildAdjacencyList = (edges: Edge[]): AdjacencyList => {
const adjacencyList: AdjacencyList = {};
edges.forEach((edge) => {
if (!adjacencyList[edge.source]) {
adjacencyList[edge.source] = [];
}
if (!adjacencyList[edge.target]) {
adjacencyList[edge.target] = [];
}
adjacencyList[edge.source].push(edge.target);
});
return adjacencyList;
};

View File

@@ -3,7 +3,8 @@ import { OpenAPIV3 } from 'openapi-types';
import {
InputField,
Invocation,
isReferenceObject,
InvocationSchemaObject,
isInvocationSchemaObject,
isSchemaObject,
} from '../types';
import { buildInputField, buildOutputFields } from './invocationFieldBuilders';
@@ -20,65 +21,62 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
!key.includes('Iterate') &&
!key.includes('LoadImage') &&
!key.includes('Graph')
);
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce<Record<string, Invocation>>(
(acc, schema: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject) => {
(acc, schema) => {
// only want SchemaObjects
if (isReferenceObject(schema)) {
return acc;
}
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
const type = (
schema.properties!.type as OpenAPIV3.SchemaObject & { default: string }
).default;
const title = schema.title
.replace('Invocation', '')
.split(/(?=[A-Z])/) // split PascalCase into array
.join(' ');
const title = schema
.title!.replace('Invocation', '')
.split(/(?=[A-Z])/) // split PascalCase into array
.join(' ');
const typeHints = schema.ui?.type_hints;
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
const field = buildInputField(property, propertyName);
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
const field = buildInputField(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
if (field) {
inputsAccumulator[propertyName] = field;
}
}
}
return inputsAccumulator;
},
{} as Record<string, InputField>
);
return inputsAccumulator;
},
{} as Record<string, InputField>
);
const rawOutput = (
schema as OpenAPIV3.SchemaObject & {
output: OpenAPIV3.ReferenceObject;
}
).output;
const rawOutput = (schema as InvocationSchemaObject).output;
const outputs = buildOutputFields(rawOutput, openAPI);
const outputs = buildOutputFields(rawOutput, openAPI, typeHints);
const invocation: Invocation = {
title,
type,
description: schema.description ?? '',
inputs,
outputs,
};
const invocation: Invocation = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
acc[type] = invocation;
acc[type] = invocation;
}
return acc;
},
{}
);
console.debug('Generated invocations: ', invocations);
return invocations;
};

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,7 @@ export type { Body_upload_image } from './models/Body_upload_image';
export type { CkptModelInfo } from './models/CkptModelInfo';
export type { CollectInvocation } from './models/CollectInvocation';
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CropImageInvocation } from './models/CropImageInvocation';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
@@ -68,6 +69,7 @@ export { $Body_upload_image } from './schemas/$Body_upload_image';
export { $CkptModelInfo } from './schemas/$CkptModelInfo';
export { $CollectInvocation } from './schemas/$CollectInvocation';
export { $CollectInvocationOutput } from './schemas/$CollectInvocationOutput';
export { $CreateModelRequest } from './schemas/$CreateModelRequest';
export { $CropImageInvocation } from './schemas/$CropImageInvocation';
export { $CvInpaintInvocation } from './schemas/$CvInpaintInvocation';
export { $DiffusersModelInfo } from './schemas/$DiffusersModelInfo';

View File

@@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { CkptModelInfo } from './CkptModelInfo';
import type { DiffusersModelInfo } from './DiffusersModelInfo';
export type CreateModelRequest = {
/**
* The name of the model
*/
name: string;
/**
* The model info
*/
info: (CkptModelInfo | DiffusersModelInfo);
};

View File

@@ -15,7 +15,7 @@ export type ImageMetadata = {
*/
width: number;
/**
* The width of the image in pixels
* The height of the image in pixels
*/
height: number;
/**

View File

@@ -0,0 +1,22 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $CreateModelRequest = {
properties: {
name: {
type: 'string',
description: `The name of the model`,
isRequired: true,
},
info: {
type: 'one-of',
description: `The model info`,
contains: [{
type: 'CkptModelInfo',
}, {
type: 'DiffusersModelInfo',
}],
isRequired: true,
},
},
} as const;

View File

@@ -16,7 +16,7 @@ export const $ImageMetadata = {
},
height: {
type: 'number',
description: `The width of the image in pixels`,
description: `The height of the image in pixels`,
isRequired: true,
},
sd_metadata: {

View File

@@ -1,6 +1,7 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { CreateModelRequest } from '../models/CreateModelRequest';
import type { ModelsList } from '../models/ModelsList';
import type { CancelablePromise } from '../core/CancelablePromise';
@@ -22,4 +23,50 @@ export class ModelsService {
});
}
/**
* Update Model
* Add Model
* @returns any Successful Response
* @throws ApiError
*/
public static updateModel({
requestBody,
}: {
requestBody: CreateModelRequest,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/models/',
body: requestBody,
mediaType: 'application/json',
errors: {
422: `Validation Error`,
},
});
}
/**
* Delete Model
* Delete Model
* @returns any Successful Response
* @throws ApiError
*/
public static delModel({
modelName,
}: {
modelName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'DELETE',
url: '/api/v1/models/{model_name}',
path: {
'model_name': modelName,
},
errors: {
404: `Model not found`,
422: `Validation Error`,
},
});
}
}

View File

@@ -6,6 +6,9 @@ export const receivedOpenAPISchema = createAsyncThunk(
async () => {
const response = await fetch(`openapi.json`);
const jsonData = (await response.json()) as OpenAPIV3.Document;
console.debug('OpenAPI schema: ', jsonData);
return jsonData;
}
);