Setup Probe and UI to accept bria controlnet models

This commit is contained in:
Ilan Tchenak
2025-07-09 23:45:08 +03:00
committed by Ubuntu
parent 9e5e1ec0da
commit 7140f2ec72
12 changed files with 162 additions and 4 deletions

View File

@@ -43,6 +43,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
BriaControlNetModel = "BriaControlNetModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"

View File

@@ -126,6 +126,7 @@ class ModelProbe(object):
CLASS2TYPE = {
"BriaPipeline": ModelType.Main,
"BriaControlNetModel": ModelType.ControlNet,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -1013,6 +1014,9 @@ class ControlNetFolderProbe(FolderProbeBase):
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
if config.get("_class_name", None) == "BriaControlNetModel":
return BaseModelType.Bria
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
if dimension == 768:

View File

@@ -5,6 +5,8 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
ControlNetDiffusersConfig,
ControlNetCheckpointConfig,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@@ -17,6 +19,45 @@ from invokeai.backend.model_manager.taxonomy import (
)
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria control net models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
if submodel_type is None:
raise Exception("A submodel type must be provided when loading control net pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
class BriaDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria main models."""

View File

@@ -27,10 +27,12 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isBriaControlNetModelFieldInputInstance,
isBriaControlNetModelFieldInputTemplate,
isBriaMainModelFieldInputInstance,
isBriaMainModelFieldInputTemplate,
isChatGPT4oModelFieldInputInstance,
isChatGPT4oModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
@@ -119,6 +121,7 @@ import { assert } from 'tsafe';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import BriaControlNetModelFieldInputComponent from './inputs/BriaControlNetModelFieldInputComponent';
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
@@ -458,6 +461,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isBriaControlNetModelFieldInputTemplate(template)) {
if (!isBriaControlNetModelFieldInputInstance(field)) {
return null;
}
return <BriaControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,47 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
BriaControlNetModelFieldInputInstance,
BriaControlNetModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useBriaModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<BriaControlNetModelFieldInputInstance, BriaControlNetModelFieldInputTemplate>;
const BriaControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useBriaModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(BriaControlNetModelFieldInputComponent);

View File

@@ -53,6 +53,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
BriaMainModelField: 'teal.500',
BriaControlNetModelField: 'teal.500',
SD3MainModelField: 'teal.500',
CogView4MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',

View File

@@ -189,6 +189,10 @@ const zBriaMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('BriaMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zBriaControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('BriaControlNetModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@@ -330,6 +334,7 @@ const zStatefulFieldType = z.union([
zStringGeneratorFieldType,
zImageGeneratorFieldType,
zBriaMainModelFieldType,
zBriaControlNetModelFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -347,6 +352,7 @@ const modelFieldTypeNames = [
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zBriaMainModelFieldType.shape.name.value,
zBriaControlNetModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
@@ -914,6 +920,26 @@ export const isBriaMainModelFieldInputTemplate =
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
// #endregion
// #region BriaControlNetModelField
const zBriaControlNetModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zBriaControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBriaControlNetModelFieldValue,
});
const zBriaControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBriaControlNetModelFieldType,
originalType: zFieldType.optional(),
default: zBriaControlNetModelFieldValue,
});
const zBriaControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBriaControlNetModelFieldType,
});
export type BriaControlNetModelFieldInputInstance = z.infer<typeof zBriaControlNetModelFieldInputInstance>;
export type BriaControlNetModelFieldInputTemplate = z.infer<typeof zBriaControlNetModelFieldInputTemplate>;
export const isBriaControlNetModelFieldInputInstance = buildInstanceTypeGuard(zBriaControlNetModelFieldInputInstance);
export const isBriaControlNetModelFieldInputTemplate =
buildTemplateTypeGuard<BriaControlNetModelFieldInputTemplate>('BriaControlNetModelField');
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
@@ -1914,6 +1940,7 @@ export const zStatefulFieldValue = z.union([
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zBriaMainModelFieldValue,
zBriaControlNetModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
@@ -1966,6 +1993,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zBriaMainModelFieldInputInstance,
zBriaControlNetModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
@@ -2009,6 +2037,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zBriaMainModelFieldInputTemplate,
zBriaControlNetModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
@@ -2062,6 +2091,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zBriaMainModelFieldOutputTemplate,
zBriaControlNetModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,

View File

@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
BriaMainModelField: undefined,
BriaControlNetModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,

View File

@@ -3,8 +3,9 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
ChatGPT4oModelFieldInputTemplate,
BriaControlNetModelFieldInputTemplate,
BriaMainModelFieldInputTemplate,
ChatGPT4oModelFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
@@ -357,6 +358,20 @@ const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder<BriaMainMo
return template;
};
const buildBriaControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<BriaControlNetModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: BriaControlNetModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -850,6 +865,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
BriaMainModelField: buildBriaMainModelFieldInputTemplate,
BriaControlNetModelField: buildBriaControlNetModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

View File

@@ -9,7 +9,9 @@ import {
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isBriaControlNetModelConfig,
isBriaMainModelModelConfig,
isChatGPT4oModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isCogView4MainModelModelConfig,
@@ -66,6 +68,7 @@ export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useBriaModels = buildModelsHook(isBriaMainModelModelConfig);
export const useBriaControlNetModels = buildModelsHook(isBriaControlNetModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);

View File

@@ -21320,7 +21320,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "BriaControlNetModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */

View File

@@ -283,6 +283,10 @@ export const isBriaMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'bria';
};
export const isBriaControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
return config.type === 'controlnet' && config.base === 'bria';
};
export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
};