feat(ui): wip node editor

This commit is contained in:
psychedelicious
2023-04-11 11:20:35 +10:00
parent ef890058b9
commit cf562f140c
4 changed files with 141 additions and 129 deletions

View File

@@ -1,6 +1,12 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties } from 'react';
import { Handle, Position, Connection, HandleType } from 'reactflow';
import { CSSProperties, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputField, OutputField } from '../types';
@@ -9,6 +15,7 @@ const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
opacity: 0.5,
borderWidth: 0,
};
@@ -20,6 +27,10 @@ const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
const requiredConnectionStyles: CSSProperties = {
opacity: 1,
};
type FieldHandleProps = {
nodeId: string;
field: InputField | OutputField;
@@ -30,7 +41,7 @@ type FieldHandleProps = {
export const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description } = field;
const { name, title, type, description, connectionType } = field;
// this needs to iterate over every candicate target node, calculating graph cycles
// WIP
@@ -56,8 +67,9 @@ export const FieldHandle = (props: FieldHandleProps) => {
style={{
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
...styles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
...(connectionType === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>

View File

@@ -24,6 +24,8 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
const isValidConnection = useIsValidConnection();
// TODO: determine if a field/handle is connected and disable the input if so
return (
<Box
sx={{
@@ -74,12 +76,14 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
</HStack>
<InputFieldComponent nodeId={id} field={input} />
</FormControl>
<FieldHandle
nodeId={id}
field={input}
isValidConnection={isValidConnection}
handleType="target"
/>
{input.connectionType !== 'never' && (
<FieldHandle
nodeId={id}
field={input}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</Box>
);
})}

View File

@@ -73,11 +73,14 @@ export type InputField =
export type OutputField = FieldBase;
export type ConnectionType = 'never' | 'always';
export type FieldBase = {
name: string;
title: string;
description: string;
type: FieldType;
connectionType?: ConnectionType;
};
export type NumberInvocationField = {

View File

@@ -14,9 +14,16 @@ import {
ModelInputField,
TypeHints,
FieldType,
isReferenceObject,
InputField,
} from '../types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
export type BuildInputFieldArg = {
schemaObject: OpenAPIV3.SchemaObject;
baseField: Pick<InputField, BaseFieldProperties>;
};
/**
* Transforms an invocation output ref object to field type.
* @param ref The ref string to transform
@@ -29,178 +36,158 @@ export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0];
const buildIntegerInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): IntegerInputField => {
const field: IntegerInputField = {
const buildIntegerInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerInputField => {
const field: Omit<IntegerInputField, BaseFieldProperties> = {
type: 'integer',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? 0,
value: schemaObject.default ?? 0,
};
if (input.multipleOf !== undefined) {
field.multipleOf = input.multipleOf;
if (schemaObject.multipleOf !== undefined) {
field.multipleOf = schemaObject.multipleOf;
}
if (input.maximum !== undefined) {
field.maximum = input.maximum;
if (schemaObject.maximum !== undefined) {
field.maximum = schemaObject.maximum;
}
if (input.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = input.exclusiveMaximum;
if (schemaObject.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (input.minimum !== undefined) {
field.minimum = input.minimum;
if (schemaObject.minimum !== undefined) {
field.minimum = schemaObject.minimum;
}
if (input.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = input.exclusiveMinimum;
if (schemaObject.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return field;
return { ...baseField, ...field };
};
const buildFloatInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): FloatInputField => {
const field: FloatInputField = {
const buildFloatInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatInputField => {
const field: Omit<FloatInputField, BaseFieldProperties> = {
type: 'float',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? 0,
value: schemaObject.default ?? 0,
};
if (input.multipleOf !== undefined) {
field.multipleOf = input.multipleOf;
if (schemaObject.multipleOf !== undefined) {
field.multipleOf = schemaObject.multipleOf;
}
if (input.maximum !== undefined) {
field.maximum = input.maximum;
if (schemaObject.maximum !== undefined) {
field.maximum = schemaObject.maximum;
}
if (input.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = input.exclusiveMaximum;
if (schemaObject.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (input.minimum !== undefined) {
field.minimum = input.minimum;
if (schemaObject.minimum !== undefined) {
field.minimum = schemaObject.minimum;
}
if (input.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = input.exclusiveMinimum;
if (schemaObject.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return field;
return { ...baseField, ...field };
};
const buildStringInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): StringInputField => {
const field: StringInputField = {
const buildStringInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringInputField => {
const field: Omit<StringInputField, BaseFieldProperties> = {
type: 'string',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? '',
value: schemaObject.default ?? '',
};
if (input.minLength !== undefined) {
field.minLength = input.minLength;
if (schemaObject.minLength !== undefined) {
field.minLength = schemaObject.minLength;
}
if (input.maxLength !== undefined) {
field.maxLength = input.maxLength;
if (schemaObject.maxLength !== undefined) {
field.maxLength = schemaObject.maxLength;
}
if (input.pattern !== undefined) {
field.pattern = input.pattern;
if (schemaObject.pattern !== undefined) {
field.pattern = schemaObject.pattern;
}
return field;
return { ...baseField, ...field };
};
const buildBooleanInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): BooleanInputField => {
const field: BooleanInputField = {
const buildBooleanInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanInputField => {
const field: Omit<BooleanInputField, BaseFieldProperties> = {
type: 'boolean',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? false,
value: schemaObject.default ?? false,
};
return field;
return { ...baseField, ...field };
};
const buildModelInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): ModelInputField => {
const field: ModelInputField = {
const buildModelInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ModelInputField => {
const field: Omit<ModelInputField, BaseFieldProperties> = {
type: 'model',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? '',
value: schemaObject.default ?? '',
connectionType: 'never',
};
return field;
return { ...baseField, ...field };
};
const buildImageInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): ImageInputField => {
const field: ImageInputField = {
const buildImageInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImageInputField => {
const field: Omit<ImageInputField, BaseFieldProperties> = {
type: 'image',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? '',
value: schemaObject.default ?? '',
connectionType: 'always',
};
return field;
return { ...baseField, ...field };
};
const buildLatentsInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): LatentsInputField => {
const field: LatentsInputField = {
const buildLatentsInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsInputField => {
const field: Omit<LatentsInputField, BaseFieldProperties> = {
type: 'latents',
name,
title: input.title ?? '',
description: input.description ?? '',
value: input.default ?? '',
value: schemaObject.default ?? '',
connectionType: 'always',
};
return field;
return { ...baseField, ...field };
};
const buildEnumInputField = (
input: OpenAPIV3.SchemaObject,
name: string
): EnumInputField => {
const field: EnumInputField = {
const buildEnumInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): EnumInputField => {
const field: Omit<EnumInputField, BaseFieldProperties> = {
...baseField,
type: 'enum',
name,
title: input.title ?? '',
value: input.default,
enumType: (input.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
options: input.enum ?? [],
description: input.description ?? '',
value: schemaObject.default,
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
options: schemaObject.enum ?? [],
};
return field;
return { ...baseField, ...field };
};
export const getFieldType = (
@@ -241,29 +228,35 @@ export const buildInputField = (
throw `Field type "${fieldType}" is unknown!`;
}
if (['image', 'ImageField'].includes(fieldType)) {
return buildImageInputField(schemaObject, name);
const baseField = {
name,
title: schemaObject.title ?? '',
description: schemaObject.description ?? '',
};
if (['image'].includes(fieldType)) {
return buildImageInputField({ schemaObject, baseField });
}
if (['latents', 'LatentsField'].includes(fieldType)) {
return buildLatentsInputField(schemaObject, name);
if (['latents'].includes(fieldType)) {
return buildLatentsInputField({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) {
return buildModelInputField(schemaObject, name);
return buildModelInputField({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputField(schemaObject, name);
return buildEnumInputField({ schemaObject, baseField });
}
if (['integer'].includes(fieldType)) {
return buildIntegerInputField(schemaObject, name);
return buildIntegerInputField({ schemaObject, baseField });
}
if (['number', 'float'].includes(fieldType)) {
return buildFloatInputField(schemaObject, name);
return buildFloatInputField({ schemaObject, baseField });
}
if (['string'].includes(fieldType)) {
return buildStringInputField(schemaObject, name);
return buildStringInputField({ schemaObject, baseField });
}
if (['boolean'].includes(fieldType)) {
return buildBooleanInputField(schemaObject, name);
return buildBooleanInputField({ schemaObject, baseField });
}
return;