Compare commits

...

1 Commits

Author SHA1 Message Date
Brandon Rising
1ebb03a41c Setup Probe and UI to accept bria main models 2025-04-28 23:03:58 -04:00
14 changed files with 173 additions and 3 deletions

View File

@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"

View File

@@ -125,6 +125,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"BriaPipeline": ModelType.Main,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -861,6 +862,8 @@ class PipelineFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion3
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
return BaseModelType.CogView4
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
return BaseModelType.Bria
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")

View File

@@ -0,0 +1,56 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
class BriaDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result

View File

@@ -28,6 +28,7 @@ class BaseModelType(str, Enum):
CogView4 = "cogview4"
Imagen3 = "imagen3"
ChatGPT4o = "chatgpt-4o"
Bria = "bria"
class ModelType(str, Enum):

View File

@@ -23,6 +23,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isBriaMainModelFieldInputInstance,
isBriaMainModelFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
@@ -105,6 +107,7 @@ import { assert } from 'tsafe';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
@@ -408,6 +411,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isBriaMainModelFieldInputTemplate(template)) {
if (!isBriaMainModelFieldInputInstance(field)) {
return null;
}
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isSD3MainModelFieldInputTemplate(template)) {
if (!isSD3MainModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,44 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useBriaModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate>;
const BriaMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useBriaModels();
const onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(BriaMainModelFieldInputComponent);

View File

@@ -77,9 +77,10 @@ const zBaseModel = z.enum([
'cogview4',
'imagen3',
'chatgpt-4o',
'bria',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o', 'bria']);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([

View File

@@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
LoRAModelField: 'teal.500',
MainModelField: 'teal.500',
FluxMainModelField: 'teal.500',
BriaMainModelField: 'teal.500',
SD3MainModelField: 'teal.500',
CogView4MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',

View File

@@ -184,6 +184,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zBriaMainModelFieldType = zFieldTypeBase.extend({
name: z.literal('BriaMainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
@@ -304,6 +308,7 @@ const zStatefulFieldType = z.union([
zIntegerGeneratorFieldType,
zStringGeneratorFieldType,
zImageGeneratorFieldType,
zBriaMainModelFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -320,6 +325,7 @@ const modelFieldTypeNames = [
zSD3MainModelFieldType.shape.name.value,
zCogView4MainModelFieldType.shape.name.value,
zFluxMainModelFieldType.shape.name.value,
zBriaMainModelFieldType.shape.name.value,
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
@@ -863,6 +869,26 @@ export const isFluxMainModelFieldInputTemplate =
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
// #endregion
// #region BriaMainModelField
const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zBriaMainModelFieldValue,
});
const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zBriaMainModelFieldType,
originalType: zFieldType.optional(),
default: zBriaMainModelFieldValue,
});
const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zBriaMainModelFieldType,
});
export type BriaMainModelFieldInputInstance = z.infer<typeof zBriaMainModelFieldInputInstance>;
export type BriaMainModelFieldInputTemplate = z.infer<typeof zBriaMainModelFieldInputTemplate>;
export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance);
export const isBriaMainModelFieldInputTemplate =
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
// #endregion
// #region SDXLRefinerModelField
/** @alias */ // tells knip to ignore this duplicate export
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
@@ -1790,6 +1816,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zFluxMainModelFieldValue,
zBriaMainModelFieldValue,
zSD3MainModelFieldValue,
zCogView4MainModelFieldValue,
zSDXLRefinerModelFieldValue,
@@ -1837,6 +1864,7 @@ const zStatefulFieldInputInstance = z.union([
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
zFluxMainModelFieldInputInstance,
zBriaMainModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zCogView4MainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
@@ -1879,6 +1907,7 @@ const zStatefulFieldInputTemplate = z.union([
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
zFluxMainModelFieldInputTemplate,
zBriaMainModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zCogView4MainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
@@ -1927,6 +1956,7 @@ const zStatefulFieldOutputTemplate = z.union([
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,
zFluxMainModelFieldOutputTemplate,
zBriaMainModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zCogView4MainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,

View File

@@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SchedulerField: 'dpmpp_3m_k',
SDXLMainModelField: undefined,
FluxMainModelField: undefined,
BriaMainModelField: undefined,
SD3MainModelField: undefined,
CogView4MainModelField: undefined,
SDXLRefinerModelField: undefined,

View File

@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
BriaMainModelFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
CLIPGEmbedModelFieldInputTemplate,
CLIPLEmbedModelFieldInputTemplate,
@@ -338,6 +339,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
return template;
};
const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder<BriaMainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: BriaMainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -778,6 +793,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
BriaMainModelField: buildBriaMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

View File

@@ -9,6 +9,7 @@ import {
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isBriaMainModelModelConfig,
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isCogView4MainModelModelConfig,
@@ -60,6 +61,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useBriaModels = buildModelsHook(isBriaMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);

View File

@@ -2062,7 +2062,7 @@ export type components = {
* @description Base model type.
* @enum {string}
*/
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o";
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o" | "bria";
/** Batch */
Batch: {
/**
@@ -21258,7 +21258,7 @@ export type components = {
* used, and the type will be ignored. They are included here for backwards compatibility.
* @enum {string}
*/
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
/** UNetField */
UNetField: {
/** @description Info to load unet submodel */

View File

@@ -256,6 +256,10 @@ export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is Ma
return config.type === 'main' && config.base === 'flux';
};
export const isBriaMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'bria';
};
export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
};