mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): make config_path optional
This commit is contained in:
@@ -48,7 +48,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -232,7 +232,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
|
||||
is_schnell = transformer_config.variant is FluxVariantType.Schnell
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@@ -277,7 +277,7 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
|
||||
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
|
||||
img_cond: torch.Tensor | None = None
|
||||
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
|
||||
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
|
||||
if is_flux_fill:
|
||||
img_cond = self._prep_flux_fill_img_cond(
|
||||
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
|
||||
|
||||
@@ -198,7 +198,6 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, LegacyProbeMixin):
|
||||
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
||||
# Cannot use `elif isinstance(cls, UnknownModelConfig)` because UnknownModelConfig is not defined yet
|
||||
else:
|
||||
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
||||
|
||||
@@ -346,11 +345,16 @@ class CheckpointConfigBase(ABC, BaseModel):
|
||||
"""Base class for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
description="Format of the provided checkpoint model",
|
||||
default=ModelFormat.Checkpoint,
|
||||
)
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
config_path: str | None = Field(
|
||||
description="path to the checkpoint model config file",
|
||||
default=None,
|
||||
)
|
||||
converted_at: float | None = Field(
|
||||
description="When this model was last converted to diffusers",
|
||||
default_factory=time.time,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ export const MainModelPicker = memo(() => {
|
||||
() =>
|
||||
selectedModelConfig &&
|
||||
isCheckpointMainModelConfig(selectedModelConfig) &&
|
||||
selectedModelConfig.config_path === 'flux-dev',
|
||||
selectedModelConfig.variant === 'flux_dev',
|
||||
[selectedModelConfig]
|
||||
);
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ export const InitialStateMainModelPicker = memo(() => {
|
||||
() =>
|
||||
selectedModelConfig &&
|
||||
isCheckpointMainModelConfig(selectedModelConfig) &&
|
||||
selectedModelConfig.config_path === 'flux-dev',
|
||||
selectedModelConfig.variant === 'flux_dev',
|
||||
[selectedModelConfig]
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user