Get probing of FLUX LoRA kohya models working.

This commit is contained in:
Ryan Dick
2024-09-04 16:03:55 +00:00
parent e4cca62a90
commit ceb5d50568
2 changed files with 14 additions and 3 deletions

View File

@@ -20,6 +20,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.peft.conversions.flux_lora_conversion_utils import convert_flux_kohya_state_dict_to_invoke_format
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.peft.lora import LoRAModelRaw
@@ -56,9 +57,16 @@ class LoRALoader(ModelLoader):
else:
state_dict = torch.load(model_path, map_location="cpu")
# TODO(ryand): Add conversions for other base models and raise if an unsupported base model is used.
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
elif self._model_base == BaseModelType.Flux:
state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
pass
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
return LoRAModelRaw.from_state_dict(state_dict=state_dict, dtype=self._torch_dtype)

View File

@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.peft.conversions.flux_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -557,9 +558,11 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint):
return BaseModelType.Flux
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024: