Compare commits

...

122 Commits

Author SHA1 Message Date
psychedelicious
e238b2682d chore: bump version to v6.9.0a3 2025-10-14 16:31:55 +11:00
psychedelicious
7296a0c911 chore: bump version to v6.9.0a2 2025-10-13 10:30:24 +11:00
psychedelicious
e38052155b fix(mm): fixes for migration 23
- Handle CLIP Embed and Main SD models missing variant field
- Handle errors when calling the discriminator function, previously only
handled ValidationError but it could be a ValueError or something else
- Better logging for config migration
2025-10-13 10:30:24 +11:00
psychedelicious
644bb98476 fix(ui): typegen schema sync 2025-10-13 10:30:23 +11:00
psychedelicious
fc10db40d0 tests(mm): windows CI issue 2025-10-13 10:30:23 +11:00
psychedelicious
6cea0c043e feat(mm): just delete the dir w/ rmtree when deleting model 2025-10-13 10:30:23 +11:00
psychedelicious
bab61fbca4 fix(mm): issue with deleting single file models 2025-10-13 10:30:23 +11:00
psychedelicious
abbc96de7a tests(mm): attempt to fix windows model id tests 2025-10-13 10:30:23 +11:00
psychedelicious
bcc4735024 docs: update model id readme 2025-10-13 10:30:23 +11:00
psychedelicious
975d7166cd chore: bump version to v6.9.0a1 2025-10-13 10:30:23 +11:00
psychedelicious
f1ba95d42e docs(db): update version numbers in migration comments 2025-10-13 10:30:09 +11:00
psychedelicious
e811ffc8e2 feat(ui): use translation string for model edit warning 2025-10-13 10:30:09 +11:00
psychedelicious
7ff73eb75b fix(mm): lora state dict loading in model id 2025-10-13 10:30:09 +11:00
psychedelicious
396f739b22 tidy: remove unused file 2025-10-13 10:30:09 +11:00
psychedelicious
ab1e15e4f5 tests(mm): flux state dict tests 2025-10-13 10:30:09 +11:00
psychedelicious
9d9625f8ab feat(ui): add warning for model settings edit 2025-10-13 10:30:09 +11:00
psychedelicious
3af42c56d2 feat: allow users to edit models freely 2025-10-13 10:30:09 +11:00
psychedelicious
e5935a39e4 tests(mm): fix remaining MM tests 2025-10-13 10:30:09 +11:00
psychedelicious
adc332b9e3 feat(mm): add flag for updating models to allow class changes 2025-10-13 10:30:09 +11:00
psychedelicious
d81a55401a feat(mm): use ValueError for model id sanity checks 2025-10-13 10:30:09 +11:00
psychedelicious
db2a8306c2 fix(mm): omit type/format/base when creating unknown config instance 2025-10-13 10:30:08 +11:00
psychedelicious
a16d0b8301 tests(mm): refactor model identification tests
Overhaul of model identification (probing) tests. Previously we didn't
test the correctness of probing except in a few narrow cases - now we
do.

See tests/model_identification/README.md for a detailed overview of the
new test setup. It includes instructions for adding a new test case. In
brief:

- Download the model you want to add as a test case
- Run a script against it to generate the test model files
- Fill in the expected model type/format/base/etc in the generated test
metadata JSON file

Included test cases:
- All starter models
- A handful of other models that I had installed
- Models present in the previous test cases as smoke tests, now also
tested for correctness
2025-10-13 10:30:08 +11:00
psychedelicious
2206b28576 refactor(mm): continued iteration on model identifcation 2025-10-13 10:30:08 +11:00
psychedelicious
4f20a0db2e feat(mm): do not log when multiple non-unknown model matches 2025-10-13 10:30:08 +11:00
psychedelicious
d258af0a14 feat(mm): add method to get main model defaults from a base 2025-10-13 10:30:08 +11:00
psychedelicious
fc2175ae03 fix(mm): ModelOnDisk skips dirs when looking for weights
Previously a path w/ any of the known weights suffixes would be seen as
a weights file, even if it was a directory. We now check to ensure the
candidate path is actually a file before adding it to the list of
weights.
2025-10-13 10:30:08 +11:00
psychedelicious
d11cb34d22 fix(mm): vae checkpoint probe checking for dir instead of file 2025-10-13 10:30:08 +11:00
psychedelicious
381827fd54 fix(mm): false negative on flux lora 2025-10-13 10:30:08 +11:00
psychedelicious
17c3d15488 feat(db): run migrations 23 and 24 2025-10-13 10:30:08 +11:00
psychedelicious
5a681d51c9 fix(db): migration 23 fall back to unknown model when config change fails 2025-10-13 10:30:08 +11:00
psychedelicious
f57ee304bc fix(db): migration 22 insert only real cols 2025-10-13 10:30:08 +11:00
psychedelicious
55d7d2e396 fix(mm): pop base/type/format when creating unknown model config 2025-10-13 10:30:08 +11:00
psychedelicious
347a33f77c refactor(mm): split big migration into 3
Split the big migration that did all of these things into 3:

- Migration 22: Remove unique contraint on base/name/type in models
table
- Migration 23: Migrate configs to v6.8.0 schemas
- Migration 24: Normalize file storage
2025-10-13 10:30:07 +11:00
psychedelicious
34cb88ef23 fix(mm): duplicate import 2025-10-13 10:30:07 +11:00
psychedelicious
e676b9d075 feat(mm): add model config schema migration logic 2025-10-13 10:30:07 +11:00
psychedelicious
26dc155ad8 feat(mm): fix clip vision starter model bases, add ref to actual models 2025-10-13 10:30:07 +11:00
psychedelicious
b5aa31526f feat(mm): clearer naming for main config class hierarchy 2025-10-13 10:30:07 +11:00
psychedelicious
99d3f16eb4 docs(mm): add reminder for self for field migrations 2025-10-13 10:30:07 +11:00
psychedelicious
01ca74e622 feat(mm): add sanity checks before probing paths 2025-10-13 10:30:07 +11:00
psychedelicious
25619684c0 fix(mm): clip vision identification 2025-10-13 10:30:07 +11:00
psychedelicious
d336aa45f5 feat(mm): more flexible config matching utils 2025-10-13 10:30:07 +11:00
psychedelicious
303acdb4ac fix(mm): sdxl ip adapter identification 2025-10-13 10:30:07 +11:00
psychedelicious
2e5ec1c98b fix(mm): t5 identification 2025-10-13 10:30:07 +11:00
psychedelicious
56e31ca4ac fix(mm): ensure unknown model configs get unknown attrs 2025-10-13 10:30:07 +11:00
psychedelicious
74e4dd4393 docs(mm): remove extraneous comment 2025-10-13 10:30:07 +11:00
psychedelicious
233b286893 feat(mm): satisfy type checker in flux denoise 2025-10-13 10:30:07 +11:00
psychedelicious
07a667ad9f feat(mm): add helper method for legacy configs 2025-10-13 10:30:07 +11:00
psychedelicious
7437a14301 docs(mm): document flux variant attr 2025-10-13 10:30:07 +11:00
psychedelicious
09fef01786 docs(mm): update docsstrings in factory.py 2025-10-13 10:30:07 +11:00
psychedelicious
83fe40e7ee fix(mm): inverted condition 2025-10-13 10:30:07 +11:00
psychedelicious
13b2f9d12b refactor(mm): remove legacy probe, new configs dir structure, update imports 2025-10-13 10:30:07 +11:00
psychedelicious
0214afc3d1 chore(ui): typegen 2025-10-13 10:30:07 +11:00
psychedelicious
e23ac6d813 docs(mm): add comments for identification utils 2025-10-13 10:30:07 +11:00
psychedelicious
9faffe93f9 refactor(mm): split configs into separate files 2025-10-13 10:30:07 +11:00
psychedelicious
edfd90f2a4 tidy(mm): consistent class names 2025-10-13 10:30:06 +11:00
psychedelicious
e48e354bf1 fix(mm): tag generation & scattered probe fixes 2025-10-13 10:30:06 +11:00
psychedelicious
4ded5b5a80 feat(mm): consistent naming for all model config classes 2025-10-13 10:30:06 +11:00
psychedelicious
ee5808355d refactor(mm): diffusers loras
w
2025-10-13 10:30:06 +11:00
psychedelicious
af305250cb refactor(mm): make config classes narrow
Simpler logic to identify, less complexity to add new model, fewer
useless attrs that do not relate to the model arch, etc
2025-10-13 10:30:06 +11:00
psychedelicious
c065655a1d tidy(mm): flux lora format util 2025-10-13 10:30:06 +11:00
psychedelicious
a0a4eb9a5a tidy(mm): clean up ModelOnDisk caching 2025-10-13 10:30:06 +11:00
psychedelicious
c53c731371 tidy(mm): clean up model heuristic utils 2025-10-13 10:30:06 +11:00
psychedelicious
951635fbee feat(mm): wip port main models to new api 2025-10-13 10:30:06 +11:00
psychedelicious
044648fe61 tidy(mm): removed unused model merge class 2025-10-13 10:30:06 +11:00
psychedelicious
111782d6c9 docs(mm): add todos 2025-10-13 10:30:06 +11:00
psychedelicious
f5cbf60fc0 feat(mm): wip port of main models to new api 2025-10-13 10:30:06 +11:00
psychedelicious
395b7d8bbf feat(mm): wip port of main models to new api 2025-10-13 10:30:06 +11:00
psychedelicious
934b3f8b87 feat(mm): wip port of main models to new api 2025-10-13 10:30:06 +11:00
psychedelicious
9745c25b1b refactor(mm): add config validation utils, make it all consistent and clean 2025-10-13 10:30:06 +11:00
psychedelicious
925698a688 feat(mm): port cnet to new api 2025-10-13 10:30:06 +11:00
psychedelicious
96bbd8a26e fix(mm): t2i base determination 2025-10-13 10:30:06 +11:00
psychedelicious
eb1ed245fe tidy(ui): use Extract to get model config types 2025-10-13 10:30:06 +11:00
psychedelicious
a118700cc8 feat(mm): port flux "control lora" and t2i adapter to new api 2025-10-13 10:30:06 +11:00
psychedelicious
eaddd6f533 refactor(mm): continue iterating on config 2025-10-13 10:30:06 +11:00
psychedelicious
7ca0a0a0fd tidy(mm): skip optimistic override handling for now 2025-10-13 10:30:06 +11:00
psychedelicious
d185b85fb7 feat(mm): port ip adapter to new api 2025-10-13 10:30:06 +11:00
psychedelicious
a35a49f585 feat(mm): port flux redux to new api 2025-10-13 10:30:06 +11:00
psychedelicious
3b606b6d63 feat(mm): make match helpers more succint 2025-10-13 10:30:05 +11:00
psychedelicious
d89472d3b1 feat(mm): port SigLIPDiffusersConfig to new api 2025-10-13 10:30:05 +11:00
psychedelicious
036ab04376 feat(mm): port CLIPVisionDiffusersConfig to new api 2025-10-13 10:30:05 +11:00
psychedelicious
e1a54badc1 fix(mm): fall back to UnknownModelConfig correctly 2025-10-13 10:30:05 +11:00
psychedelicious
bbecc86d0f tidy(mm): clarify that model id utils are private 2025-10-13 10:30:05 +11:00
psychedelicious
d4823b6869 fix(mm): abstractmethod bork 2025-10-13 10:30:05 +11:00
psychedelicious
3488975b2b refactor(mm): add model config parsing utils 2025-10-13 10:30:05 +11:00
psychedelicious
fd47da6842 refactor(mm): remove unused methods in config.py 2025-10-13 10:30:05 +11:00
psychedelicious
8399de9c25 refactor(mm): simplify model classification process
Previously, we had a multi-phase strategy to identify models from their
files on disk:
1. Run each model config classes' `matches()` method on the files. It
checks if the model could possibly be an identified as the candidate
model type. This was intended to be a quick check. Break on the first
match.
2. If we have a match, run the config class's `parse()` method. It
derive some additional model config attrs from the model files. This was
intended to encapsulate heavier operations that may require loading the
model into memory.
3. Derive the common model config attrs, like name, description,
calculate the hash, etc. Some of these are also heavier operations.

This strategy has some issues:
- It is not clear how the pieces fit together. There is some
back-and-forth between different methods and the config base class. It
is hard to trace the flow of logic until you fully wrap your head around
the system and therefore difficult to add a model architecture to the
probe.
- The assumption that we could do quick, lightweight checks before
heavier checks is incorrect. We often _must_ load the model state dict
in the `matches()` method. So there is no practical perf benefit to
splitting up the responsibility of `matches()` and `parse()`.
- Sometimes we need to do the same checks in `matches()` and `parse()`.
In these cases, splitting the logic is has a negative perf impact
because we are doing the same work twice.
- As we introduce the concept of an "unknown" model config (i.e. a model
that we cannot identify, but still record in the db; see #8582), we will
_always_ run _all_ the checks for every model. Therefore we need not try
to defer heavier checks or resource-intensive ops like hashing. We are
going to do them anyways.
- There are situations where a model may match multiple configs. One
known case are SD pipeline models with merged LoRAs. In the old probe
API, we relied on the implicit order of checks to know that if a model
matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in
the new API, we do not have this implicit ordering of checks. To resolve
this in a resilient way, we need to get all matches up front, then use
tie-breaker logic to figure out which should win (or add "differential
diagnosis" logic to the matchers).
- Field overrides weren't handled well by this strategy. They were only
applied at the very end, if a model matched successfully. This means we
cannot tell the system "Hey, this model is type X with base Y. Trust me
bro.". We cannot override the match logic. As we move towards letting
users correct mis-identified models (see #8582), this is a requirement.

We can simplify the process significantly and better support "unknown"
models.

Firstly, model config classes now have a single `from_model_on_disk()`
method that attempts to construct an instance of the class from the
model files. This replaces the `matches()` and `parse()` methods.

If we fail to create the config instance, a special exception is raised
that indicates why we think the files cannot be identified as the given
model config class.

Next, the flow for model identification is a bit simpler:
- Derive all the common fields up-front (name, desc, hash, etc).
- Merge in overrides.
- Call `from_model_on_disk()` for every config class, passing in the
fields. Overrides are handled in this method.
- Record the results for each config class and choose the best one.

The identification logic is a bit more verbose, with the special
exceptions and handling of overrides, but it is very clear what is
happening.

The one downside I can think of for this strategy is we do need to check
every model type, instead of stopping at the first match. It's a bit
less efficient. In practice, however, this isn't a hot code path, and
the improved clarity is worth far more than perf optimizations that the
end user will likely never notice.
2025-10-13 10:30:05 +11:00
psychedelicious
0fd58681a2 feat(mm): make config_path optional 2025-10-13 10:30:05 +11:00
psychedelicious
250163e6b7 feat(mm): port t5 to new API 2025-10-13 10:30:05 +11:00
psychedelicious
4b1450a4ff feat(mm): better errors when invalid model config found in db 2025-10-13 10:30:05 +11:00
psychedelicious
4e2145c6c4 tidy(mm): patcher types and import paths 2025-10-13 10:30:05 +11:00
psychedelicious
8a6d5f4f6a fix(mm): vae class inheritance and config_path 2025-10-13 10:30:05 +11:00
psychedelicious
06dcd290df feat(mm): port vae to new API 2025-10-13 10:30:05 +11:00
psychedelicious
73b6fae00e fix(mm): tis use existing weight_files method 2025-10-13 10:30:05 +11:00
psychedelicious
4ae20f4876 fix(mm): loader for clip embed 2025-10-13 10:30:05 +11:00
psychedelicious
f852c03ba5 fix(mm): parsing for spandrel 2025-10-13 10:30:05 +11:00
psychedelicious
8a14175ab2 feat(mm): port spandrel to new API 2025-10-13 10:30:05 +11:00
psychedelicious
9469bb05fe tidy(mm): remove unused probes 2025-10-13 10:30:05 +11:00
psychedelicious
8036bb0e8f feat(mm): port TIs to new API 2025-10-13 10:30:05 +11:00
psychedelicious
e72c78f7d4 refactor: port MM probes to new api
- Add concept of match certainty to new probe
- Port CLIP Embed models to new API
- Fiddle with stuff
2025-10-13 10:30:05 +11:00
psychedelicious
a8009b47e9 fix(mm): normalized multi-file/diffusers model installation no worky
now worky
2025-10-13 10:30:04 +11:00
psychedelicious
6294c294d0 feat(mm): add migration to flat model storage 2025-10-13 10:30:04 +11:00
psychedelicious
6f08a2bfb1 feat(mm): normalized model storage
Store models in a flat directory structure. Each model is in a dir named
its unique key (a UUID). Inside that dir is either the model file or the
model dir.
2025-10-13 10:30:04 +11:00
psychedelicious
84e4d313a8 fix(ui): wrong translation string 2025-10-13 10:30:04 +11:00
psychedelicious
092cff358a chore(ui): lint 2025-10-13 10:30:04 +11:00
psychedelicious
ca3ccf92bc tidy(ui): prefer types from zod schemas for model attrs 2025-10-13 10:30:04 +11:00
psychedelicious
7cdc821801 tests(mm): fix test for MM, leave the UnknownModelConfig class in the list of configs 2025-10-13 10:30:04 +11:00
psychedelicious
08853f9be2 chore(ui): typegen 2025-10-13 10:30:04 +11:00
psychedelicious
4897eebf5f docs: update config docstrings 2025-10-13 10:30:04 +11:00
psychedelicious
93a170a62c feat(ui): toast warning when installed model is unidentified 2025-10-13 10:30:04 +11:00
psychedelicious
facb02602c chore(ui): typegen 2025-10-13 10:30:04 +11:00
psychedelicious
62c456a1e4 feat(app): add the installed model config to install complete events 2025-10-13 10:30:04 +11:00
psychedelicious
51b2297a2b feat(ui): allow changing model format in MM 2025-10-13 10:30:04 +11:00
psychedelicious
64aaf9880a feat(app): add setting to allow unknown models 2025-10-13 10:30:04 +11:00
psychedelicious
9e509ffb56 feat(mm): omit model description instead of making it "base type filename model" 2025-10-13 10:30:04 +11:00
psychedelicious
6e9e8d6bd2 feat(ui): allow changing model type in MM, fix up base and variant selects 2025-10-13 10:30:04 +11:00
psychedelicious
eb6b3b8168 feat(ui): add unknown model base support in ui 2025-10-13 10:30:04 +11:00
psychedelicious
0f5beec657 chore(ui): typegen 2025-10-13 10:30:04 +11:00
psychedelicious
8474fd8342 feat(nodes): add unknown as model base 2025-10-13 10:30:04 +11:00
psychedelicious
0a3e6d3f88 refactor(ui): remove unused excludeSubmodels
I can't remember what this was for and don't see any reference to it.
Maybe it's just remnants from a previous implementation?
2025-10-13 10:30:03 +11:00
psychedelicious
7cc7d06f3c refactor(ui)refactor(ui): more cleanup of model categories 2025-10-13 10:30:03 +11:00
psychedelicious
b26ab0b3f1 refactor(ui): move model categorisation-ish logic to central location, simplify model manager models list 2025-10-13 10:30:03 +11:00
psychedelicious
4ae6c903e3 feat(mm): add UnknownModelConfig 2025-10-13 10:30:03 +11:00
548 changed files with 11711 additions and 6217 deletions

1
.gitattributes vendored
View File

@@ -4,3 +4,4 @@
* text=auto
docker/** text eol=lf
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
tests/model_identification/stripped_models/** filter=lfs diff=lfs merge=lfs -text

View File

@@ -28,10 +28,12 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
MainCheckpointConfig,
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
@@ -44,6 +46,7 @@ from invokeai.backend.model_manager.starter_models import (
StarterModelBundle,
StarterModelWithoutDependencies,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
@@ -297,10 +300,8 @@ async def update_model_record(
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try:
record_store.update_model(key, changes=changes)
config = installer.sync_model_path(key)
config = record_store.update_model(key, changes=changes, allow_class_change=True)
config = add_cover_image_to_model_config(config, ApiDependencies)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
@@ -743,9 +744,18 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
if not isinstance(
model_config,
(
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
logger.error(msg)
raise HTTPException(400, msg)
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem

View File

@@ -22,7 +22,7 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import CogView4ConditioningInfo

View File

@@ -13,8 +13,7 @@ from invokeai.app.invocations.model import (
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("cogview4_model_loader_output")

View File

@@ -20,9 +20,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelType, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@@ -182,10 +180,11 @@ class CreateGradientMaskInvocation(BaseInvocation):
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
assert main_model_config.type is ModelType.Main
variant = getattr(main_model_config, "variant", None)
if variant is ModelVariantType.Inpaint or variant is FluxVariantType.DevFill:
mask = dilated_mask_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -39,7 +39,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher

View File

@@ -48,7 +48,7 @@ from invokeai.backend.flux.sampling_utils import (
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -232,7 +232,8 @@ class FluxDenoiseInvocation(BaseInvocation):
)
transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main
is_schnell = transformer_config.variant is FluxVariantType.Schnell
# Calculate the timestep schedule.
timesteps = get_schedule(
@@ -277,7 +278,7 @@ class FluxDenoiseInvocation(BaseInvocation):
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
img_cond: torch.Tensor | None = None
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
if is_flux_fill:
img_cond = self._prep_flux_fill_img_cond(
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype

View File

@@ -16,10 +16,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
)
from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_FLUX_Config
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@@ -68,7 +65,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config)
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

View File

@@ -13,10 +13,8 @@ from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
)
from invokeai.backend.flux.util import get_flux_max_seq_length
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@@ -87,12 +85,12 @@ class FluxModelLoaderInvocation(BaseInvocation):
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
assert isinstance(transformer_config, Checkpoint_Config_Base)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
)

View File

@@ -24,9 +24,9 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.starter_models import siglip
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice

View File

@@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.model_manager import ModelFormat
from invokeai.backend.model_manager.taxonomy import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

View File

@@ -12,7 +12,7 @@ from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux

View File

@@ -23,7 +23,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice

View File

@@ -11,10 +11,10 @@ from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.ip_adapter import (
IPAdapter_Checkpoint_Config_Base,
IPAdapter_InvokeAI_Config_Base,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
@@ -123,9 +123,9 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:

View File

@@ -12,9 +12,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@@ -24,8 +22,9 @@ class ModelIdentifierField(BaseModel):
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
submodel_type: SubModelType | None = Field(
description="The submodel to load, if this is a main model",
default=None,
)
@classmethod

View File

@@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo

View File

@@ -108,6 +108,7 @@ class InvokeAIAppConfig(BaseSettings):
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
scan_models_on_startup: Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
"""
_root: Optional[Path] = PrivateAttr(default=None)
@@ -198,6 +199,7 @@ class InvokeAIAppConfig(BaseSettings):
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
scan_models_on_startup: bool = Field(default=False, description="Scan the models directory on startup, registering orphaned models. This is typically only used in conjunction with `use_memory_db` for testing purposes.")
unsafe_disable_picklescan: bool = Field(default=False, description="UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.")
allow_unknown_models: bool = Field(default=True, description="Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.")
# fmt: on

View File

@@ -44,8 +44,8 @@ if TYPE_CHECKING:
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import SubModelType
class EventServiceBase:

View File

@@ -16,8 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import SubModelType
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
@@ -546,11 +546,18 @@ class ModelInstallCompleteEvent(ModelEventBase):
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
config: AnyModelConfig = Field(description="The installed model's config")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes)
return cls(
id=job.id,
source=job.source,
key=(job.config_out.key),
total_bytes=job.total_bytes,
config=job.config_out,
)
@payload_schema.register

View File

@@ -12,7 +12,6 @@ from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
@@ -231,19 +230,6 @@ class ModelInstallServiceBase(ABC):
will block indefinitely until the installs complete.
"""
@abstractmethod
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
@abstractmethod
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
"""

View File

@@ -10,11 +10,17 @@ from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
class InvalidModelConfigException(Exception):
"""Raised when a model configuration is invalid."""
pass
class InstallStatus(str, Enum):
"""State of an install job running in the background."""

View File

@@ -5,6 +5,7 @@ import os
import re
import threading
import time
from copy import deepcopy
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
@@ -26,6 +27,7 @@ from invokeai.app.services.model_install.model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
InvalidModelConfigException,
LocalModelSource,
ModelInstallJob,
ModelSource,
@@ -34,13 +36,12 @@ from invokeai.app.services.model_install.model_install_common import (
)
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
from invokeai.backend.model_manager.configs.factory import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@@ -180,28 +181,32 @@ class ModelInstallService(ModelInstallServiceBase):
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
) -> str: # noqa D102
) -> str:
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
if Path(model_path).is_file():
# Careful! Don't use pathlib.Path(...).with_suffix - it can will strip everything after the first dot.
preferred_name = f"{preferred_name}{model_path.suffix}"
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
)
dest_dir = self.app_config.models_path / info.key
try:
new_path = self._move_model(model_path, dest_path)
except FileExistsError as excp:
if dest_dir.exists():
raise FileExistsError(
f"Cannot install model {model_path.name} to {dest_dir}: destination already exists"
)
dest_dir.mkdir(parents=True)
dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir
if model_path.is_file():
move(model_path, dest_path)
elif model_path.is_dir():
# Move the contents of the directory, not the directory itself
for item in model_path.iterdir():
move(item, dest_dir / item.name)
except FileExistsError as e:
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
f"A model named {model_path.name} is already installed at {dest_dir.as_posix()}"
) from e
return self._register(
new_path,
dest_path,
config,
info,
)
@@ -364,9 +369,18 @@ class ModelInstallService(ModelInstallServiceBase):
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
# Models are stored in a directory named by their key. To delete the model on disk, we delete the entire
# directory. However, the path we store in the model record may be either a file within the key directory,
# or the directory itself. So we have to handle both cases.
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
# Sanity check - file models should be in their own directory under the models dir. The parent of the
# file should be the model's directory, not the Invoke models dir!
assert model_path.parent != self.app_config.models_path
rmtree(model_path.parent)
elif model_path.is_dir():
# Sanity check - folder models should be in their own directory under the models dir. The path should
# not be the Invoke models dir itself!
assert model_path != self.app_config.models_path
rmtree(model_path)
self.unregister(key)
@@ -526,7 +540,7 @@ class ModelInstallService(ModelInstallServiceBase):
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException(
ValueError(
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
@@ -589,66 +603,25 @@ class ModelInstallService(ModelInstallServiceBase):
found_models = search.search(self._app_config.models_path)
self._logger.info(f"{len(found_models)} new models registered")
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
old_path = self.app_config.models_path / model.path
if not old_path.is_relative_to(models_dir):
# The model is not in the models directory - we don't need to move it.
return model
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
if new_path.exists():
raise FileExistsError(f"Cannot move {old_path} to {new_path}: destination already exists")
new_path.parent.mkdir(parents=True, exist_ok=True)
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
config = config or ModelRecordChanges()
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
result = ModelConfigFactory.from_model_on_disk(
mod=model_path,
override_fields=deepcopy(fields),
hash_algo=hash_algo,
allow_unknown=self.app_config.allow_unknown_models,
)
if result.config is None:
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
raise InvalidModelConfigException(f"Could not identify model for {model_path}")
elif isinstance(result.config, Unknown_Config):
self._logger.error(f"Could not identify model for {model_path}, detailed results: {result.details}")
return result.config
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
@@ -669,7 +642,7 @@ class ModelInstallService(ModelInstallServiceBase):
info.path = model_path.as_posix()
if isinstance(info, CheckpointConfigBase):
if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None:
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType

View File

@@ -11,7 +11,7 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,

View File

@@ -1,12 +1,10 @@
"""Initialization file for model manager service."""
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.load import LoadedModel
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModelConfig",
"LoadedModel",
]

View File

@@ -12,15 +12,14 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlAdapterDefaultSettings,
LoraModelDefaultSettings,
MainModelDefaultSettings,
)
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.lora import LoraModelDefaultSettings
from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxVariantType,
ModelFormat,
ModelSourceType,
ModelType,
@@ -90,7 +89,9 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
description="The variant of the model.", default=None
)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
@@ -126,12 +127,14 @@ class ModelRecordServiceBase(ABC):
pass
@abstractmethod
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated.
:param changes: A set of changes to apply to this model. Changes are validated before being written.
:param allow_class_change: If True, allows changes that would change the model config class. For example,
changing a LoRA into a Main model. This does not disable validation, so the changes must still be valid.
"""
pass

View File

@@ -58,10 +58,7 @@ from invokeai.app.services.model_records.model_records_base import (
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigFactory,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
@@ -137,15 +134,36 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
if allow_class_change:
# The changes may cause the model config class to change. To handle this, we need to construct the new
# class from scratch rather than trying to modify the existing instance in place.
#
# 1. Convert the existing record to a dict
# 2. Apply the changes to the dict
# 3. Attempt to create a new model config from the updated dict
json_serialized = record.model_dump_json()
# 1. Convert the existing record to a dict
record_as_dict = record.model_dump()
# 2. Apply the changes to the dict
for field_name in changes.model_fields_set:
record_as_dict[field_name] = getattr(changes, field_name)
# 3. Attempt to create a new model config from the updated dict
record = ModelConfigFactory.from_dict(record_as_dict)
# If we get this far, the updated model config is valid, so we can save it to the database.
json_serialized = record.model_dump_json()
else:
# We are not allowing the model config class to change, so we can just update the existing instance in
# place. If the changes are invalid for the existing class, an exception will be raised by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
cursor.execute(
"""--sql
@@ -172,7 +190,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
SELECT config FROM models
WHERE id=?;
""",
(key,),
@@ -180,14 +198,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
model = ModelConfigFactory.from_dict(json.loads(rows[0]))
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
SELECT config FROM models
WHERE hash=?;
""",
(hash,),
@@ -195,7 +213,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
model = ModelConfigFactory.from_dict(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
@@ -263,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
SELECT config
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
@@ -276,15 +294,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
model_config = ModelConfigFactory.from_dict(json.loads(row[0]))
except pydantic.ValidationError as e:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
try:
name = json.loads(row[0]).get("name", "<unknown>")
except Exception:
name = "<unknown>"
self._logger.warning(
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
f"Skipping invalid model config in the database with name {name}. Ignoring this model. ({row_data})"
)
self._logger.warning(f"Validation error: {e}")
else:
results.append(model_config)
@@ -295,12 +318,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
SELECT config FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
@@ -308,12 +331,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
SELECT config FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
results = [ModelConfigFactory.from_dict(json.loads(x[0])) for x in cursor.fetchall()]
return results
def list_models(

View File

@@ -1,6 +1,6 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
class ModelRelationshipsService(ModelRelationshipsServiceABC):

View File

@@ -19,10 +19,8 @@ from invokeai.app.services.model_records.model_records_base import UnknownModelE
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.step_callback import diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
)
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -558,7 +556,7 @@ class ModelsInterface(InvocationContextInterface):
The absolute path to the model.
"""
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path)
if model_path.is_absolute():
return model_path.resolve()

View File

@@ -24,6 +24,9 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_22 import build_migration_22
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -65,6 +68,9 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.register_migration(build_migration_21())
migrator.register_migration(build_migration_22(app_config=config, logger=logger))
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
migrator.run_migrations()
return db

View File

@@ -0,0 +1,89 @@
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration22Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
self._models_dir = app_config.models_path.resolve()
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._logger.info("Removing UNIQUE(name, base, type) constraint from models table")
# Step 1: Rename the existing models table
cursor.execute("ALTER TABLE models RENAME TO models_old;")
# Step 2: Create the new models table without the UNIQUE(name, base, type) constraint
cursor.execute(
"""--sql
CREATE TABLE models (
id TEXT NOT NULL PRIMARY KEY,
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
file_size INTEGER GENERATED ALWAYS as (json_extract(config, '$.file_size')) VIRTUAL NOT NULL,
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Explicit unique constraint on path
UNIQUE(path)
);
"""
)
# Step 3: Copy all data from the old table to the new table
# Only copy the stored columns (id, config, created_at, updated_at), not the virtual columns
cursor.execute(
"INSERT INTO models (id, config, created_at, updated_at) "
"SELECT id, config, created_at, updated_at FROM models_old;"
)
# Step 4: Drop the old table
cursor.execute("DROP TABLE models_old;")
# Step 5: Recreate indexes
cursor.execute("CREATE INDEX IF NOT EXISTS base_index ON models(base);")
cursor.execute("CREATE INDEX IF NOT EXISTS type_index ON models(type);")
cursor.execute("CREATE INDEX IF NOT EXISTS name_index ON models(name);")
# Step 6: Recreate the updated_at trigger
cursor.execute(
"""--sql
CREATE TRIGGER models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 21 to version 22.
This migration:
- Removes the UNIQUE constraint on the combination of (base, name, type) columns in the models table
- Adds an explicit UNIQUE contraint on the path column
"""
return Migration(
from_version=21,
to_version=22,
callback=Migration22Callback(app_config=app_config, logger=logger),
)

View File

@@ -0,0 +1,193 @@
import json
import sqlite3
from copy import deepcopy
from logging import Logger
from typing import Any
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, AnyModelConfigValidator
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxVariantType,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
class Migration23Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
self._models_dir = app_config.models_path.resolve()
def __call__(self, cursor: sqlite3.Cursor) -> None:
# Grab all model records
cursor.execute("SELECT id, config FROM models;")
rows = cursor.fetchall()
migrated_count = 0
fallback_count = 0
for model_id, config_json in rows:
try:
# Migrate the config JSON to the latest schema
config_dict: dict[str, Any] = json.loads(config_json)
migrated_config = self._parse_and_migrate_config(config_dict)
if isinstance(migrated_config, Unknown_Config):
fallback_count += 1
else:
migrated_count += 1
# Write the migrated config back to the database
cursor.execute(
"UPDATE models SET config = ? WHERE id = ?;",
(migrated_config.model_dump_json(), model_id),
)
except ValidationError as e:
self._logger.error("Invalid config schema for model %s: %s", model_id, e)
raise
except json.JSONDecodeError as e:
self._logger.error("Invalid config JSON for model %s: %s", model_id, e)
raise
if migrated_count > 0 and fallback_count == 0:
self._logger.info(f"Migration complete: {migrated_count} model configs migrated")
elif migrated_count > 0 and fallback_count > 0:
self._logger.warning(
f"Migration complete: {migrated_count} model configs migrated, "
f"{fallback_count} model configs could not be migrated and were saved as unknown models",
)
elif migrated_count == 0 and fallback_count > 0:
self._logger.warning(
f"Migration complete: all {fallback_count} model configs could not be migrated and were saved as unknown models",
)
else:
self._logger.info("Migration complete: no model configs needed migration")
def _parse_and_migrate_config(self, config_dict: dict[str, Any]) -> AnyModelConfig:
# In v6.9.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes
# we need to make to old configs to bring them up to date.
type = config_dict.get("type")
format = config_dict.get("format")
base = config_dict.get("base")
if base == BaseModelType.Flux.value and type == ModelType.Main.value:
# Prior to v6.9.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
# variants.
#
# `config_path` was set to one of:
# - flux-dev
# - flux-dev-fill
# - flux-schnell
#
# `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX
# models.
#
# We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used.
# Extract and remove `config_path` if present.
config_path = config_dict.pop("config_path", None)
match config_path:
case "flux-dev":
config_dict["variant"] = FluxVariantType.Dev.value
case "flux-dev-fill":
config_dict["variant"] = FluxVariantType.DevFill.value
case "flux-schnell":
config_dict["variant"] = FluxVariantType.Schnell.value
case _:
# Unknown config_path - default to Dev variant
config_dict["variant"] = FluxVariantType.Dev.value
if (
base
in {
BaseModelType.StableDiffusion1.value,
BaseModelType.StableDiffusion2.value,
BaseModelType.StableDiffusionXL.value,
BaseModelType.StableDiffusionXLRefiner.value,
}
and type == ModelType.Main.value
):
# Prior to v6.9.0, the prediction_type field was optional and would default to Epsilon if not present.
# We now make it explicit and always present. Use the existing value if present, otherwise default to
# Epsilon, matching the probe logic.
#
# It's only on SD1.x, SD2.x, and SDXL main models.
config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value)
# Prior to v6.9.0, the variant field was optional and would default to Normal if not present.
# We now make it explicit and always present. Use the existing value if present, otherwise default to
# Normal. It's only on SD main models.
config_dict["variant"] = config_dict.get("variant", ModelVariantType.Normal.value)
if base == BaseModelType.Flux.value and type == ModelType.LoRA.value and format == ModelFormat.Diffusers.value:
# Prior to v6.9.0, we used the Diffusers format for FLUX LoRA models that used the diffusers _key_
# structure. This was misleading, as everywhere else in the application, we used the Diffusers format
# to indicate that the model files were in the Diffusers _file_ format (i.e. a directory containing
# the weights and config files).
#
# At runtime, we check the LoRA's state dict directly to determine the key structure, so we do not need
# to rely on the format field for this purpose. As of v6.9.0, we always use the LyCORIS format for single-
# file LoRAs, regardless of the key structure.
#
# This change allows LoRA model identification to not need a special case for FLUX LoRAs in the diffusers
# key format.
config_dict["format"] = ModelFormat.LyCORIS.value
if type == ModelType.CLIPVision.value:
# Prior to v6.9.0, some CLIP Vision models were associated with a specific base model architecture:
# - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL
# - CLIP-ViT-H-14-laion2B-s32B-b79K is the image encoder for SD1.5 IP Adapter and was associated with SD1.5
#
# While this made some sense at the time, it is more correct and flexible to treat CLIP Vision models
# as independent of any specific base model architecture.
config_dict["base"] = BaseModelType.Any.value
if type == ModelType.CLIPEmbed.value:
# Prior to v6.9.0, some CLIP Embed models did not have a variant set. The default was the L variant.
# We now make it explicit and always present. Use the existing value if present, otherwise default to
# L variant. Also, treat CLIP Embed models as independent of any specific base model architecture.
config_dict["base"] = BaseModelType.Any.value
config_dict["variant"] = config_dict.get("variant", ClipVariantType.L.value)
try:
migrated_config = AnyModelConfigValidator.validate_python(config_dict)
# This could be a ValidationError or any other error that occurs during validation. A failure to generate a
# union discriminator could raise a ValueError, for example. Who knows what else could fail - catch all.
except Exception as e:
self._logger.error("Failed to validate migrated config, attempting to save as unknown model: %s", e)
cloned_config_dict = deepcopy(config_dict)
cloned_config_dict.pop("base", None)
cloned_config_dict.pop("type", None)
cloned_config_dict.pop("format", None)
migrated_config = Unknown_Config(
**cloned_config_dict,
base=BaseModelType.Unknown,
type=ModelType.Unknown,
format=ModelFormat.Unknown,
)
return migrated_config
def build_migration_23(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 22 to version 23.
This migration updates model configurations to the latest config schemas for v6.9.0.
"""
return Migration(
from_version=22,
to_version=23,
callback=Migration23Callback(app_config=app_config, logger=logger),
)

View File

@@ -0,0 +1,240 @@
import json
import sqlite3
from logging import Logger
from pathlib import Path
from typing import NamedTuple
from pydantic import ValidationError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
class NormalizeResult(NamedTuple):
new_relative_path: str | None
rollback_ops: list[tuple[Path, Path]]
class Migration24Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
self._models_dir = app_config.models_path.resolve()
def __call__(self, cursor: sqlite3.Cursor) -> None:
# Grab all model records
cursor.execute("SELECT id, config FROM models;")
rows = cursor.fetchall()
for model_id, config_json in rows:
try:
config = AnyModelConfigValidator.validate_json(config_json)
except ValidationError:
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
# for users, more likely for devs testing out migration paths.
self._logger.warning("Skipping model %s: invalid config schema", model_id)
continue
except json.JSONDecodeError:
# This should never happen, as we use pydantic to serialize the config to JSON.
self._logger.warning("Skipping model %s: invalid config JSON", model_id)
continue
# We'll use a savepoint so we can roll back the database update if something goes wrong, and a simple
# rollback of file operations if needed.
cursor.execute("SAVEPOINT migrate_model")
try:
new_relative_path, rollback_ops = self._normalize_model_storage(
key=config.key,
path_value=config.path,
)
except Exception as err:
self._logger.error("Error normalizing model %s: %s", config.key, err)
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
cursor.execute("RELEASE SAVEPOINT migrate_model")
continue
if new_relative_path is None:
cursor.execute("RELEASE SAVEPOINT migrate_model")
continue
config.path = new_relative_path
try:
cursor.execute(
"UPDATE models SET config = ? WHERE id = ?;",
(config.model_dump_json(), model_id),
)
except Exception as err:
self._logger.error("Database update failed for model %s: %s", config.key, err)
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
cursor.execute("RELEASE SAVEPOINT migrate_model")
self._rollback_file_ops(rollback_ops)
continue
cursor.execute("RELEASE SAVEPOINT migrate_model")
self._prune_empty_directories()
def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult:
models_dir = self._models_dir
stored_path = Path(path_value)
relative_path: Path | None
if stored_path.is_absolute():
# If the stored path is absolute, we need to check if it's inside the models directory, which means it is
# an Invoke-managed model. If it's outside, it is user-managed we leave it alone.
try:
relative_path = stored_path.resolve().relative_to(models_dir)
except ValueError:
self._logger.info("Leaving user-managed model %s at %s", key, stored_path)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
else:
# Relative paths are always relative to the models directory and thus Invoke-managed.
relative_path = stored_path
# If the relative path is empty, assume something is wrong. Warn and skip.
if not relative_path.parts:
self._logger.warning("Skipping model %s: empty relative path", key)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
# Sanity check: the path is relative. It should be present in the models directory.
absolute_path = (models_dir / relative_path).resolve()
if not absolute_path.exists():
self._logger.warning(
"Skipping model %s: expected model files at %s but nothing was found",
key,
absolute_path,
)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
if relative_path.parts[0] == key:
# Already normalized. Still ensure the stored path is relative.
normalized_path = relative_path.as_posix()
# If the stored path is already the normalized path, no change is needed.
new_relative_path = normalized_path if stored_path.as_posix() != normalized_path else None
return NormalizeResult(new_relative_path=new_relative_path, rollback_ops=[])
# We'll store the file operations we perform so we can roll them back if needed.
rollback_ops: list[tuple[Path, Path]] = []
# Destination directory is models_dir/<key> - a flat directory structure.
destination_dir = models_dir / key
try:
if absolute_path.is_file():
destination_dir.mkdir(parents=True, exist_ok=True)
dest_file = destination_dir / absolute_path.name
# This really shouldn't happen.
if dest_file.exists():
self._logger.warning(
"Destination for model %s already exists at %s; skipping move",
key,
dest_file,
)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
self._logger.info("Moving model file %s -> %s", absolute_path, dest_file)
# `Path.rename()` effectively moves the file or directory.
absolute_path.rename(dest_file)
rollback_ops.append((dest_file, absolute_path))
return NormalizeResult(
new_relative_path=(Path(key) / dest_file.name).as_posix(),
rollback_ops=rollback_ops,
)
if absolute_path.is_dir():
dest_path = destination_dir
# This really shouldn't happen.
if dest_path.exists():
self._logger.warning(
"Destination directory %s already exists for model %s; skipping",
dest_path,
key,
)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
self._logger.info("Moving model directory %s -> %s", absolute_path, dest_path)
# `Path.rename()` effectively moves the file or directory.
absolute_path.rename(dest_path)
rollback_ops.append((dest_path, absolute_path))
return NormalizeResult(
new_relative_path=Path(key).as_posix(),
rollback_ops=rollback_ops,
)
# Maybe a broken symlink or something else weird?
self._logger.warning("Skipping model %s: path %s is neither a file nor directory", key, absolute_path)
return NormalizeResult(new_relative_path=None, rollback_ops=[])
except Exception:
self._rollback_file_ops(rollback_ops)
raise
def _rollback_file_ops(self, rollback_ops: list[tuple[Path, Path]]) -> None:
# This is a super-simple rollback that just reverses the move operations we performed.
for source, destination in reversed(rollback_ops):
try:
if source.exists():
source.rename(destination)
except Exception as err:
self._logger.error("Failed to rollback move %s -> %s: %s", source, destination, err)
def _prune_empty_directories(self) -> None:
# These directories are system directories we want to keep even if empty. Technically, the app should not
# have any problems if these are removed, creating them as needed, but it's cleaner to just leave them alone.
keep_names = {"model_images", ".download_cache"}
keep_dirs = {self._models_dir / name for name in keep_names}
removed_dirs: set[Path] = set()
# Walk the models directory tree from the bottom up, removing empty directories. We sort by path length
# descending to ensure we visit children before parents.
for directory in sorted(self._models_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True):
if not directory.is_dir():
continue
if directory == self._models_dir:
continue
if any(directory == keep or keep in directory.parents for keep in keep_dirs):
continue
try:
next(directory.iterdir())
except StopIteration:
try:
directory.rmdir()
removed_dirs.add(directory)
self._logger.debug("Removed empty directory %s", directory)
except OSError:
# Directory not empty (or some other error) - bail out.
self._logger.warning("Failed to prune directory %s - not empty?", directory)
continue
except OSError:
continue
self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir)
def build_migration_24(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 23 to version 24.
This migration normalizes on-disk model storage so that each model lives within
a directory named by its key inside the Invoke-managed models directory, and
updates database records to reference the new relative paths.
This migration behaves a bit differently than others. Because it involves FS operations, if we rolled the
DB back on any failure, we could leave the FS out of sync with the DB. Instead, we use savepoints
to roll back individual model updates on failure, and we roll back any FS operations we performed
for that model.
If a model cannot be migrated for any reason (invalid config, missing files, FS errors, DB errors), we log a
warning and skip it, leaving it in its original state and location. The model will still work, but it will be in
the "wrong" location on disk.
"""
return Migration(
from_version=23,
to_version=24,
callback=Migration24Callback(app_config=app_config, logger=logger),
)

View File

@@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@@ -115,6 +116,13 @@ def get_openapi_func(
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
any_model_config_schema = AnyModelConfigValidator.json_schema(
mode="serialization",
ref_template="#/components/schemas/{model}",
)
move_defs_to_top_level(openapi_schema, any_model_config_schema)
openapi_schema["components"]["schemas"]["AnyModelConfig"] = any_model_config_schema
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)

View File

@@ -5,7 +5,7 @@ import torch
from invokeai.backend.flux.model import FluxParams
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
def is_state_dict_xlabs_controlnet(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
@@ -25,7 +25,7 @@ def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
return False
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
def is_state_dict_instantx_controlnet(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.

View File

@@ -1,10 +1,7 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from invokeai.backend.model_manager.legacy_probe import CkptType
from typing import Any
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
def get_flux_in_channels_from_state_dict(state_dict: dict[str | int, Any]) -> int | None:
"""Gets the in channels from the state dict."""
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use

View File

@@ -1,11 +1,11 @@
from typing import Any, Dict
from typing import Any
import torch
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams
def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool:
"""Is the state dict for an XLabs FLUX IP-Adapter model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
@@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool:
return False
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams:
num_double_blocks = 0
context_dim = 0
hidden_dim = 0

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict
from typing import Any
def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely a FLUX Redux model."""
expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"}

View File

@@ -1,10 +1,11 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from typing import Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
@dataclass
@@ -41,30 +42,39 @@ PREFERED_KONTEXT_RESOLUTIONS = [
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,
"flux-schnell": 256,
_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
FluxVariantType.Dev: 512,
FluxVariantType.DevFill: 512,
FluxVariantType.Schnell: 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
def get_flux_max_seq_length(variant: AnyVariant):
try:
return _flux_max_seq_lengths[variant]
except KeyError:
raise ValueError(f"Unknown variant for FLUX max seq len: {variant}")
params = {
"flux-dev": FluxParams(
_flux_ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
def get_flux_ae_params() -> AutoEncoderParams:
return _flux_ae_params
_flux_transformer_params: dict[AnyVariant, FluxParams] = {
FluxVariantType.Dev: FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@@ -78,7 +88,7 @@ params = {
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
FluxVariantType.Schnell: FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
@@ -92,7 +102,7 @@ params = {
qkv_bias=True,
guidance_embed=False,
),
"flux-dev-fill": FluxParams(
FluxVariantType.DevFill: FluxParams(
in_channels=384,
out_channels=64,
vec_in_dim=768,
@@ -108,3 +118,10 @@ params = {
guidance_embed=True,
),
}
def get_flux_transformers_params(variant: AnyVariant):
try:
return _flux_transformer_params[variant]
except KeyError:
raise ValueError(f"Unknown variant for FLUX transformer params: {variant}")

View File

@@ -1,45 +0,0 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from invokeai.backend.model_manager.config import (
AnyModelConfig,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
AnyVariant,
BaseModelType,
ClipVariantType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
__all__ = [
"AnyModelConfig",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelProbe",
"ModelSearch",
"ModelConfigBase",
"AnyModel",
"AnyVariant",
"BaseModelType",
"ClipVariantType",
"ModelFormat",
"ModelRepoVariant",
"ModelSourceType",
"ModelType",
"ModelVariantType",
"SchedulerPredictionType",
"SubModelType",
]

View File

@@ -1,770 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
from enum import Enum
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
logger = logging.getLogger(__name__)
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognize this combination of model type and format."""
pass
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
model_config = ConfigDict(protected_namespaces=())
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
model_config = ConfigDict(extra="forbid")
class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
model_config = ConfigDict(extra="forbid")
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
model_config = ConfigDict(extra="forbid")
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
MED = 1
SLOW = 2
class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
To create a new config type, inherit from this class and implement its interface:
- (mandatory) override methods 'matches' and 'parse'
- (mandatory) define fields 'type' and 'format' as class attributes
- (optional) override method 'get_tag'
- (optional) override field _MATCH_SPEED
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
file_size: int = Field(description="The size of the model in bytes.")
name: str = Field(description="Name of the model.")
type: ModelType = Field(description="Model type")
format: ModelFormat = Field(description="Model format")
base: BaseModelType = Field(description="The base model.")
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
description: Optional[str] = Field(description="Model description", default=None)
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(mod, hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
for config_cls in sorted_by_match_speed:
try:
if not config_cls.matches(mod):
continue
except Exception as e:
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
return Tag(f"{type}.{format}")
@classmethod
@abstractmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
"""Returns a dictionary with the fields needed to construct the model.
Raises InvalidModelConfigException if the model is invalid.
"""
pass
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk) -> bool:
"""Performs a quick check to determine if the config matches the model.
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])
if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
fields["file_size"] = fields.get("file_size") or mod.size()
return cls(**fields)
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
class CheckpointConfigBase(ABC, BaseModel):
"""Base class for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class DiffusersConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[LoraModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value
@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
metadata = mod.metadata()
return (
metadata.get("modelspec.sai_model_spec")
and metadata.get("ot_branch") == "omi_format"
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
base = BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
return {"base": base}
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if isinstance(key, int):
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
class VideoConfigBase(ABC, BaseModel):
type: Literal[ModelType.Video] = ModelType.Video
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main diffusers models."""
pass
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(description="Clip variant for this model")
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: Literal[ClipVariantType.G] = ClipVariantType.G
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: Literal[ClipVariantType.L] = ClipVariantType.L
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return False
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False
architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}
class ApiModelConfig(MainConfigBase, ModelConfigBase):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
# API models are not stored on disk, so we can't match them.
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
class VideoApiModelConfig(VideoConfigBase, ModelConfigBase):
"""Model config for API-based video models."""
format: Literal[ModelFormat.Api] = ModelFormat.Api
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
# API models are not stored on disk, so we can't match them.
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
raise NotImplementedError("API models are not parsed from disk.")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = type_ = variant_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
variant_ = v.get("variant")
if isinstance(variant_, Enum):
variant_ = variant_.value
else:
format_ = v.format.value
type_ = v.type.value
variant_ = getattr(v, "variant", None)
if variant_:
variant_ = variant_.value
# Ideally, each config would be uniquely identified with a combination of fields
# i.e. (type, format, variant) without any special cases. Alas...
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
# To maintain compatibility, we default to ClipVariantType.L in this case.
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
variant_ = variant_ or ClipVariantType.L.value
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory:
@staticmethod
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
if isinstance(model, CheckpointConfigBase) and timestamp:
model.converted_at = timestamp
validate_hash(model.hash)
return model # type: ignore

View File

@@ -0,0 +1,245 @@
from abc import ABC, abstractmethod
from enum import Enum
from inspect import isabstract
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Self,
Type,
)
from pydantic import BaseModel, ConfigDict, Field, Tag
from pydantic_core import PydanticUndefined
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
ModelType,
)
if TYPE_CHECKING:
pass
class Config_Base(ABC, BaseModel):
"""
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format
would have base=sd-1, type=main, format=checkpoint.
To create a new config type, inherit from this class and implement its interface:
- Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be
called during model installation to determine the correct config class for a model.
- Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A
default must be provided for each of these fields.
If multiple combinations of base, type and format need to be supported, create a separate subclass for each.
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
# These fields are common to all model configs.
key: str = Field(
default_factory=uuid_string,
description="A unique key for this model.",
)
hash: str = Field(
description="The hash of the model file(s).",
)
path: str = Field(
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.",
)
file_size: int = Field(
description="The size of the model in bytes.",
)
name: str = Field(
description="Name of the model.",
)
description: str | None = Field(
default=None,
description="Model description",
)
source: str = Field(
description="The original source of the model (path, URL or repo_id).",
)
source_type: ModelSourceType = Field(
description="The type of source",
)
source_api_response: str | None = Field(
default=None,
description="The original API response from the source, as stringified JSON.",
)
cover_image: str | None = Field(
default=None,
description="Url for image to preview model",
)
usage_info: str | None = Field(
default=None,
description="Usage information for this model",
)
CONFIG_CLASSES: ClassVar[set[Type["Config_Base"]]] = set()
"""Set of all non-abstract subclasses of Config_Base, for use during model probing. In other words, this is the set
of all known model config types."""
model_config = ConfigDict(
validate_assignment=True,
json_schema_serialization_defaults_required=True,
json_schema_mode_override="serialization",
)
@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# Register non-abstract subclasses so we can iterate over them later during model probing. Note that
# isabstract() will return False if the class does not have any abstract methods, even if it inherits from ABC.
# We must check for ABC lest we unintentionally register some abstract model config classes.
if not isabstract(cls) and ABC not in cls.__bases__:
cls.CONFIG_CLASSES.add(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# Ensure that model configs define 'base', 'type' and 'format' fields and provide defaults for them. Each
# subclass is expected to represent a single combination of base, type and format.
#
# This pydantic dunder method is called after the pydantic model for a class is created. The normal
# __init_subclass__ is too early to do this check.
for name in ("type", "base", "format"):
if name not in cls.model_fields:
raise NotImplementedError(f"{cls.__name__} must define a '{name}' field")
if cls.model_fields[name].default is PydanticUndefined:
raise NotImplementedError(f"{cls.__name__} must define a default for the '{name}' field")
@classmethod
def get_tag(cls) -> Tag:
"""Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized,
pydantic uses the tag to determine which subclass to instantiate.
The tag is a dot-separated string of the type, format, base and variant (if applicable).
"""
tag_strings: list[str] = []
for name in ("type", "format", "base", "variant"):
if field := cls.model_fields.get(name):
# The check in __pydantic_init_subclass__ ensures that type, format and base are always present with
# defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can
# check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by
# pydantic to indicate that no default was provided.
if field.default is not PydanticUndefined:
# We expect each of these fields has an Enum for its default; we want the value of the enum.
tag_strings.append(field.default.value)
return Tag(".".join(tag_strings))
@staticmethod
def get_model_discriminator_value(v: Any) -> str:
"""Computes the discriminator value for a model config discriminated union."""
# This is called by pydantic during deserialization and serialization to determine which model the data
# represents. It can get either a dict (during deserialization) or an instance of a Config_Base subclass
# (during serialization).
#
# See: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
if isinstance(v, Config_Base):
# We have an instance of a ModelConfigBase subclass - use its tag directly.
return v.get_tag().tag
if isinstance(v, dict):
# We have a dict - attempt to compute a tag from its fields.
tag_strings: list[str] = []
if type_ := v.get("type"):
if isinstance(type_, Enum):
type_ = str(type_.value)
elif not isinstance(type_, str):
raise ValueError("Model config dict 'type' field must be a string or Enum")
tag_strings.append(type_)
if format_ := v.get("format"):
if isinstance(format_, Enum):
format_ = str(format_.value)
elif not isinstance(format_, str):
raise ValueError("Model config dict 'format' field must be a string or Enum")
tag_strings.append(format_)
if base_ := v.get("base"):
if isinstance(base_, Enum):
base_ = str(base_.value)
elif not isinstance(base_, str):
raise ValueError("Model config dict 'base' field must be a string or Enum")
tag_strings.append(base_)
# Special case: CLIP Embed models also need the variant to distinguish them.
if (
type_ == ModelType.CLIPEmbed.value
and format_ == ModelFormat.Diffusers.value
and base_ == BaseModelType.Any.value
):
if variant_ := v.get("variant"):
if isinstance(variant_, Enum):
variant_ = variant_.value
elif not isinstance(variant_, str):
raise ValueError("Model config dict 'variant' field must be a string or Enum")
tag_strings.append(variant_)
else:
raise ValueError("CLIP Embed model config dict must include a 'variant' field")
return ".".join(tag_strings)
else:
raise ValueError(
"Model config discriminator value must be computed from a dict or ModelConfigBase instance"
)
@classmethod
@abstractmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
"""Given the model on disk and any override fields, attempt to construct an instance of this config class.
This method serves to identify whether the model on disk matches this config class, and if so, to extract any
additional metadata needed to instantiate the config.
Implementations should raise a NotAMatchError if the model does not match this config class."""
raise NotImplementedError(f"from_model_on_disk not implemented for {cls.__name__}")
class Checkpoint_Config_Base(ABC, BaseModel):
"""Base class for checkpoint-style models."""
config_path: str | None = Field(
description="Path to the config for this model, if any.",
default=None,
)
class Diffusers_Config_Base(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
repo_variant: ModelRepoVariant = Field(ModelRepoVariant.Default)
@classmethod
def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant:
# get all files ending in .bin or .safetensors
weight_files = list(mod.path.glob("**/*.safetensors"))
weight_files.extend(list(mod.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant | None = None
model_config = ConfigDict(protected_namespaces=())

View File

@@ -0,0 +1,91 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
ModelFormat,
ModelType,
)
def get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
try:
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base):
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
{
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
},
)
cls._validate_variant(mod)
return cls(**override_fields)
@classmethod
def _validate_variant(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model variant does not match this config class."""
expected_variant = cls.model_fields["variant"].default
config = get_config_dict_or_raise(
{
mod.path / "config.json",
mod.path / "text_encoder" / "config.json",
},
)
recognized_variant = get_clip_variant_type_from_config(config)
if recognized_variant is None:
raise NotAMatchError("unable to determine CLIP variant from config")
if expected_variant is not recognized_variant:
raise NotAMatchError(f"variant is {recognized_variant}, not {expected_variant}")
class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G)
class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base):
variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L)

View File

@@ -0,0 +1,57 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
get_class_name_from_config_dict_or_raise,
get_config_dict_or_raise,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for CLIPVision."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls.raise_if_config_doesnt_look_like_clip_vision(mod)
return cls(**override_fields)
@classmethod
def raise_if_config_doesnt_look_like_clip_vision(cls, mod: ModelOnDisk) -> None:
config_dict = get_config_dict_or_raise(mod.path / "config.json")
class_name = get_class_name_from_config_dict_or_raise(config_dict)
if class_name == "CLIPVisionModelWithProjection":
looks_like_clip_vision = True
elif class_name == "CLIPModel" and "vision_config" in config_dict:
looks_like_clip_vision = True
else:
looks_like_clip_vision = False
if not looks_like_clip_vision:
raise NotAMatchError(
f"config class name is {class_name}, not CLIPVisionModelWithProjection or CLIPModel with vision_config"
)

View File

@@ -0,0 +1,230 @@
from typing import (
Literal,
Self,
)
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Any
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart anime": "lineart_anime_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"lineart": "lineart_image_processor",
"soft": "hed_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
model_config = ConfigDict(extra="forbid")
@classmethod
def from_model_name(cls, model_name: str) -> Self:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
model_name_lower = model_name.lower()
if k in model_name_lower:
return cls(preprocessor=v)
return cls(preprocessor=None)
class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"ControlNetModel",
"FluxControlNetModel",
},
)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
if config_dict.get("_class_name") == "FluxControlNetModel":
return BaseModelType.Flux
dimension = config_dict.get("cross_attention_dim")
match dimension:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
# No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them
# anyway.
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross_attention_dim {dimension}")
class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_controlnet(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
},
):
raise NotAMatchError("state dict does not look like a ControlNet checkpoint")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict):
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
):
if key not in state_dict:
continue
width = state_dict[key].shape[-1]
match width:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case 1280:
return BaseModelType.StableDiffusionXL
case _:
pass
raise NotAMatchError("unable to determine base type from state dict")
class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,523 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import (
Union,
)
from pydantic import Discriminator, TypeAdapter, ValidationError
from typing_extensions import Annotated, Any
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_L_Config
from invokeai.backend.model_manager.configs.clip_vision import CLIPVision_Diffusers_Config
from invokeai.backend.model_manager.configs.controlnet import (
ControlAdapterDefaultSettings,
ControlNet_Checkpoint_FLUX_Config,
ControlNet_Checkpoint_SD1_Config,
ControlNet_Checkpoint_SD2_Config,
ControlNet_Checkpoint_SDXL_Config,
ControlNet_Diffusers_FLUX_Config,
ControlNet_Diffusers_SD1_Config,
ControlNet_Diffusers_SD2_Config,
ControlNet_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
from invokeai.backend.model_manager.configs.ip_adapter import (
IPAdapter_Checkpoint_FLUX_Config,
IPAdapter_Checkpoint_SD1_Config,
IPAdapter_Checkpoint_SD2_Config,
IPAdapter_Checkpoint_SDXL_Config,
IPAdapter_InvokeAI_SD1_Config,
IPAdapter_InvokeAI_SD2_Config,
IPAdapter_InvokeAI_SDXL_Config,
)
from invokeai.backend.model_manager.configs.llava_onevision import LlavaOnevision_Diffusers_Config
from invokeai.backend.model_manager.configs.lora import (
ControlLoRA_LyCORIS_FLUX_Config,
LoRA_Diffusers_FLUX_Config,
LoRA_Diffusers_SD1_Config,
LoRA_Diffusers_SD2_Config,
LoRA_Diffusers_SDXL_Config,
LoRA_LyCORIS_FLUX_Config,
LoRA_LyCORIS_SD1_Config,
LoRA_LyCORIS_SD2_Config,
LoRA_LyCORIS_SDXL_Config,
LoRA_OMI_FLUX_Config,
LoRA_OMI_SDXL_Config,
LoraModelDefaultSettings,
)
from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_FLUX_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Diffusers_CogView4_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SD3_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
Main_ExternalAPI_ChatGPT4o_Config,
Main_ExternalAPI_FluxKontext_Config,
Main_ExternalAPI_Gemini2_5_Config,
Main_ExternalAPI_Imagen3_Config,
Main_ExternalAPI_Imagen4_Config,
Main_GGUF_FLUX_Config,
MainModelDefaultSettings,
Video_ExternalAPI_Runway_Config,
Video_ExternalAPI_Veo3_Config,
)
from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config
from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config
from invokeai.backend.model_manager.configs.t2i_adapter import (
T2IAdapter_Diffusers_SD1_Config,
T2IAdapter_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
from invokeai.backend.model_manager.configs.textual_inversion import (
TI_File_SD1_Config,
TI_File_SD2_Config,
TI_File_SDXL_Config,
TI_Folder_SD1_Config,
TI_Folder_SD2_Config,
TI_Folder_SDXL_Config,
)
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.configs.vae import (
VAE_Checkpoint_FLUX_Config,
VAE_Checkpoint_SD1_Config,
VAE_Checkpoint_SD2_Config,
VAE_Checkpoint_SDXL_Config,
VAE_Diffusers_SD1_Config,
VAE_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelSourceType,
ModelType,
variant_type_adapter,
)
logger = logging.getLogger(__name__)
app_config = get_config()
# Known model file extensions for sanity checking
_MODEL_EXTENSIONS = {
".safetensors",
".ckpt",
".pt",
".pth",
".bin",
".gguf",
".onnx",
}
# Known config file names for diffusers/transformers models
_CONFIG_FILES = {
"model_index.json",
"config.json",
}
# Maximum number of files in a directory to be considered a model
_MAX_FILES_IN_MODEL_DIR = 50
# Maximum depth to search for model files in directories
_MAX_SEARCH_DEPTH = 2
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
# Main (Pipeline) - diffusers format
Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()],
Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()],
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
# Main (Pipeline) - checkpoint format
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
# Main (Pipeline) - quantized formats
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
# VAE - checkpoint format
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
# VAE - diffusers format
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
# ControlNet - checkpoint format
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
# ControlNet - diffusers format
Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()],
Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()],
# LoRA - LyCORIS format
Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()],
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
# LoRA - OMI format
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
# LoRA - diffusers format
Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()],
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
# ControlLoRA - diffusers format
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
# T5 Encoder - all formats
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
# TI - file format
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()],
# TI - folder format
Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()],
Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()],
Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()],
# IP Adapter - InvokeAI format
Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()],
Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()],
# IP Adapter - checkpoint format
Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()],
Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()],
# T2I Adapter - diffusers format
Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()],
Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()],
# Misc models
Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()],
Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()],
Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()],
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
# Main - external API
Annotated[Main_ExternalAPI_ChatGPT4o_Config, Main_ExternalAPI_ChatGPT4o_Config.get_tag()],
Annotated[Main_ExternalAPI_Gemini2_5_Config, Main_ExternalAPI_Gemini2_5_Config.get_tag()],
Annotated[Main_ExternalAPI_Imagen3_Config, Main_ExternalAPI_Imagen3_Config.get_tag()],
Annotated[Main_ExternalAPI_Imagen4_Config, Main_ExternalAPI_Imagen4_Config.get_tag()],
Annotated[Main_ExternalAPI_FluxKontext_Config, Main_ExternalAPI_FluxKontext_Config.get_tag()],
# Video - external API
Annotated[Video_ExternalAPI_Veo3_Config, Video_ExternalAPI_Veo3_Config.get_tag()],
Annotated[Video_ExternalAPI_Runway_Config, Video_ExternalAPI_Runway_Config.get_tag()],
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Discriminator(Config_Base.get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig)
"""Pydantic TypeAdapter for the AnyModelConfig union, used for parsing and validation.
If you need to parse/validate a dict or JSON into an AnyModelConfig, you should probably use
ModelConfigFactory.from_dict or ModelConfigFactory.from_json instead as they may implement
additional logic in the future.
"""
@dataclass
class ModelClassificationResult:
"""Result of attempting to classify a model on disk into a specific model config.
Attributes:
match: The best matching model config, or None if no match was found.
results: A mapping of model config class names to either an instance of that class (if it matched)
or an Exception (if it didn't match or an error occurred during matching).
"""
config: AnyModelConfig | None
details: dict[str, AnyModelConfig | Exception]
@property
def all_matches(self) -> list[AnyModelConfig]:
"""Returns a list of all matching model configs found."""
return [r for r in self.details.values() if isinstance(r, Config_Base)]
@property
def match_count(self) -> int:
"""Returns the number of matching model configs found."""
return len(self.all_matches)
class ModelConfigFactory:
@staticmethod
def from_dict(fields: dict[str, Any]) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(fields)
return model
@staticmethod
def from_json(json: str | bytes | bytearray) -> AnyModelConfig:
"""Return the appropriate config object from json."""
model = AnyModelConfigValidator.validate_json(json)
return model
@staticmethod
def build_common_fields(
mod: ModelOnDisk,
override_fields: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Builds the common fields for all model configs.
Args:
mod: The model on disk to extract fields from.
overrides: A optional dictionary of fields to override. These fields will take precedence over the values
extracted from the model on disk.
- Casts string fields to their Enum types.
- Does not validate the fields against the model config schema.
"""
_overrides: dict[str, Any] = override_fields or {}
fields: dict[str, Any] = {}
if "type" in _overrides:
fields["type"] = ModelType(_overrides["type"])
if "format" in _overrides:
fields["format"] = ModelFormat(_overrides["format"])
if "base" in _overrides:
fields["base"] = BaseModelType(_overrides["base"])
if "source_type" in _overrides:
fields["source_type"] = ModelSourceType(_overrides["source_type"])
if "variant" in _overrides:
fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"])
fields["path"] = mod.path.as_posix()
fields["source"] = _overrides.get("source") or fields["path"]
fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path
fields["name"] = _overrides.get("name") or mod.name
fields["hash"] = _overrides.get("hash") or mod.hash()
fields["key"] = _overrides.get("key") or uuid_string()
fields["description"] = _overrides.get("description")
fields["file_size"] = _overrides.get("file_size") or mod.size()
return fields
@staticmethod
def _validate_path_looks_like_model(path: Path) -> None:
"""Perform basic sanity checks to ensure a path looks like a model.
This prevents wasting time trying to identify obviously non-model paths like
home directories or downloads folders. Raises RuntimeError if the path doesn't
pass basic checks.
Args:
path: The path to validate
Raises:
ValueError: If the path doesn't look like a model
"""
if path.is_file():
# For files, just check the extension
if path.suffix.lower() not in _MODEL_EXTENSIONS:
raise ValueError(
f"File extension {path.suffix} is not a recognized model format. "
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
)
else:
# For directories, do a quick file count check with early exit
total_files = 0
# Ignore hidden files and directories
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
for item in paths_to_check:
if item.is_file():
total_files += 1
if total_files > _MAX_FILES_IN_MODEL_DIR:
raise ValueError(
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
"This looks like a general-purpose directory rather than a model. "
"Please provide a path to a specific model file or model directory."
)
# Check if it has config files at root (diffusers/transformers marker)
has_root_config = any((path / config).exists() for config in _CONFIG_FILES)
if has_root_config:
# Has a config file, looks like a valid model directory
return
# Otherwise, search for model files within depth limit
def find_model_files(current_path: Path, depth: int) -> bool:
if depth > _MAX_SEARCH_DEPTH:
return False
try:
for item in current_path.iterdir():
if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS:
return True
elif item.is_dir() and find_model_files(item, depth + 1):
return True
except PermissionError:
pass
return False
if not find_model_files(path, 0):
raise ValueError(
f"No model files or config files found in directory {path}. "
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
)
@staticmethod
def matches_sort_key(m: AnyModelConfig) -> int:
"""Sort key function to prioritize model config matches in case of multiple matches."""
# It is possible that we have multiple matches. We need to prioritize them.
# Known cases where multiple matches can occur:
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
# a config.json file. Prefer the main model.
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
# sophisticated approach.
match m.type:
case ModelType.Main:
return 0
case ModelType.LoRA:
return 1
case ModelType.CLIPEmbed:
return 2
case _:
return 3
@staticmethod
def from_model_on_disk(
mod: str | Path | ModelOnDisk,
override_fields: dict[str, Any] | None = None,
hash_algo: HASHING_ALGORITHMS = "blake3_single",
allow_unknown: bool = True,
) -> ModelClassificationResult:
"""Classify a model on disk and return the best matching model config.
Args:
mod: The model on disk to classify. Can be a path (str or Path) or a ModelOnDisk instance.
override_fields: Optional dictionary of fields to override. These fields will take precedence
over the values extracted from the model on disk, but this cannot force a match if the
model on disk doesn't actually match the config class.
hash_algo: The hashing algorithm to use when computing the model hash if needed.
Returns:
A ModelClassificationResult containing the best matching model config (or None if no match)
and a mapping of all attempted model config classes to either an instance of that class (if it matched)
or an Exception (if it didn't match or an error occurred during matching).
Raises:
ValueError: If the provided path doesn't look like a model.
"""
if isinstance(mod, Path | str):
mod = ModelOnDisk(Path(mod), hash_algo)
# Perform basic sanity checks before attempting any config matching
# This rejects obviously non-model paths early, saving time
ModelConfigFactory._validate_path_looks_like_model(mod.path)
# We will always need these fields to build any model config.
fields = ModelConfigFactory.build_common_fields(mod, override_fields)
# Store results as a mapping of config class to either an instance of that class or an exception
# that was raised when trying to build it.
details: dict[str, AnyModelConfig | Exception] = {}
# Try to build an instance of each model config class that uses the classify API.
# Each class will either return an instance of itself or raise NotAMatch if it doesn't match.
# Other exceptions may be raised if something unexpected happens during matching or building.
for candidate_class in filter(lambda x: x is not Unknown_Config, Config_Base.CONFIG_CLASSES):
candidate_name = candidate_class.__name__
try:
# Technically, from_model_on_disk returns a Config_Base, but in practice it will always be a member of
# the AnyModelConfig union.
details[candidate_name] = candidate_class.from_model_on_disk(mod, fields) # type: ignore
except NotAMatchError as e:
# This means the model didn't match this config class. It's not an error, just no match.
details[candidate_name] = e
except ValidationError as e:
# This means the model matched, but we couldn't create the pydantic model instance for the config.
# Maybe invalid overrides were provided?
details[candidate_name] = e
except Exception as e:
# Some other unexpected error occurred. Store the exception for reporting later.
details[candidate_name] = e
# Extract just the successful matches
matches = [r for r in details.values() if isinstance(r, Config_Base)]
if not matches:
if not allow_unknown:
# No matches and we are not allowed to fall back to Unknown_Config
return ModelClassificationResult(config=None, details=details)
else:
# Fall back to Unknown_Config
# This should always succeed as Unknown_Config.from_model_on_disk never raises NotAMatch
config = Unknown_Config.from_model_on_disk(mod, fields)
details[Unknown_Config.__name__] = config
return ModelClassificationResult(config=config, details=details)
matches.sort(key=ModelConfigFactory.matches_sort_key)
config = matches[0]
# Now do any post-processing needed for specific model types/bases/etc.
match config.type:
case ModelType.Main:
config.default_settings = MainModelDefaultSettings.from_base(config.base)
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
case ModelType.LoRA:
config.default_settings = LoraModelDefaultSettings()
case _:
pass
return ModelClassificationResult(config=config, details=details)
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart anime": "lineart_anime_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"lineart": "lineart_image_processor",
"soft": "hed_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}

View File

@@ -0,0 +1,40 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class FLUXRedux_Checkpoint_Config(Config_Base):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
if not is_state_dict_likely_flux_redux(mod.load_state_dict()):
raise NotAMatchError("model does not match FLUX Tools Redux heuristics")
return cls(**override_fields)

View File

@@ -0,0 +1,206 @@
import json
from functools import cache
from pathlib import Path
from pydantic import BaseModel, ValidationError
from pydantic_core import CoreSchema, SchemaValidator
from typing_extensions import Any
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
class NotAMatchError(Exception):
"""Exception for when a model does not match a config class.
Args:
reason: The reason why the model did not match.
"""
def __init__(self, reason: str):
super().__init__(reason)
def get_config_dict_or_raise(config_path: Path | set[Path]) -> dict[str, Any]:
"""Load the diffusers/transformers model config file and return it as a dictionary. The config file is expected
to be in JSON format.
Args:
config_path: The path to the config file, or a set of paths to try.
Returns:
The config file as a dictionary.
Raises:
NotAMatch if the config file is missing or cannot be loaded.
"""
paths_to_check = config_path if isinstance(config_path, set) else {config_path}
problems: dict[Path, str] = {}
for p in paths_to_check:
if not p.exists():
problems[p] = "file does not exist"
continue
try:
with open(p, "r") as file:
config = json.load(file)
return config
except Exception as e:
problems[p] = str(e)
continue
raise NotAMatchError(f"unable to load config file(s): {problems}")
def get_class_name_from_config_dict_or_raise(config: Path | set[Path] | dict[str, Any]) -> str:
"""Load the diffusers/transformers model config file and return the class name.
Args:
config_path: The path to the config file, or a set of paths to try.
Returns:
The class name from the config file.
Raises:
NotAMatch if the config file is missing or does not contain a valid class name.
"""
if not isinstance(config, dict):
config = get_config_dict_or_raise(config)
try:
if "_class_name" in config:
# This is a diffusers-style config
config_class_name = config["_class_name"]
elif "architectures" in config:
# This is a transformers-style config
config_class_name = config["architectures"][0]
else:
raise ValueError("missing _class_name or architectures field")
except Exception as e:
raise NotAMatchError(f"unable to determine class name from config file: {config}") from e
if not isinstance(config_class_name, str):
raise NotAMatchError(f"_class_name or architectures field is not a string: {config_class_name}")
return config_class_name
def raise_for_class_name(config: Path | set[Path] | dict[str, Any], class_name: str | set[str]) -> None:
"""Get the class name from the config file and raise NotAMatch if it is not in the expected set.
Args:
config_path: The path to the config file, or a set of paths to try.
class_name: The expected class name, or a set of expected class names.
Raises:
NotAMatch if the class name is not in the expected set.
"""
class_name = {class_name} if isinstance(class_name, str) else class_name
actual_class_name = get_class_name_from_config_dict_or_raise(config)
if actual_class_name not in class_name:
raise NotAMatchError(f"invalid class name from config: {actual_class_name}")
def raise_for_override_fields(candidate_config_class: type[BaseModel], override_fields: dict[str, Any]) -> None:
"""Check if the provided override fields are valid for the config class using pydantic.
For example, if the candidate config class has a field "base" of type Literal[BaseModelType.StableDiffusion1], and
the override fields contain "base": BaseModelType.Flux, this function will raise NotAMatch.
Internally, this function extracts the pydantic schema for each individual override field from the candidate config
class and validates the override value against that schema. Post-instantiation validators are not run.
Args:
candidate_config_class: The config class that is being tested.
override_fields: The override fields provided by the user.
Raises:
NotAMatch if any override field is invalid for the config class.
"""
for field_name, override_value in override_fields.items():
if field_name not in candidate_config_class.model_fields:
raise NotAMatchError(f"unknown override field: {field_name}")
try:
PydanticFieldValidator.validate_field(candidate_config_class, field_name, override_value)
except ValidationError as e:
raise NotAMatchError(f"invalid override for field '{field_name}': {e}") from e
def raise_if_not_file(mod: ModelOnDisk) -> None:
"""Raise NotAMatch if the model path is not a file."""
if not mod.path.is_file():
raise NotAMatchError("model path is not a file")
def raise_if_not_dir(mod: ModelOnDisk) -> None:
"""Raise NotAMatch if the model path is not a directory."""
if not mod.path.is_dir():
raise NotAMatchError("model path is not a directory")
def state_dict_has_any_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
"""Returns true if the state dict has any of the specified keys."""
_keys = {keys} if isinstance(keys, str) else keys
return any(key in state_dict for key in _keys)
def state_dict_has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
def state_dict_has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
def common_config_paths(path: Path) -> set[Path]:
"""Returns common config file paths for models stored in directories."""
return {path / "config.json", path / "model_index.json"}
class PydanticFieldValidator:
"""Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
"""
@staticmethod
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
"""Find the Pydantic core schema for a specific field in a model."""
schema: CoreSchema = model.__pydantic_core_schema__.copy()
# we shallow copied, be careful not to mutate the original schema!
assert schema["type"] in ["definitions", "model"]
# find the field schema
field_schema = schema["schema"] # type: ignore
while "fields" not in field_schema:
field_schema = field_schema["schema"] # type: ignore
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
# if the original schema is a definition schema, replace the model schema with the field schema
if schema["type"] == "definitions":
schema["schema"] = field_schema
return schema
else:
return field_schema
@cache
@staticmethod
def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
"""Get a SchemaValidator for a specific field in a model."""
return SchemaValidator(PydanticFieldValidator.find_field_schema(model, field_name))
@staticmethod
def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
"""Validate a value for a specific field in a model."""
return PydanticFieldValidator.get_validator(model, field_name).validate_python(value)

View File

@@ -0,0 +1,180 @@
from abc import ABC
from typing import (
Literal,
Self,
)
from pydantic import BaseModel, Field
from typing_extensions import Any
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class IPAdapter_Config_Base(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter diffusers format models."""
format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI)
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_has_weights_file(mod)
cls._validate_has_image_encoder_metadata_file(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None:
weights_file = mod.path / "ip_adapter.bin"
if not weights_file.exists():
raise NotAMatchError("missing ip_adapter.bin weights file")
@classmethod
def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None:
image_encoder_metadata_file = mod.path / "image_encoder.txt"
if not image_encoder_metadata_file.exists():
raise NotAMatchError("missing image_encoder.txt metadata file")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
try:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_ip_adapter(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"image_proj.",
"ip_adapter.",
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
"ip_adapter_proj_model.",
},
):
raise NotAMatchError("model does not match Checkpoint IP Adapter heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
if is_state_dict_xlabs_ip_adapter(state_dict):
return BaseModelType.Flux
try:
cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1]
except Exception as e:
raise NotAMatchError(f"unable to determine cross attention dimension: {e}") from e
match cross_attention_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross attention dimension {cross_attention_dim}")
class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,42 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
)
class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"LlavaOnevisionForConditionalGeneration",
},
)
return cls(**override_fields)

View File

@@ -0,0 +1,322 @@
from abc import ABC
from pathlib import Path
from typing import (
Any,
Literal,
Self,
)
from pydantic import BaseModel, ConfigDict, Field
from invokeai.backend.model_manager.configs.base import (
Config_Base,
)
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_ending_with,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
FluxLoRAFormat,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
model_config = ConfigDict(extra="forbid")
class LoRA_Config_Base(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
trigger_phrases: set[str] | None = Field(
default=None,
description="Set of trigger phrases for this model",
)
default_settings: LoraModelDefaultSettings | None = Field(
default=None,
description="Default settings for this model",
)
def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
# TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later.
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
state_dict = mod.load_state_dict()
value = flux_format_from_state_dict(state_dict, mod.metadata())
return value
class LoRA_OMI_Config_Base(LoRA_Config_Base):
format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_omi_lora(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model metadata does not look like an OMI LoRA."""
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
raise NotAMatchError("model looks like ControlLoRA or Diffusers LoRA")
metadata = mod.metadata()
metadata_looks_like_omi_lora = (
bool(metadata.get("modelspec.sai_model_spec"))
and metadata.get("ot_branch") == "omi_format"
and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora"
)
if not metadata_looks_like_omi_lora:
raise NotAMatchError("metadata does not look like OMI LoRA")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
return BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
return BaseModelType.Flux
else:
raise NotAMatchError(f"unrecognised/unsupported architecture for OMI LoRA: {architecture}")
class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_lora(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# First rule out ControlLoRA and Diffusers LoRA
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control]:
raise NotAMatchError("model looks like Control LoRA")
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"lora_te_",
"lora_unet_",
"lora_te1_",
"lora_te2_",
"lora_transformer_",
},
)
has_key_with_lora_suffix = state_dict_has_any_keys_ending_with(
mod.load_state_dict(),
{
"to_k_lora.up.weight",
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
},
)
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
raise NotAMatchError("model does not match LyCORIS LoRA heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
if _get_flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise NotAMatchError(f"unrecognized token vector length {token_vector_length}")
class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class ControlAdapter_Config_Base(ABC, BaseModel):
default_settings: ControlAdapterDefaultSettings | None = Field(None)
class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base):
"""Model config for Control LoRA models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa)
format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS)
trigger_phrases: set[str] | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_control_lora(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not is_state_dict_likely_flux_control(state_dict):
raise NotAMatchError("model state dict does not look like a Flux Control LoRA")
class LoRA_Diffusers_Config_Base(LoRA_Config_Base):
"""Model config for LoRA/Diffusers models."""
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
# the weights format. FLUX Diffusers LoRAs are single files.
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
if _get_flux_lora_format(mod):
return BaseModelType.Flux
# If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
path_to_weight_file = cls._get_weight_file_or_raise(mod)
state_dict = mod.load_state_dict(path_to_weight_file)
token_vector_length = lora_token_vector_length(state_dict)
match token_vector_length:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized token vector length {token_vector_length}")
@classmethod
def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
for wf in weight_files:
if wf.exists():
return wf
raise NotAMatchError("missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)

View File

@@ -0,0 +1,705 @@
from abc import ABC
from typing import Any, Literal, Self
from pydantic import BaseModel, ConfigDict, Field
from invokeai.backend.model_manager.configs.base import (
Checkpoint_Config_Base,
Config_Base,
Diffusers_Config_Base,
SubmodelDefinition,
)
from invokeai.backend.model_manager.configs.clip_embed import get_clip_variant_type_from_config
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_exact,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
FluxVariantType,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
model_config = ConfigDict(extra="forbid")
@classmethod
def from_base(cls, base: BaseModelType) -> Self | None:
match base:
case BaseModelType.StableDiffusion1:
return cls(width=512, height=512)
case BaseModelType.StableDiffusion2:
return cls(width=768, height=768)
case BaseModelType.StableDiffusionXL:
return cls(width=1024, height=1024)
case _:
# TODO(psyche): Do we want defaults for other base types?
return None
class Main_Config_Base(ABC, BaseModel):
type: Literal[ModelType.Main] = Field(default=ModelType.Main)
trigger_phrases: set[str] | None = Field(
default=None,
description="Set of trigger phrases for this model",
)
default_settings: MainModelDefaultSettings | None = Field(
default=None,
description="Default settings for this model",
)
def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool:
bnb_nf4_keys = {
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
"model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4",
}
return any(key in state_dict for key in bnb_nf4_keys)
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
for key in state_dict.keys():
if isinstance(key, int):
continue
elif key.startswith(
(
"cond_stage_model.",
"first_stage_model.",
"model.diffusion_model.",
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
# This prefix is typically used to distinguish between multiple models bundled in a single file.
"model.diffusion_model.double_blocks.",
)
):
return True
elif key.startswith("double_blocks.") and "ip_adapter" not in key:
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be
# careful to avoid false positives on XLabs FLUX IP-Adapter models.
return True
return False
class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_base(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, prediction_type=prediction_type, variant=variant)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
raise NotAMatchError("unable to determine base type from state dict")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
base = cls.model_fields["base"].default
if base is BaseModelType.StableDiffusion2:
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in state_dict:
if state_dict["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif state_dict["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction
else:
return SchedulerPredictionType.Epsilon
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default
state_dict = mod.load_state_dict()
key_name = "model.diffusion_model.input_blocks.0.0.weight"
if key_name not in state_dict:
raise NotAMatchError("unable to determine model variant from state dict")
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'")
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
class Main_Checkpoint_SD1_Config(Main_SD_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class Main_Checkpoint_SD2_Config(Main_SD_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class Main_Checkpoint_SDXL_Config(Main_SD_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
# Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight".
#
# Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
#
# Input channels for known FLUX models:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
in_channels = None
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
if key in state_dict:
in_channels = state_dict[key].shape[1]
break
if in_channels is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
return None
# Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of
# certain keys to distinguish between them.
is_flux_dev = (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
)
if is_flux_dev and in_channels == 384:
return FluxVariantType.DevFill
elif is_flux_dev:
return FluxVariantType.Dev
else:
# Must be a Schnell model...?
return FluxVariantType.Schnell
class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_is_flux(mod)
cls._validate_does_not_look_like_bnb_quantized(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_exact(
mod.load_state_dict(),
{
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
},
):
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if has_bnb_nf4_keys:
raise NotAMatchError("state dict looks like bnb quantized nf4")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if has_ggml_tensors:
raise NotAMatchError("state dict looks like GGUF quantized")
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_model_looks_like_bnb_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if not has_bnb_nf4_keys:
raise NotAMatchError("state dict does not look like bnb quantized nf4")
class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
state_dict = mod.load_state_dict()
variant = _get_flux_variant(state_dict)
if variant is None:
# TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant,
# but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX
# model, we should figure out a good fallback value.
raise NotAMatchError("unable to determine model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml_tensors:
raise NotAMatchError("state dict does not look like GGUF quantized")
class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
prediction_type: SchedulerPredictionType = Field()
variant: ModelVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
# SD 1.x and 2.x
"StableDiffusionPipeline",
"StableDiffusionInpaintPipeline",
# SDXL
"StableDiffusionXLPipeline",
"StableDiffusionXLInpaintPipeline",
# SDXL Refiner
"StableDiffusionXLImg2ImgPipeline",
# TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase.
"LatentConsistencyModelPipeline",
},
)
cls._validate_base(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
variant=variant,
prediction_type=prediction_type,
repo_variant=repo_variant,
)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
unet_conf = get_config_dict_or_raise(mod.path / "unet" / "config.json")
cross_attention_dim = unet_conf.get("cross_attention_dim")
match cross_attention_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXLRefiner
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized cross_attention_dim {cross_attention_dim}")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType:
scheduler_conf = get_config_dict_or_raise(mod.path / "scheduler" / "scheduler_config.json")
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
match prediction_type:
case "v_prediction":
return SchedulerPredictionType.VPrediction
case "epsilon":
return SchedulerPredictionType.Epsilon
case _:
raise NotAMatchError(f"unrecognized scheduler prediction_type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType:
base = cls.model_fields["base"].default
unet_config = get_config_dict_or_raise(mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
# Only SD2 has a depth variant
assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'"
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatchError(f"unrecognized unet in_channels {in_channels} for base '{base}'")
class Main_Diffusers_SD1_Config(Main_SD_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1)
class Main_Diffusers_SD2_Config(Main_SD_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2)
class Main_Diffusers_SDXL_Config(Main_SD_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL)
class Main_Diffusers_SDXLRefiner_Config(Main_SD_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner)
class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3)
submodels: dict[SubModelType, SubmodelDefinition] | None = Field(
description="Loadable submodels in this model",
default=None,
)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"StableDiffusion3Pipeline",
"SD3Transformer2DModel",
},
)
submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
submodels=submodels,
repo_variant=repo_variant,
)
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
config = get_config_dict_or_raise(common_config_paths(mod.path))
submodels: dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
# Anything that starts with an underscore is top-level metadata, not a submodel
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
# The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
# The value value is something like ["diffusers", "SD3Transformer2DModel"]
_library_name, class_name = value
match class_name:
case "CLIPTextModelWithProjection":
model_type = ModelType.CLIPEmbed
path_or_prefix = (mod.path / key).resolve().as_posix()
# We need to read the config to determine the variant of the CLIP model.
clip_embed_config = get_config_dict_or_raise(
{
mod.path / key / "config.json",
mod.path / key / "model_index.json",
}
)
variant = get_clip_variant_type_from_config(clip_embed_config)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case "SD3Transformer2DModel":
model_type = ModelType.Main
path_or_prefix = (mod.path / key).resolve().as_posix()
variant = None
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case _:
pass
return submodels
class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"CogView4Pipeline",
},
)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
repo_variant=repo_variant,
)
class ExternalAPI_Config_Base(ABC, BaseModel):
"""Model config for API-based models."""
format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise NotAMatchError("External API models cannot be built from disk")
class Main_ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o)
class Main_ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5)
class Main_ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3)
class Main_ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4)
class Main_ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base):
base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext)
class Video_Config_Base(ABC, BaseModel):
type: Literal[ModelType.Video] = Field(default=ModelType.Video)
trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None)
default_settings: MainModelDefaultSettings | None = Field(
description="Default settings for this model", default=None
)
class Video_ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
base: Literal[BaseModelType.Veo3] = Field(default=BaseModelType.Veo3)
class Video_ExternalAPI_Runway_Config(ExternalAPI_Config_Base, Video_Config_Base, Config_Base):
base: Literal[BaseModelType.Runway] = Field(default=BaseModelType.Runway)

View File

@@ -0,0 +1,44 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
common_config_paths,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"SiglipModel",
},
)
return cls(**override_fields)

View File

@@ -0,0 +1,54 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
class Spandrel_Checkpoint_Config(Config_Base):
"""Model config for Spandrel Image to Image models."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_spandrel_loads_model(mod)
return cls(**override_fields)
@classmethod
def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None:
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(mod.path)
except Exception as e:
raise NotAMatchError("model does not match SpandrelImageToImage heuristics") from e

View File

@@ -0,0 +1,79 @@
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"T2IAdapter",
},
)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
adapter_type = config_dict.get("adapter_type")
match adapter_type:
case "full_adapter_xl":
return BaseModelType.StableDiffusionXL
case "full_adapter" | "light_adapter":
return BaseModelType.StableDiffusion1
case _:
raise NotAMatchError(f"unrecognized adapter_type '{adapter_type}'")
class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)

View File

@@ -0,0 +1,80 @@
from typing import Any, Literal, Self
from pydantic import Field
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
state_dict_has_any_keys_ending_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
class T5Encoder_T5Encoder_Config(Config_Base):
"""Configuration for T5 Encoder models in a bespoke, diffusers-like format. The model weights are expected to be in
a folder called text_encoder_2 inside the model directory, with a config file named model.safetensors.index.json."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
expected_config_path = mod.path / "text_encoder_2" / "config.json"
expected_class_name = "T5EncoderModel"
raise_for_class_name(expected_config_path, expected_class_name)
cls.raise_if_doesnt_have_unquantized_config_file(mod)
return cls(**override_fields)
@classmethod
def raise_if_doesnt_have_unquantized_config_file(cls, mod: ModelOnDisk) -> None:
has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists()
if not has_unquantized_config:
raise NotAMatchError("missing text_encoder_2/model.safetensors.index.json")
class T5Encoder_BnBLLMint8_Config(Config_Base):
"""Configuration for T5 Encoder models quantized by bitsandbytes' LLM.int8."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder)
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
expected_config_path = mod.path / "text_encoder_2" / "config.json"
expected_class_name = "T5EncoderModel"
raise_for_class_name(expected_config_path, expected_class_name)
cls.raise_if_filename_doesnt_look_like_bnb_quantized(mod)
cls.raise_if_state_dict_doesnt_look_like_bnb_quantized(mod)
return cls(**override_fields)
@classmethod
def raise_if_filename_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix())
if not filename_looks_like_bnb:
raise NotAMatchError("filename does not look like bnb quantized llm_int8")
@classmethod
def raise_if_state_dict_doesnt_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_scb_key_suffix = state_dict_has_any_keys_ending_with(mod.load_state_dict(), "SCB")
if not has_scb_key_suffix:
raise NotAMatchError("state dict does not look like bnb quantized llm_int8")

View File

@@ -0,0 +1,156 @@
from abc import ABC
from pathlib import Path
from typing import (
Literal,
Self,
)
import torch
from pydantic import BaseModel, Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
class TI_Config_Base(ABC, BaseModel):
type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion)
@classmethod
def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None:
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod, path)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool:
try:
p = path or mod.path
if not p.exists():
return False
if p.is_dir():
return False
if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]:
return True
state_dict = mod.load_state_dict(p)
# Heuristic: textual inversion embeddings have these keys
if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()):
return True
# Heuristic: small state dict with all tensor values
if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()):
return True
return False
except Exception:
return False
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType:
p = path or mod.path
try:
state_dict = mod.load_state_dict(p)
except Exception as e:
raise NotAMatchError(f"unable to load state dict from {p}: {e}") from e
try:
if "string_to_token" in state_dict:
token_dim = list(state_dict["string_to_param"].values())[0].shape[-1]
elif "emb_params" in state_dict:
token_dim = state_dict["emb_params"].shape[-1]
elif "clip_g" in state_dict:
token_dim = state_dict["clip_g"].shape[-1]
else:
token_dim = list(state_dict.values())[0].shape[0]
except Exception as e:
raise NotAMatchError(f"unable to determine token dimension from state dict in {p}: {e}") from e
match token_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatchError(f"unrecognized token dimension {token_dim}")
class TI_File_Config_Base(TI_Config_Base):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
if not cls._file_looks_like_embedding(mod):
raise NotAMatchError("model does not look like a textual inversion embedding file")
cls._validate_base(mod)
return cls(**override_fields)
class TI_File_SD1_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_File_SD2_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class TI_Folder_Config_Base(TI_Config_Base):
"""Model config for textual inversion embeddings."""
format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
for p in mod.weight_files():
if cls._file_looks_like_embedding(mod, p):
cls._validate_base(mod, p)
return cls(**override_fields)
raise NotAMatchError("model does not look like a textual inversion embedding folder")
class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)

View File

@@ -0,0 +1,44 @@
from copy import deepcopy
from typing import Any, Literal, Self
from pydantic import Field
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
app_config = get_config()
class Unknown_Config(Config_Base):
"""Model config for unknown models, used as a fallback when we cannot positively identify a model."""
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
"""Create an Unknown_Config for models that couldn't be positively identified.
Note: Basic path validation (file extensions, directory structure) is already
performed by ModelConfigFactory before this method is called.
"""
cloned_override_fields = deepcopy(override_fields)
cloned_override_fields.pop("base", None)
cloned_override_fields.pop("type", None)
cloned_override_fields.pop("format", None)
return cls(
**cloned_override_fields,
# Override the type/format/base to ensure it's marked as unknown.
base=BaseModelType.Unknown,
type=ModelType.Unknown,
format=ModelFormat.Unknown,
)

View File

@@ -0,0 +1,163 @@
import re
from typing import (
Literal,
Self,
)
from pydantic import Field
from typing_extensions import Any
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_config_dict_or_raise,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelType,
)
REGEX_TO_BASE: dict[str, BaseModelType] = {
r"xl": BaseModelType.StableDiffusionXL,
r"sd2": BaseModelType.StableDiffusion2,
r"vae": BaseModelType.StableDiffusion1,
r"FLUX.1-schnell_ae": BaseModelType.Flux,
}
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_vae(mod)
cls._validate_base(mod)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"encoder.conv_in",
"decoder.conv_in",
},
):
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
for regexp, base in REGEX_TO_BASE.items():
if re.search(regexp, mod.path.name, re.IGNORECASE):
return base
raise NotAMatchError("cannot determine base type")
class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"AutoencoderKL",
"AutoencoderTiny",
},
)
# Unfortunately it is difficult to distinguish SD1 and SDXL VAEs by config alone, so we may need to
# guess based on name if the config is inconclusive.
override_name = override_fields.get("name")
cls._validate_base(mod, override_name)
return cls(**override_fields)
@classmethod
def _validate_base(cls, mod: ModelOnDisk, override_name: str | None = None) -> None:
"""Raise `NotAMatch` if the model base does not match this config class."""
expected_base = cls.model_fields["base"].default
recognized_base = cls._get_base_or_raise(mod, override_name)
if expected_base is not recognized_base:
raise NotAMatchError(f"base is {recognized_base}, not {expected_base}")
@classmethod
def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool:
# Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
@classmethod
def _name_looks_like_sdxl(cls, mod: ModelOnDisk, override_name: str | None = None) -> bool:
# Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best
# we can do is guess based on name.
return bool(re.search(r"xl\b", override_name or mod.path.name, re.IGNORECASE))
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk, override_name: str | None = None) -> BaseModelType:
config_dict = get_config_dict_or_raise(common_config_paths(mod.path))
if cls._config_looks_like_sdxl(config_dict):
return BaseModelType.StableDiffusionXL
elif cls._name_looks_like_sdxl(mod, override_name):
return BaseModelType.StableDiffusionXL
else:
# TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO.
return BaseModelType.StableDiffusion1
class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)

File diff suppressed because it is too large Load Diff

View File

@@ -12,9 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
@@ -91,14 +89,6 @@ class LoadedModel(LoadedModelWithoutConfig):
self.config = config
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.
#
# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to
# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented.
class ModelLoaderBase(ABC):
"""Abstract base class for loading models into RAM/VRAM."""

View File

@@ -6,7 +6,8 @@ from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
@@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
model_path = self._get_model_path(model_config)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
cache_record = self._load_and_cache(model_config, submodel_type)
@@ -90,7 +91,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None,
)
# This needs to be implemented in the subclass

View File

@@ -18,10 +18,8 @@ Use like this:
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ModelConfigBase,
)
from invokeai.backend.model_manager.configs.base import Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import ModelLoaderBase
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
@@ -40,7 +38,7 @@ class ModelLoaderRegistryBase(ABC):
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@@ -84,7 +82,7 @@ class ModelLoaderRegistry(ModelLoaderRegistryBase):
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
) -> Tuple[Type[ModelLoaderBase], Config_Base, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type

View File

@@ -3,10 +3,8 @@ from typing import Optional
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_manager.config import (
AnyModelConfig,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
@@ -21,7 +19,7 @@ class ClipVisionLoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, DiffusersConfigBase):
if not isinstance(config, Diffusers_Config_Base):
raise ValueError("Only DiffusersConfigBase models are currently supported here.")
if submodel_type is not None:

View File

@@ -3,11 +3,8 @@ from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
@@ -28,7 +25,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
if isinstance(config, Checkpoint_Config_Base):
raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.")
if submodel_type is None:
@@ -36,7 +33,7 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value

View File

@@ -5,10 +5,8 @@ from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlNetCheckpointConfig,
)
from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
@@ -46,7 +44,7 @@ class ControlNetLoader(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
if isinstance(config, ControlNet_Checkpoint_Config_Base):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,

View File

@@ -33,27 +33,29 @@ from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
FluxReduxConfig,
IPAdapterCheckpointConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base
from invokeai.backend.model_manager.configs.controlnet import (
ControlNet_Checkpoint_Config_Base,
ControlNet_Diffusers_Config_Base,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoint_Config
from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base
from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_FLUX_Config,
Main_GGUF_FLUX_Config,
)
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
FluxVariantType,
ModelFormat,
ModelType,
SubModelType,
@@ -85,12 +87,12 @@ class FluxVAELoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
if not isinstance(config, VAE_Checkpoint_Config_Base):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = AutoEncoder(ae_params[config.config_path])
model = AutoEncoder(get_flux_ae_params())
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
@@ -107,7 +109,7 @@ class FluxVAELoader(ModelLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
class CLIPDiffusersLoader(ModelLoader):
"""Class to load main models."""
def _load_model(
@@ -115,7 +117,7 @@ class ClipCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
if not isinstance(config, CLIPEmbed_Diffusers_Config_Base):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
@@ -138,7 +140,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
if not isinstance(config, T5Encoder_BnBLLMint8_Config):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
@@ -185,7 +187,7 @@ class T5EncoderCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
if not isinstance(config, T5Encoder_T5Encoder_Config):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
@@ -210,7 +212,7 @@ class FluxCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
@@ -225,11 +227,11 @@ class FluxCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
assert isinstance(config, Main_Checkpoint_FLUX_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
@@ -252,7 +254,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
@@ -267,11 +269,11 @@ class FluxGGUFCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainGGUFCheckpointConfig)
assert isinstance(config, Main_GGUF_FLUX_Config)
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
@@ -298,7 +300,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
@@ -313,7 +315,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
assert isinstance(config, Main_BnBNF4_FLUX_Config)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
@@ -322,7 +324,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = Flux(get_flux_transformers_params(config.variant))
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
@@ -341,9 +343,9 @@ class FluxControlnetModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
if isinstance(config, ControlNet_Checkpoint_Config_Base):
model_path = Path(config.path)
elif isinstance(config, ControlNetDiffusersConfig):
elif isinstance(config, ControlNet_Diffusers_Config_Base):
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
else:
@@ -362,7 +364,7 @@ class FluxControlnetModel(ModelLoader):
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(params["flux-dev"])
model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev))
model.load_state_dict(sd, assign=True)
return model
@@ -388,7 +390,7 @@ class FluxIpAdapterModel(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, IPAdapterCheckpointConfig):
if not isinstance(config, IPAdapter_Checkpoint_Config_Base):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))
@@ -411,7 +413,7 @@ class FluxReduxModelLoader(ModelLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, FluxReduxConfig):
if not isinstance(config, FLUXRedux_Checkpoint_Config):
raise ValueError(f"Unexpected model config type: {type(config)}.")
sd = load_file(Path(config.path))

View File

@@ -8,7 +8,8 @@ from typing import Any, Optional
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
from invokeai.backend.model_manager.configs.base import Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
@@ -33,7 +34,7 @@ class GenericDiffusersLoader(ModelLoader):
model_class = self.get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
@@ -56,9 +57,7 @@ class GenericDiffusersLoader(ModelLoader):
module, class_name = config[submodel_type.value]
result = self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
raise ValueError(f'The "{submodel_type}" submodel is not available for this model.') from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
@@ -67,9 +66,9 @@ class GenericDiffusersLoader(ModelLoader):
elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
raise RuntimeError("Unable to decipher Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
raise ValueError("An expected config.json file is missing from this model.") from e
assert result is not None
return result

View File

@@ -7,7 +7,7 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.raw_model import RawModel

View File

@@ -3,9 +3,7 @@ from typing import Optional
from transformers import LlavaOnevisionForConditionalGeneration
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType

View File

@@ -9,7 +9,7 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -30,6 +30,7 @@ from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
@@ -96,15 +97,19 @@ class LoRALoader(ModelLoader):
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
if config.format is ModelFormat.OMI:
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
# is a popular choice. For example, in the diffusers training scripts:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
#
# We assume the same for LyCORIS models in diffusers key format.
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
elif config.format is ModelFormat.LyCORIS:
if is_state_dict_likely_in_flux_diffusers_format(state_dict=state_dict):
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)

View File

@@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (

View File

@@ -3,9 +3,7 @@ from typing import Optional
from transformers import SiglipVisionModel
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType

View File

@@ -3,9 +3,7 @@ from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModelConfig,
)
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType

View File

@@ -4,18 +4,24 @@
from pathlib import Path
from typing import Optional
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import (
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -58,7 +64,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
if isinstance(config, Checkpoint_Config_Base):
return self._load_from_singlefile(config, submodel_type)
if submodel_type is None:
@@ -66,7 +72,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
try:
@@ -107,7 +113,19 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
ModelVariantType.Normal: StableDiffusionXLPipeline,
},
}
assert isinstance(config, MainCheckpointConfig)
assert isinstance(
config,
(
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
)
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:

View File

@@ -4,7 +4,7 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (

View File

@@ -3,9 +3,10 @@
from typing import Optional
from diffusers import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
@@ -27,7 +28,7 @@ class VAELoader(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
if isinstance(config, VAE_Checkpoint_Config_Base):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,

View File

@@ -1,163 +0,0 @@
"""
invokeai.backend.model_manager.merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to the models tables
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Set
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, ModelVariantType
from invokeai.backend.model_manager.config import MainDiffusersConfig
from invokeai.backend.util.devices import TorchDevice
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
"""Wrapper class for model merge function."""
def __init__(self, installer: ModelInstallServiceBase):
"""
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer
self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
variant: Optional[str] = None,
**kwargs: Any,
) -> Any: # pipe.merge is an untyped function.
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
pipe = AutoPipelineForText2Image.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
torch_dtype=dtype,
variant=variant,
) # type: ignore
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
force=force,
torch_dtype=dtype,
variant=variant,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save(
self,
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
force: bool = False,
interp: Optional[MergeInterpolationMethod] = None,
merge_dest_directory: Optional[Path] = None,
variant: Optional[str] = None,
**kwargs: Any,
) -> AnyModelConfig:
"""
:param models: up to three models, designated by their registered InvokeAI model name
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths: List[Path] = []
model_names: List[str] = []
config = self._installer.app_config
store = self._installer.record_store
base_models: Set[BaseModelType] = set()
variant = None if self._installer.app_config.precision == "float32" else "fp16"
assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, (
"When merging three models, only the 'add_difference' merge method is supported"
)
for key in model_keys:
info = store.get_model(key)
model_names.append(info.name)
assert isinstance(info, MainDiffusersConfig), (
f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
)
assert info.variant == ModelVariantType("normal"), (
f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
)
# tally base models used
base_models.add(info.base)
model_paths.extend([config.models_path / info.path])
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
base_model = base_models.pop()
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs)
dump_path = (
Path(merge_dest_directory)
if merge_dest_directory
else config.models_path / base_model.value / ModelType.Main.value
)
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key
key = self._installer.register_path(dump_path)
# update model's config
model_config = self._installer.record_store.get_model(key)
model_config.name = merged_model_name
model_config.description = f"Merge of models {', '.join(model_names)}"
self._installer.record_store.update_model(
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
)
return model_config

View File

@@ -30,7 +30,8 @@ class ModelOnDisk:
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}
self._state_dict_cache: dict[Path, Any] = {}
self._metadata_cache: dict[Path, Any] = {}
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
@@ -44,16 +45,21 @@ class ModelOnDisk:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
return {f for f in self.path.rglob("*") if f.suffix in extensions and f.is_file()}
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = path or self.path
if path in self._metadata_cache:
return self._metadata_cache[path]
try:
with safe_open(self.path, framework="pt", device="cpu") as f:
metadata = f.metadata()
assert isinstance(metadata, dict)
return metadata
except Exception:
return {}
metadata = {}
self._metadata_cache[path] = metadata
return metadata
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
@@ -73,10 +79,8 @@ class ModelOnDisk:
return ModelRepoVariant.Default
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]
if path in sd_cache:
return sd_cache[path]
if path in self._state_dict_cache:
return self._state_dict_cache[path]
path = self.resolve_weight_file(path)
@@ -111,7 +115,7 @@ class ModelOnDisk:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
self._state_dict_cache[path] = state_dict
return state_dict
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:

View File

@@ -0,0 +1,93 @@
from dataclasses import dataclass
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
@dataclass(frozen=True)
class LegacyConfigKey:
type: ModelType
base: BaseModelType
variant: ModelVariantType | None = None
pred: SchedulerPredictionType | None = None
@classmethod
def from_model_config(cls, config: AnyModelConfig) -> "LegacyConfigKey":
variant = getattr(config, "variant", None)
pred = getattr(config, "prediction_type", None)
return cls(type=config.type, base=config.base, variant=variant, pred=pred)
LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = {
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Normal,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v1-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Normal,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v1-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion1,
ModelVariantType.Inpaint,
): "stable-diffusion/v1-inpainting-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Normal,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v2-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Normal,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v2-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Inpaint,
SchedulerPredictionType.Epsilon,
): "stable-diffusion/v2-inpainting-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Inpaint,
SchedulerPredictionType.VPrediction,
): "stable-diffusion/v2-inpainting-inference-v.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusion2,
ModelVariantType.Depth,
): "stable-diffusion/v2-midas-inference.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXL,
ModelVariantType.Normal,
): "stable-diffusion/sd_xl_base.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXL,
ModelVariantType.Inpaint,
): "stable-diffusion/sd_xl_inpaint.yaml",
LegacyConfigKey(
ModelType.Main,
BaseModelType.StableDiffusionXLRefiner,
ModelVariantType.Normal,
): "stable-diffusion/sd_xl_refiner.yaml",
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml",
LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml",
LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml",
}

View File

@@ -37,16 +37,20 @@ cyberrealistic_negative = StarterModel(
)
# region CLIP Image Encoders
# This is CLIP-ViT-H-14-laion2B-s32B-b79K
ip_adapter_sd_image_encoder = StarterModel(
name="IP Adapter SD1.5 Image Encoder",
base=BaseModelType.StableDiffusion1,
base=BaseModelType.Any,
source="InvokeAI/ip_adapter_sd_image_encoder",
description="IP Adapter SD Image Encoder",
type=ModelType.CLIPVision,
)
# This is CLIP-ViT-bigG-14-laion2B-39B-b160k
ip_adapter_sdxl_image_encoder = StarterModel(
name="IP Adapter SDXL Image Encoder",
base=BaseModelType.StableDiffusionXL,
base=BaseModelType.Any,
source="InvokeAI/ip_adapter_sdxl_image_encoder",
description="IP Adapter SDXL Image Encoder",
type=ModelType.CLIPVision,

View File

@@ -1,38 +1,70 @@
from enum import Enum
from typing import Dict, TypeAlias, Union
import diffusers
import onnxruntime as ort
import torch
from diffusers import ModelMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from pydantic import TypeAdapter
from invokeai.backend.raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
AnyModel: TypeAlias = Union[
ModelMixin,
RawModel,
torch.nn.Module,
Dict[str, torch.Tensor],
DiffusionPipeline,
ort.InferenceSession,
]
"""Type alias for any kind of runtime, in-memory model representation. For example, a torch module or diffusers pipeline."""
class BaseModelType(str, Enum):
"""Base model type."""
"""An enumeration of base model architectures. For example, Stable Diffusion 1.x, Stable Diffusion 2.x, FLUX, etc.
Every model config must have a base architecture type.
Not all models are associated with a base architecture. For example, CLIP models are their own thing, not related
to any particular model architecture. To simplify internal APIs and make it easier to work with models, we use a
fallback/null value `BaseModelType.Any` for these models, instead of making the model base optional."""
Any = "any"
"""`Any` is essentially a fallback/null value for models with no base architecture association.
For example, CLIP models are not related to Stable Diffusion, FLUX, or any other model arch."""
StableDiffusion1 = "sd-1"
"""Indicates the model is associated with the Stable Diffusion 1.x model architecture, including 1.4 and 1.5."""
StableDiffusion2 = "sd-2"
"""Indicates the model is associated with the Stable Diffusion 2.x model architecture, including 2.0 and 2.1."""
StableDiffusion3 = "sd-3"
"""Indicates the model is associated with the Stable Diffusion 3.5 model architecture."""
StableDiffusionXL = "sdxl"
"""Indicates the model is associated with the Stable Diffusion XL model architecture."""
StableDiffusionXLRefiner = "sdxl-refiner"
"""Indicates the model is associated with the Stable Diffusion XL Refiner model architecture."""
Flux = "flux"
"""Indicates the model is associated with FLUX.1 model architecture, including FLUX Dev, Schnell and Fill."""
CogView4 = "cogview4"
"""Indicates the model is associated with CogView 4 model architecture."""
Imagen3 = "imagen3"
"""Indicates the model is associated with Google Imagen 3 model architecture. This is an external API model."""
Imagen4 = "imagen4"
"""Indicates the model is associated with Google Imagen 4 model architecture. This is an external API model."""
Gemini2_5 = "gemini-2.5"
"""Indicates the model is associated with Google Gemini 2.5 Flash Image model architecture. This is an external API model."""
ChatGPT4o = "chatgpt-4o"
"""Indicates the model is associated with OpenAI ChatGPT 4o Image model architecture. This is an external API model."""
FluxKontext = "flux-kontext"
"""Indicates the model is associated with FLUX Kontext model architecture. This is an external API model; local FLUX
Kontext models use the base `Flux`."""
Veo3 = "veo3"
"""Indicates the model is associated with Google Veo 3 video model architecture. This is an external API model."""
Runway = "runway"
"""Indicates the model is associated with Runway video model architecture. This is an external API model."""
Unknown = "unknown"
"""Indicates the model's base architecture is unknown."""
class ModelType(str, Enum):
@@ -55,6 +87,7 @@ class ModelType(str, Enum):
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
Video = "video"
Unknown = "unknown"
class SubModelType(str, Enum):
@@ -90,6 +123,12 @@ class ModelVariantType(str, Enum):
Depth = "depth"
class FluxVariantType(str, Enum):
Schnell = "schnell"
Dev = "dev"
DevFill = "dev_fill"
class ModelFormat(str, Enum):
"""Storage format of model."""
@@ -107,6 +146,7 @@ class ModelFormat(str, Enum):
BnbQuantizednf4b = "bnb_quantized_nf4b"
GGUFQuantized = "gguf_quantized"
Api = "api"
Unknown = "unknown"
class SchedulerPredictionType(str, Enum):
@@ -146,4 +186,7 @@ class FluxLoRAFormat(str, Enum):
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType](
ModelVariantType | ClipVariantType | FluxVariantType
)

View File

@@ -8,7 +8,8 @@ from typing import Any, Dict, Optional, Set, Tuple
from PIL import Image
from invokeai.app.util.thumbnails import make_thumbnail
from invokeai.backend.model_manager.config import AnyModelConfig, ModelType
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.taxonomy import ModelType
logger = logging.getLogger(__name__)

View File

@@ -83,14 +83,14 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
return checkpoint
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
def lora_token_vector_length(checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]:
lora_token_vector_length = None
if "." not in key:
@@ -136,6 +136,8 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if isinstance(key, int):
continue
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (

View File

@@ -5,10 +5,10 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Generator, Iterator, List, Optional, Tuple, Type, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
@@ -146,7 +146,7 @@ class ModelPatcher:
cls,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
clip_skip: int,
) -> None:
) -> Generator[None, Any, Any]:
skipped_layers = []
try:
for _i in range(clip_skip):
@@ -164,7 +164,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> None:
) -> Generator[None, Any, Any]:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?

View File

@@ -12,7 +12,10 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
def is_state_dict_likely_in_flux_aitoolkit_format(
state_dict: dict[str | int, Any],
metadata: dict[str, Any] | None = None,
) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
@@ -20,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str))
@dataclass

View File

@@ -18,14 +18,16 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
all_keys_match = all(
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str)
)
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)

View File

@@ -9,14 +9,16 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_L
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
all_keys_in_peft_format = all(
k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str)
)
# Check if keys use transformer prefix
transformer_prefix_keys = [

View File

@@ -44,7 +44,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
if isinstance(k, str)
)

View File

@@ -40,7 +40,7 @@ FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = (
)
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
if isinstance(k, str)
)

View File

@@ -1,3 +1,5 @@
from typing import Any
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
@@ -14,7 +16,10 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
def flux_format_from_state_dict(
state_dict: dict[str | int, Any],
metadata: dict[str, Any] | None = None,
) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):

View File

@@ -4,7 +4,8 @@ import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
@@ -22,7 +23,7 @@ def main():
with log_time("Initialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():

View File

@@ -7,7 +7,8 @@ import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@@ -35,7 +36,7 @@ def main():
# inference_dtype = torch.bfloat16
with log_time("Initialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
p = get_flux_transformers_params(ModelVariantType.FluxSchnell)
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():

View File

@@ -23,6 +23,7 @@ from diffusers.models.unets.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from torch import nn
from invokeai.backend.model_manager.taxonomy import BaseModelType, SchedulerPredictionType
from invokeai.backend.util.logging import InvokeAILogger
# TODO: create PR to diffusers
@@ -407,7 +408,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
upcast_attention=unet.config.base is BaseModelType.StableDiffusion2
and unet.config.prediction_type is SchedulerPredictionType.VPrediction,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,

View File

@@ -7,7 +7,8 @@ import torch
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@pytest.fixture(scope="session")

View File

@@ -914,6 +914,9 @@
"hfTokenReset": "HF Token Reset",
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"unidentifiedModelTitle": "Unable to identify model",
"unidentifiedModelMessage": "We were unable to identify the type, base and/or format of the installed model. Try editing the model and selecting the appropriate settings for the model.",
"unidentifiedModelMessage2": "If you don't see the correct settings, or the model doesn't work after changing them, ask for help on <DiscordLink /> or create an issue on <GitHubIssuesLink />.",
"imageEncoderModelId": "Image Encoder Model ID",
"installedModelsCount": "{{installed}} of {{total}} models installed.",
"includesNModels": "Includes {{n}} models and their dependencies.",
@@ -942,6 +945,7 @@
"modelConverted": "Model Converted",
"modelDeleted": "Model Deleted",
"modelDeleteFailed": "Failed to delete model",
"modelFormat": "Model Format",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
@@ -949,6 +953,7 @@
"modelManager": "Model Manager",
"modelName": "Model Name",
"modelSettings": "Model Settings",
"modelSettingsWarning": "These settings tell Invoke what kind of model this is and how to load it. If Invoke didn't detect these correctly when you installed the model, or if the model is classified as Unknown, you may need to edit them manually.",
"modelType": "Model Type",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",

View File

@@ -11,8 +11,8 @@ import {
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models';
import { modelSelected } from 'features/parameters/store/actions';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/parameters/types/constants';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';

View File

@@ -37,7 +37,7 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isCLIPEmbedModelConfigOrSubmodel,
isControlLayerModelConfig,
isControlNetModelConfig,
isFluxReduxModelConfig,
@@ -48,7 +48,7 @@ import {
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT5EncoderModelConfig,
isT5EncoderModelConfigOrSubmodel,
isVideoModelConfig,
} from 'services/api/types';
import type { JsonObject } from 'type-fest';
@@ -418,7 +418,7 @@ const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) =
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfigOrSubmodel(m));
// If the currently selected model is available, we don't need to do anything
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
@@ -446,7 +446,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfigOrSubmodel(m));
// If the currently selected model is available, we don't need to do anything
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {

Some files were not shown because too many files have changed in this diff Show More