mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 06:08:02 -05:00
Compare commits
1 Commits
controlnet
...
brandon/br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ebb03a41c |
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
MainModel = "MainModelField"
|
||||
CogView4MainModel = "CogView4MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
BriaMainModel = "BriaMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
|
||||
@@ -125,6 +125,7 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"BriaPipeline": ModelType.Main,
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
@@ -861,6 +862,8 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.StableDiffusion3
|
||||
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
|
||||
return BaseModelType.CogView4
|
||||
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
|
||||
return BaseModelType.Bria
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
|
||||
56
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
56
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
class BriaDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load Bria main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, CheckpointConfigBase):
|
||||
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
|
||||
|
||||
if submodel_type is None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
|
||||
dtype = self._torch_dtype
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
@@ -28,6 +28,7 @@ class BaseModelType(str, Enum):
|
||||
CogView4 = "cogview4"
|
||||
Imagen3 = "imagen3"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
Bria = "bria"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
|
||||
@@ -23,6 +23,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isBriaMainModelFieldInputInstance,
|
||||
isBriaMainModelFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isCLIPGEmbedModelFieldInputInstance,
|
||||
@@ -105,6 +107,7 @@ import { assert } from 'tsafe';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
|
||||
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
|
||||
@@ -408,6 +411,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isBriaMainModelFieldInputTemplate(template)) {
|
||||
if (!isBriaMainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputTemplate(template)) {
|
||||
if (!isSD3MainModelFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useBriaModels } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate>;
|
||||
|
||||
const BriaMainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useBriaModels();
|
||||
const onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<ModelFieldCombobox
|
||||
value={field.value}
|
||||
modelConfigs={modelConfigs}
|
||||
isLoadingConfigs={isLoading}
|
||||
onChange={onChange}
|
||||
required={props.fieldTemplate.required}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BriaMainModelFieldInputComponent);
|
||||
@@ -77,9 +77,10 @@ const zBaseModel = z.enum([
|
||||
'cogview4',
|
||||
'imagen3',
|
||||
'chatgpt-4o',
|
||||
'bria',
|
||||
]);
|
||||
export type BaseModelType = z.infer<typeof zBaseModel>;
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
|
||||
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o', 'bria']);
|
||||
export type MainModelBase = z.infer<typeof zMainModelBase>;
|
||||
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
|
||||
const zModelType = z.enum([
|
||||
|
||||
@@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
LoRAModelField: 'teal.500',
|
||||
MainModelField: 'teal.500',
|
||||
FluxMainModelField: 'teal.500',
|
||||
BriaMainModelField: 'teal.500',
|
||||
SD3MainModelField: 'teal.500',
|
||||
CogView4MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
|
||||
@@ -184,6 +184,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBriaMainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BriaMainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -304,6 +308,7 @@ const zStatefulFieldType = z.union([
|
||||
zIntegerGeneratorFieldType,
|
||||
zStringGeneratorFieldType,
|
||||
zImageGeneratorFieldType,
|
||||
zBriaMainModelFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
|
||||
@@ -320,6 +325,7 @@ const modelFieldTypeNames = [
|
||||
zSD3MainModelFieldType.shape.name.value,
|
||||
zCogView4MainModelFieldType.shape.name.value,
|
||||
zFluxMainModelFieldType.shape.name.value,
|
||||
zBriaMainModelFieldType.shape.name.value,
|
||||
zSDXLRefinerModelFieldType.shape.name.value,
|
||||
zVAEModelFieldType.shape.name.value,
|
||||
zLoRAModelFieldType.shape.name.value,
|
||||
@@ -863,6 +869,26 @@ export const isFluxMainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region BriaMainModelField
|
||||
const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
|
||||
const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zBriaMainModelFieldValue,
|
||||
});
|
||||
const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zBriaMainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zBriaMainModelFieldValue,
|
||||
});
|
||||
const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zBriaMainModelFieldType,
|
||||
});
|
||||
export type BriaMainModelFieldInputInstance = z.infer<typeof zBriaMainModelFieldInputInstance>;
|
||||
export type BriaMainModelFieldInputTemplate = z.infer<typeof zBriaMainModelFieldInputTemplate>;
|
||||
export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance);
|
||||
export const isBriaMainModelFieldInputTemplate =
|
||||
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
|
||||
// #endregion
|
||||
|
||||
// #region SDXLRefinerModelField
|
||||
/** @alias */ // tells knip to ignore this duplicate export
|
||||
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
|
||||
@@ -1790,6 +1816,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zFluxMainModelFieldValue,
|
||||
zBriaMainModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zCogView4MainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
@@ -1837,6 +1864,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
zFluxMainModelFieldInputInstance,
|
||||
zBriaMainModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zCogView4MainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
@@ -1879,6 +1907,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
zFluxMainModelFieldInputTemplate,
|
||||
zBriaMainModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zCogView4MainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
@@ -1927,6 +1956,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
zFluxMainModelFieldOutputTemplate,
|
||||
zBriaMainModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zCogView4MainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
|
||||
@@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
SDXLMainModelField: undefined,
|
||||
FluxMainModelField: undefined,
|
||||
BriaMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
CogView4MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
|
||||
@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
BriaMainModelFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelFieldInputTemplate,
|
||||
CLIPLEmbedModelFieldInputTemplate,
|
||||
@@ -338,6 +339,20 @@ const buildFluxMainModelFieldInputTemplate: FieldInputTemplateBuilder<FluxMainMo
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder<BriaMainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: BriaMainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -778,6 +793,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
|
||||
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
|
||||
BriaMainModelField: buildBriaMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
} from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isBriaMainModelModelConfig,
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isCogView4MainModelModelConfig,
|
||||
@@ -60,6 +61,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useBriaModels = buildModelsHook(isBriaMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
|
||||
@@ -2062,7 +2062,7 @@ export type components = {
|
||||
* @description Base model type.
|
||||
* @enum {string}
|
||||
*/
|
||||
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o";
|
||||
BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "chatgpt-4o" | "bria";
|
||||
/** Batch */
|
||||
Batch: {
|
||||
/**
|
||||
@@ -21258,7 +21258,7 @@ export type components = {
|
||||
* used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
* @enum {string}
|
||||
*/
|
||||
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
UIType: "MainModelField" | "CogView4MainModelField" | "FluxMainModelField" | "BriaMainModelField" | "SD3MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "FluxVAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "CLIPEmbedModelField" | "CLIPLEmbedModelField" | "CLIPGEmbedModelField" | "SpandrelImageToImageModelField" | "ControlLoRAModelField" | "SigLipModelField" | "FluxReduxModelField" | "LLaVAModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict";
|
||||
/** UNetField */
|
||||
UNetField: {
|
||||
/** @description Info to load unet submodel */
|
||||
|
||||
@@ -256,6 +256,10 @@ export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
||||
return config.type === 'main' && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isBriaMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'bria';
|
||||
};
|
||||
|
||||
export const isFluxFillMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux' && config.variant === 'inpaint';
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user