Compare commits

...

4 Commits

Author SHA1 Message Date
psychedelicious
661625e8d7 fix(nodes): inverted switch logic 2025-03-08 17:33:03 +10:00
psychedelicious
3b0f5ecd6b fix(ui): swap log levels when parsing field types
We were logging a warning when it should be a trace and vice versa
2025-03-08 09:10:43 +10:00
psychedelicious
1730a0cd41 experiment: allow Any type outputs to always pass connection validation 2025-03-08 09:10:43 +10:00
psychedelicious
a32f3be4f1 experiment: add untyped switcher node 2025-03-08 09:10:39 +10:00
3 changed files with 33 additions and 6 deletions

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from typing import Any, Optional
import torch
@@ -26,6 +26,7 @@ from invokeai.app.invocations.fields import (
SD3ConditioningField,
TensorField,
UIComponent,
UIType,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -535,3 +536,27 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation_output("any_output")
class AnyOutput(BaseInvocationOutput):
value: Any = OutputField(description="The output value", ui_type=UIType.Any)
@invocation(
"switcher",
title="Switcher",
tags=["primitives", "switcher"],
category="primitives",
version="1.0.0",
)
class SwitcherInvocation(BaseInvocation):
a: Any = InputField(description="The first input", ui_type=UIType.Any)
b: Any = InputField(description="The second input", ui_type=UIType.Any)
switch: bool = InputField(
description="Switch between the two inputs. If false, the first input is returned. If true, the second input is returned."
)
def invoke(self, context: InvocationContext) -> AnyOutput:
value = self.b if self.switch else self.a
return AnyOutput(value=value)

View File

@@ -58,6 +58,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString);
const isTargetAnyType = targetType.name === 'AnyField';
const isSourceAnyType = sourceType.name === 'AnyField';
// One of these must be true for the connection to be valid
return (
@@ -67,6 +68,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
isGenericCollectionToAnyCollectionOrSingleOrCollection ||
isCollectionToGenericCollection ||
isSubTypeMatch ||
isTargetAnyType
isTargetAnyType ||
isSourceAnyType
);
};

View File

@@ -144,7 +144,7 @@ export const parseSchema = (
const fieldType = fieldTypeOverride ?? originalFieldType;
if (!fieldType) {
log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
return inputsAccumulator;
}
@@ -214,7 +214,7 @@ export const parseSchema = (
const fieldType = fieldTypeOverride ?? originalFieldType;
if (!fieldType) {
log.trace({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
log.warn({ node: type, field: propertyName, schema: parseify(property) }, 'Unable to parse field type');
return outputsAccumulator;
}
@@ -269,7 +269,7 @@ const getFieldType = (
} catch (e) {
const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError';
if (e instanceof FieldParseError) {
log.warn(
log.trace(
{
node: type,
field: propertyName,
@@ -282,7 +282,7 @@ const getFieldType = (
})
);
} else {
log.warn(
log.trace(
{
node: type,
field: propertyName,