mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): wip node editor
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties } from 'react';
|
||||
import { Handle, Position, Connection, HandleType } from 'reactflow';
|
||||
import { CSSProperties, useMemo } from 'react';
|
||||
import {
|
||||
Handle,
|
||||
Position,
|
||||
Connection,
|
||||
HandleType,
|
||||
useReactFlow,
|
||||
} from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
|
||||
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
import { InputField, OutputField } from '../types';
|
||||
@@ -9,6 +15,7 @@ const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
opacity: 0.5,
|
||||
borderWidth: 0,
|
||||
};
|
||||
|
||||
@@ -20,6 +27,10 @@ const outputHandleStyles: CSSProperties = {
|
||||
right: '-0.5rem',
|
||||
};
|
||||
|
||||
const requiredConnectionStyles: CSSProperties = {
|
||||
opacity: 1,
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputField | OutputField;
|
||||
@@ -30,7 +41,7 @@ type FieldHandleProps = {
|
||||
|
||||
export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
const { name, title, type, description, connectionType } = field;
|
||||
|
||||
// this needs to iterate over every candicate target node, calculating graph cycles
|
||||
// WIP
|
||||
@@ -56,8 +67,9 @@ export const FieldHandle = (props: FieldHandleProps) => {
|
||||
style={{
|
||||
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
|
||||
...styles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
...handleBaseStyles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
...(connectionType === 'always' ? requiredConnectionStyles : {}),
|
||||
// ...connectionEventStyles,
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -24,6 +24,8 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
// TODO: determine if a field/handle is connected and disable the input if so
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@@ -74,12 +76,14 @@ export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
|
||||
</HStack>
|
||||
<InputFieldComponent nodeId={id} field={input} />
|
||||
</FormControl>
|
||||
<FieldHandle
|
||||
nodeId={id}
|
||||
field={input}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
{input.connectionType !== 'never' && (
|
||||
<FieldHandle
|
||||
nodeId={id}
|
||||
field={input}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -73,11 +73,14 @@ export type InputField =
|
||||
|
||||
export type OutputField = FieldBase;
|
||||
|
||||
export type ConnectionType = 'never' | 'always';
|
||||
|
||||
export type FieldBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
description: string;
|
||||
type: FieldType;
|
||||
connectionType?: ConnectionType;
|
||||
};
|
||||
|
||||
export type NumberInvocationField = {
|
||||
|
||||
@@ -14,9 +14,16 @@ import {
|
||||
ModelInputField,
|
||||
TypeHints,
|
||||
FieldType,
|
||||
isReferenceObject,
|
||||
InputField,
|
||||
} from '../types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
export type BuildInputFieldArg = {
|
||||
schemaObject: OpenAPIV3.SchemaObject;
|
||||
baseField: Pick<InputField, BaseFieldProperties>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an invocation output ref object to field type.
|
||||
* @param ref The ref string to transform
|
||||
@@ -29,178 +36,158 @@ export const refObjectToFieldType = (
|
||||
refObject: OpenAPIV3.ReferenceObject
|
||||
): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const buildIntegerInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): IntegerInputField => {
|
||||
const field: IntegerInputField = {
|
||||
const buildIntegerInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerInputField => {
|
||||
const field: Omit<IntegerInputField, BaseFieldProperties> = {
|
||||
type: 'integer',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? 0,
|
||||
value: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (input.multipleOf !== undefined) {
|
||||
field.multipleOf = input.multipleOf;
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
field.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (input.maximum !== undefined) {
|
||||
field.maximum = input.maximum;
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
field.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (input.exclusiveMaximum !== undefined) {
|
||||
field.exclusiveMaximum = input.exclusiveMaximum;
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (input.minimum !== undefined) {
|
||||
field.minimum = input.minimum;
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
field.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (input.exclusiveMinimum !== undefined) {
|
||||
field.exclusiveMinimum = input.exclusiveMinimum;
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildFloatInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): FloatInputField => {
|
||||
const field: FloatInputField = {
|
||||
const buildFloatInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatInputField => {
|
||||
const field: Omit<FloatInputField, BaseFieldProperties> = {
|
||||
type: 'float',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? 0,
|
||||
value: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (input.multipleOf !== undefined) {
|
||||
field.multipleOf = input.multipleOf;
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
field.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (input.maximum !== undefined) {
|
||||
field.maximum = input.maximum;
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
field.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (input.exclusiveMaximum !== undefined) {
|
||||
field.exclusiveMaximum = input.exclusiveMaximum;
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (input.minimum !== undefined) {
|
||||
field.minimum = input.minimum;
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
field.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (input.exclusiveMinimum !== undefined) {
|
||||
field.exclusiveMinimum = input.exclusiveMinimum;
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildStringInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): StringInputField => {
|
||||
const field: StringInputField = {
|
||||
const buildStringInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringInputField => {
|
||||
const field: Omit<StringInputField, BaseFieldProperties> = {
|
||||
type: 'string',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? '',
|
||||
value: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (input.minLength !== undefined) {
|
||||
field.minLength = input.minLength;
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
field.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (input.maxLength !== undefined) {
|
||||
field.maxLength = input.maxLength;
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
field.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (input.pattern !== undefined) {
|
||||
field.pattern = input.pattern;
|
||||
if (schemaObject.pattern !== undefined) {
|
||||
field.pattern = schemaObject.pattern;
|
||||
}
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildBooleanInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): BooleanInputField => {
|
||||
const field: BooleanInputField = {
|
||||
const buildBooleanInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanInputField => {
|
||||
const field: Omit<BooleanInputField, BaseFieldProperties> = {
|
||||
type: 'boolean',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? false,
|
||||
value: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildModelInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): ModelInputField => {
|
||||
const field: ModelInputField = {
|
||||
const buildModelInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ModelInputField => {
|
||||
const field: Omit<ModelInputField, BaseFieldProperties> = {
|
||||
type: 'model',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? '',
|
||||
value: schemaObject.default ?? '',
|
||||
connectionType: 'never',
|
||||
};
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildImageInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): ImageInputField => {
|
||||
const field: ImageInputField = {
|
||||
const buildImageInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ImageInputField => {
|
||||
const field: Omit<ImageInputField, BaseFieldProperties> = {
|
||||
type: 'image',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? '',
|
||||
value: schemaObject.default ?? '',
|
||||
connectionType: 'always',
|
||||
};
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildLatentsInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): LatentsInputField => {
|
||||
const field: LatentsInputField = {
|
||||
const buildLatentsInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsInputField => {
|
||||
const field: Omit<LatentsInputField, BaseFieldProperties> = {
|
||||
type: 'latents',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
description: input.description ?? '',
|
||||
value: input.default ?? '',
|
||||
value: schemaObject.default ?? '',
|
||||
connectionType: 'always',
|
||||
};
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
const buildEnumInputField = (
|
||||
input: OpenAPIV3.SchemaObject,
|
||||
name: string
|
||||
): EnumInputField => {
|
||||
const field: EnumInputField = {
|
||||
const buildEnumInputField = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): EnumInputField => {
|
||||
const field: Omit<EnumInputField, BaseFieldProperties> = {
|
||||
...baseField,
|
||||
type: 'enum',
|
||||
name,
|
||||
title: input.title ?? '',
|
||||
value: input.default,
|
||||
enumType: (input.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
|
||||
options: input.enum ?? [],
|
||||
description: input.description ?? '',
|
||||
value: schemaObject.default,
|
||||
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
|
||||
options: schemaObject.enum ?? [],
|
||||
};
|
||||
|
||||
return field;
|
||||
return { ...baseField, ...field };
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
@@ -241,29 +228,35 @@ export const buildInputField = (
|
||||
throw `Field type "${fieldType}" is unknown!`;
|
||||
}
|
||||
|
||||
if (['image', 'ImageField'].includes(fieldType)) {
|
||||
return buildImageInputField(schemaObject, name);
|
||||
const baseField = {
|
||||
name,
|
||||
title: schemaObject.title ?? '',
|
||||
description: schemaObject.description ?? '',
|
||||
};
|
||||
|
||||
if (['image'].includes(fieldType)) {
|
||||
return buildImageInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['latents', 'LatentsField'].includes(fieldType)) {
|
||||
return buildLatentsInputField(schemaObject, name);
|
||||
if (['latents'].includes(fieldType)) {
|
||||
return buildLatentsInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputField(schemaObject, name);
|
||||
return buildModelInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputField(schemaObject, name);
|
||||
return buildEnumInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['integer'].includes(fieldType)) {
|
||||
return buildIntegerInputField(schemaObject, name);
|
||||
return buildIntegerInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['number', 'float'].includes(fieldType)) {
|
||||
return buildFloatInputField(schemaObject, name);
|
||||
return buildFloatInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['string'].includes(fieldType)) {
|
||||
return buildStringInputField(schemaObject, name);
|
||||
return buildStringInputField({ schemaObject, baseField });
|
||||
}
|
||||
if (['boolean'].includes(fieldType)) {
|
||||
return buildBooleanInputField(schemaObject, name);
|
||||
return buildBooleanInputField({ schemaObject, baseField });
|
||||
}
|
||||
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user