Compare commits

...

7 Commits

Author SHA1 Message Date
Ryan Dick
dcf11a01ce Add TODO comment about peformance bottleneck in LoRA loading code. 2024-04-04 11:33:27 -04:00
Ryan Dick
6e4de001f1 Remove line that was intended to save memory, but wasn't actually having any effect. 2024-04-04 11:29:32 -04:00
Ryan Dick
4af258615f Improve the robustness of the logic for determining the PEFT model type. In particular, so that it doesn't incorrectly detect DoRA models as LoRA models. 2024-04-04 11:10:09 -04:00
psychedelicious
132aadca15 fix(ui): cancel batch status button greyed out
Closes #6110
2024-04-03 08:23:31 -04:00
blessedcoolant
8584171a49 docs: fix broken link (#6116)
## Summary

Fix a broken link

## Related Issues / Discussions


https://discord.com/channels/1020123559063990373/1049495067846524939/1224970148058763376

## QA Instructions

n/a

## Merge Plan

n/a

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_ n/a
- [x] _Documentation added / updated (if applicable)_
2024-04-03 12:35:17 +05:30
psychedelicious
50951439bd docs: fix broken link 2024-04-03 17:36:15 +11:00
psychedelicious
7b93b554d7 fix(ui): add default coherence mode to generation slice migration
The valid values for this parameter changed when inpainting changed to gradient denoise. The generation slice's redux migration wasn't updated, resulting in a generation error until you change the setting or reset web UI.
2024-04-03 08:46:31 +11:00
4 changed files with 64 additions and 30 deletions

View File

@@ -32,5 +32,5 @@ As described in the [frontend dev toolchain] docs, you can run the UI using a de
[Fork and clone]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo
[InvokeAI repo]: https://github.com/invoke-ai/InvokeAI
[frontend dev toolchain]: ../contributing/frontend/OVERVIEW.md
[manual installation]: installation/020_INSTALL_MANUAL.md
[manual installation]: ./020_INSTALL_MANUAL.md
[editable install]: https://pip.pypa.io/en/latest/cli/pip_install/#cmdoption-e

View File

@@ -3,7 +3,7 @@
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from safetensors.torch import load_file
@@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
return new_state_dict
@classmethod
def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool:
"""Check if the set of keys matches the required and optional keys."""
if len(required_keys - keys) > 0:
# missing required keys.
return False
non_required_keys = keys - required_keys
for k in non_required_keys:
if k not in optional_keys:
# unexpected key
return False
return True
@classmethod
def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]:
"""Infer the parameter-efficient finetuning model type from the state dict keys."""
common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"}
if cls._keys_match(
peft_layer_keys,
required_keys={"lora_down.weight", "lora_up.weight"},
optional_keys=common_optional_keys | {"lora_mid.weight"},
):
return LoRALayer
if cls._keys_match(
peft_layer_keys,
required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"},
optional_keys=common_optional_keys | {"hada_t1", "hada_t2"},
):
return LoHALayer
if cls._keys_match(
peft_layer_keys,
required_keys=set(),
optional_keys=common_optional_keys
| {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"},
):
return LoKRLayer
if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys):
return FullLayer
if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys):
return IA3Layer
raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}")
@classmethod
def from_checkpoint(
cls,
@@ -486,37 +535,21 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
# We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for
# each layer.
first_module_key = next(iter(state_dict))
peft_layer_keys = set(state_dict[first_module_key].keys())
layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer = layer_cls(layer_key, values)
# TODO(ryand): This .to() call causes an implicit CUDA sync point in a tight loop. This is very slow (even
# slower than loading the weights from disk). We should ideally only be copying the weights once - right
# before they are used. Or, if we want to do this here, then setting non_blocking = True would probably
# help.
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod

View File

@@ -280,6 +280,7 @@ const migrateGenerationState = (state: any): GenerationState => {
// The signature of the model has changed, so we need to reset it
state._version = 2;
state.model = null;
state.canvasCoherenceMode = initialGenerationState.canvasCoherenceMode;
}
return state;
};

View File

@@ -192,7 +192,7 @@ export const queueApi = api.injectEndpoints({
{ batch_id: string }
>({
query: ({ batch_id }) => ({
url: buildQueueUrl(`/b/${batch_id}/status`),
url: buildQueueUrl(`b/${batch_id}/status`),
method: 'GET',
}),
providesTags: (result) => {