mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 13:48:24 -05:00
Compare commits
4 Commits
main
...
psyche/exp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
661625e8d7 | ||
|
|
3b0f5ecd6b | ||
|
|
1730a0cd41 | ||
|
|
a32f3be4f1 |
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.app.invocations.fields import (
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -535,3 +536,27 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation_output("any_output")
|
||||
class AnyOutput(BaseInvocationOutput):
|
||||
value: Any = OutputField(description="The output value", ui_type=UIType.Any)
|
||||
|
||||
|
||||
@invocation(
|
||||
"switcher",
|
||||
title="Switcher",
|
||||
tags=["primitives", "switcher"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SwitcherInvocation(BaseInvocation):
|
||||
a: Any = InputField(description="The first input", ui_type=UIType.Any)
|
||||
b: Any = InputField(description="The second input", ui_type=UIType.Any)
|
||||
switch: bool = InputField(
|
||||
description="Switch between the two inputs. If false, the first input is returned. If true, the second input is returned."
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnyOutput:
|
||||
value = self.b if self.switch else self.a
|
||||
return AnyOutput(value=value)
|
||||
|
||||
@@ -58,6 +58,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
|
||||
const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString);
|
||||
|
||||
const isTargetAnyType = targetType.name === 'AnyField';
|
||||
const isSourceAnyType = sourceType.name === 'AnyField';
|
||||
|
||||
// One of these must be true for the connection to be valid
|
||||
return (
|
||||
@@ -67,6 +68,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
|
||||
isGenericCollectionToAnyCollectionOrSingleOrCollection ||
|
||||
isCollectionToGenericCollection ||
|
||||
isSubTypeMatch ||
|
||||
isTargetAnyType
|
||||
isTargetAnyType ||
|
||||
isSourceAnyType
|
||||
);
|
||||
};
|
||||
|
||||
@@ -144,7 +144,7 @@ export const parseSchema = (
|
||||
|
||||
const fieldType = fieldTypeOverride ?? originalFieldType;
|
||||
if (!fieldType) {
|
||||
log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
|
||||
log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
|
||||
return inputsAccumulator;
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ export const parseSchema = (
|
||||
|
||||
const fieldType = fieldTypeOverride ?? originalFieldType;
|
||||
if (!fieldType) {
|
||||
log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
|
||||
log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
|
||||
return outputsAccumulator;
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ const getFieldType = (
|
||||
} catch (e) {
|
||||
const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError';
|
||||
if (e instanceof FieldParseError) {
|
||||
log.warn(
|
||||
log.trace(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
@@ -282,7 +282,7 @@ const getFieldType = (
|
||||
})
|
||||
);
|
||||
} else {
|
||||
log.warn(
|
||||
log.trace(
|
||||
{
|
||||
node: type,
|
||||
field: propertyName,
|
||||
|
||||
Reference in New Issue
Block a user