Compare commits

...

27 Commits

Author SHA1 Message Date
Ryan Dick
109cbb8532 Update the default Model Cache behavior to be more conservative with RAM usage. 2025-01-13 18:48:52 +00:00
psychedelicious
d88b59c5c4 Revert "feat(ui): rearrange canvas paste back nodes to save an image step"
This reverts commit 7cdda00a54.
2025-01-10 15:59:29 +11:00
Simon Fuhrmann
1c7adb5c70 Update communityNodes.md - Fix broken image
The image under https://invoke-ai.github.io/InvokeAI/nodes/communityNodes/#stereogram-nodes is broken. Changing img src to fix.
2025-01-09 07:29:02 -05:00
psychedelicious
8da9d3bc19 chore: bump version to v5.6.0rc2 2025-01-09 14:12:46 +11:00
psychedelicious
d9c099bd3a docs: fix incorrect macOS launcher fix command 2025-01-09 11:26:59 +11:00
psychedelicious
a329588e5a feat: add link to low vram guide to OOM toast (local only)
Needed to do a bit of refactoring to support this. Overall, the error toast components are easier to understand now.
2025-01-09 11:20:05 +11:00
psychedelicious
e09cf64779 feat: more updates to first run view 2025-01-09 11:20:05 +11:00
psychedelicious
fc8cf224ca docs: typo 2025-01-09 11:20:05 +11:00
psychedelicious
3e1ed18a1f Update docs/features/low-vram.md
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2025-01-09 11:20:05 +11:00
psychedelicious
9a84c85486 docs: add section about disabling the sysmem fallback 2025-01-09 11:20:05 +11:00
psychedelicious
e6deaa2d2f feat(ui): minor layout tweaks for first run screen 2025-01-09 11:20:05 +11:00
psychedelicious
5246b31347 feat(ui): add low vram link to first run page 2025-01-09 11:20:05 +11:00
psychedelicious
b15dd00840 docs: add docs for low vram mode 2025-01-09 11:20:05 +11:00
psychedelicious
8808c36028 docs: update example yaml file 2025-01-09 11:20:05 +11:00
psychedelicious
89b576f10d fix(ui): prevent canvas & main panel content from scrolling
Hopefully fixes issues where, when run via the launcher, the main panel kinda just scrolls out of bounds.
2025-01-09 09:14:22 +11:00
psychedelicious
d7893a52c3 tweak(ui): whats new copy 2025-01-08 15:26:26 +11:00
Mary Hipp
b9c45c3232 Whats new update 2025-01-08 15:26:26 +11:00
David Burnett
afc9d3b98f more ruff formating 2025-01-07 20:18:19 -05:00
David Burnett
7ddc757bdb ruff format changes 2025-01-07 20:18:19 -05:00
David Burnett
d8da9b45cc Fix for DEIS / DPM clash 2025-01-07 20:18:19 -05:00
Ryan Dick
607d19f4dd We should not trust the value of since the model could be partially-loaded. 2025-01-07 19:22:31 -05:00
psychedelicious
32286f321c docs: note that version is not req for editable install 2025-01-07 17:17:40 -05:00
psychedelicious
03f7bdc9f9 docs: fix manual install rocm pypi indices 2025-01-07 17:17:40 -05:00
Ryan Dick
4df3d0861b Deprecate ram/vram configs for smoother migration path to dynamic limits (#7526)
## Summary

Changes:
- Deprecate `ram` and `vram` configs. If these are set in invokeai.yaml,
they will be ignored.
- Create new `max_cache_ram_gb` and `max_cache_vram_gb` configs with the
same definitions as the old configs.

The main motivation of this change is to make the migration path
smoother for users who had previously added `ram` /`vram` to their
config files. Now, these users will be automatically migrated into the
new dynamic limit behavior (which is better in most cases). These users
will have to manually re-add `max_cache_ram_gb` and `max_cache_vram_gb`
to their configs if they wish to go back to specifying manual limits.

## Related Issues / Discussions

See the release notes for RC v5.6.0rc1 for the old migration behavior
that we are trying to improve:
https://github.com/invoke-ai/InvokeAI/releases/tag/v5.6.0rc1

## QA Instructions

- [x] Test that if `ram` or `vram` are present in a user's
`invokeai.yaml`, these values are ignored.
- [x] Test that `max_cache_ram_gb` and `max_cache_vram_gb` are applied,
if set.

## Merge Plan

- Don't forget to update the RC release notes accordingly.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-01-07 17:03:11 -05:00
Ryan Dick
974b4671b1 Deprecate the ram and vram configs to make the migration to dynamic
memory limits smoother for users who had previously overriden these
values.
2025-01-07 16:45:29 +00:00
Ryan Dick
6b18f270dd Bugfix: Offload of GGML-quantized model in torch.inference_mode() cm (#7525)
## Summary

This PR contains a bugfix for an edge case with model unloading (from
VRAM to RAM). Thanks to @JPPhoto for finding it.

The bug was triggered under the following conditions:
- A GGML-quantized model is loaded in VRAM
- We run a Spandrel image-to-image invocation (which is wrapped in a
`torch.inference_mode()` context manager.
- The model cache attempts to unload the GGML-quantized model from VRAM
to RAM.
- Doing this inside of the `torch.inference_mode()` cm results in the
following error:
```
 [2025-01-07 15:48:17,744]::[InvokeAI]::ERROR --> Error while invoking session 98a07259-0c03-4111-a8d8-107041cb86f9, invocation d8daa90b-7e4c-4fc4-807c-50ba9be1a4ed (spandrel_image_to_image): Cannot set version_counter for inference tensor
[2025-01-07 15:48:17,744]::[InvokeAI]::ERROR --> Traceback (most recent call last):
  File "/home/ryan/src/InvokeAI/invokeai/app/services/session_processor/session_processor_default.py", line 129, in run_node
    output = invocation.invoke_internal(context=context, services=self._services)
  File "/home/ryan/src/InvokeAI/invokeai/app/invocations/baseinvocation.py", line 300, in invoke_internal
    output = self.invoke(context)
  File "/home/ryan/.pyenv/versions/3.10.14/envs/InvokeAI_3.10.14/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ryan/src/InvokeAI/invokeai/app/invocations/spandrel_image_to_image.py", line 167, in invoke
    with context.models.load(self.image_to_image_model) as spandrel_model:
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/load_base.py", line 60, in __enter__
    self._cache.lock(self._cache_record, None)
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/model_cache/model_cache.py", line 224, in lock
    self._load_locked_model(cache_entry, working_mem_bytes)
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/model_cache/model_cache.py", line 272, in _load_locked_model
    vram_bytes_freed = self._offload_unlocked_models(model_vram_needed, working_mem_bytes)
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/model_cache/model_cache.py", line 458, in _offload_unlocked_models
    cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free)
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/model_cache/model_cache.py", line 330, in _move_model_to_ram
    return cache_entry.cached_model.partial_unload_from_vram(
  File "/home/ryan/.pyenv/versions/3.10.14/envs/InvokeAI_3.10.14/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/ryan/src/InvokeAI/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py", line 182, in partial_unload_from_vram
    cur_state_dict = self._model.state_dict()
  File "/home/ryan/.pyenv/versions/3.10.14/envs/InvokeAI_3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1939, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/home/ryan/.pyenv/versions/3.10.14/envs/InvokeAI_3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1936, in state_dict
    self._save_to_state_dict(destination, prefix, keep_vars)
  File "/home/ryan/.pyenv/versions/3.10.14/envs/InvokeAI_3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1843, in _save_to_state_dict
    destination[prefix + name] = param if keep_vars else param.detach()
RuntimeError: Cannot set version_counter for inference tensor
```

### Explanation

From the `torch.inference_mode()` docs:
> Code run under this mode gets better performance by disabling view
tracking and version counter bumps.

Disabling version counter bumps results in the aforementioned error when
saving `GGMLTensor`s to a state_dict.

This incompatibility between `GGMLTensors` and `torch.inference_mode()`
is likely caused by the custom tensor type implementation. There may
very well be a way to get these to cooperate, but for now it is much
simpler to remove the `torch.inference_mode()` contexts.

Note that there are several other uses of `torch.inference_mode()` in
the Invoke codebase, but they are all tight wrappers around the
inference forward pass and do not contain the model load/unload process.

## Related Issues / Discussions

Original discussion:
https://discord.com/channels/1020123559063990373/1149506274971631688/1326180753159094303

## QA Instructions

Find a sequence of operations that triggers the condition. For me, this
was:
- Reserve VRAM in a separate process so that there was ~12GB left.
- Fresh start of Invoke
- Run FLUX inference with a GGML 8K model
- Run Spandrel upscaling

Tests:
- [x] Confirmed that I can reproduce the error and that it is no longer
hit after the change
- [x] Confirm that there is no speed regression from switching from
`torch.inference_mode()` to `torch.no_grad()`.
    - Before: `50.354s`, After: `51.536s`


## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-01-07 11:31:20 -05:00
Ryan Dick
85eb4f0312 Fix an edge case with model offloading from VRAM to RAM. If a GGML-quantized model is offloaded from VRAM inside of a torch.inference_mode() context manager, this will cause the following error: 'RuntimeError: Cannot set version_counter for inference tensor'. 2025-01-07 15:59:50 +00:00
31 changed files with 490 additions and 213 deletions

View File

@@ -39,7 +39,7 @@ It has two sections - one for internal use and one for user settings:
```yaml
# Internal metadata - do not edit:
schema_version: 4
schema_version: 4.0.2
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
host: 0.0.0.0 # serve the app on your local network
@@ -83,6 +83,10 @@ A subset of settings may be specified using CLI args:
- `--root`: specify the root directory
- `--config`: override the default `invokeai.yaml` file location
### Low-VRAM Mode
See the [Low-VRAM mode docs][low-vram] for details on enabling this feature.
### All Settings
Following the table are additional explanations for certain settings.
@@ -185,3 +189,4 @@ The `log_format` option provides several alternative formats:
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
[Model Marketplace API Keys]: #model-marketplace-api-keys
[low-vram]: ./features/low-vram.md

View File

@@ -22,7 +22,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
- Use `.` instead of `invokeai` to install from the current directory.
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
- Add `-e` after the `install` operation to make this an [editable install][editable install link]. That means your changes to the python code will be reflected when you restart the Invoke server.

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

129
docs/features/low-vram.md Normal file
View File

@@ -0,0 +1,129 @@
---
title: Low-VRAM mode
---
As of v5.6.0, Invoke has a low-VRAM mode. It works on systems with dedicated GPUs (Nvidia GPUs on Windows/Linux and AMD GPUs on Linux).
This allows you to generate even if your GPU doesn't have enough VRAM to hold full models. Most users should be able to run even the beefiest models - like the ~24GB unquantised FLUX dev model.
## Enabling Low-VRAM mode
To enable Low-VRAM mode, add this line to your `invokeai.yaml` configuration file, then restart Invoke:
```yaml
enable_partial_loading: true
```
**Windows users should also [disable the Nvidia sysmem fallback](#disabling-nvidia-sysmem-fallback-windows-only)**.
It is possible to fine-tune the settings for best performance or if you still get out-of-memory errors (OOMs).
!!! tip "How to find `invokeai.yaml`"
The `invokeai.yaml` configuration file lives in your install directory. To access it, run the **Invoke Community Edition** launcher and click the install location. This will open your install directory in a file explorer window.
You'll see `invokeai.yaml` there and can edit it with any text editor. After making changes, restart Invoke.
If you don't see `invokeai.yaml`, launch Invoke once. It will create the file on its first startup.
## Details and fine-tuning
Low-VRAM mode involves 3 features, each of which can be configured or fine-tuned:
- Partial model loading
- Dynamic RAM and VRAM cache sizes
- Working memory
Read on to learn about these features and understand how to fine-tune them for your system and use-cases.
### Partial model loading
Invoke's partial model loading works by streaming model "layers" between RAM and VRAM as they are needed.
When an operation needs layers that are not in VRAM, but there isn't enough room to load them, inactive layers are offloaded to RAM to make room.
#### Enabling partial model loading
As described above, you can enable partial model loading by adding this line to `invokeai.yaml`:
```yaml
enable_partial_loading: true
```
### Dynamic RAM and VRAM cache sizes
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
By default, Invoke manages these caches' sizes dynamically for best performance.
#### Fine-tuning cache sizes
Prior to v5.6.0, the cache sizes were static, and for best performance, many users needed to manually fine-tune the `ram` and `vram` settings in `invokeai.yaml`.
As of v5.6.0, the caches are dynamically sized. The `ram` and `vram` settings are no longer used, and new settings are added to configure the cache.
**Most users will not need to fine-tune the cache sizes.**
But, if your GPU has enough VRAM to hold models fully, you might get a perf boost by manually setting the cache sizes in `invokeai.yaml`:
```yaml
# Set the RAM cache size to as large as possible, leaving a few GB free for the rest of your system and Invoke.
# For example, if your system has 32GB RAM, 28GB is a good value.
max_cache_ram_gb: 28
# Set the VRAM cache size to be as large as possible while leaving enough room for the working memory of the tasks you will be doing.
# For example, on a 24GB GPU that will be running unquantized FLUX without any auxiliary models,
# 18GB is a good value.
max_cache_vram_gb: 18
```
!!! tip "Max safe value for `max_cache_vram_gb`"
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
### Working memory
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
Invoke reserves 3GB VRAM as working memory by default, which is enough for most use-cases. However, it is possible to fine-tune this setting if you still get OOMs.
#### Fine-tuning working memory
You can increase the working memory size in `invokeai.yaml` to prevent OOMs:
```yaml
# The default is 3GB - bump it up to 4GB to prevent OOMs.
device_working_mem_gb: 4
```
!!! tip "Operations may request more working memory"
For some operations, we can determine VRAM requirements in advance and allocate additional working memory to prevent OOMs.
VAE decoding is one such operation. This operation converts the generation process's output into an image. For large image outputs, this might use more than the default working memory size of 3GB.
During this decoding step, Invoke calculates how much VRAM will be required to decode and requests that much VRAM from the model manager. If the amount exceeds the working memory size, the model manager will offload cached model layers from VRAM until there's enough VRAM to decode.
Once decoding completes, the model manager "reclaims" the extra VRAM allocated as working memory for future model loading operations.
### Disabling Nvidia sysmem fallback (Windows only)
On Windows, Nvidia GPUs are able to use system RAM when their VRAM fills up via **sysmem fallback**. While it sounds like a good idea on the surface, in practice it causes massive slowdowns during generation.
It is strongly suggested to disable this feature:
- Open the **NVIDIA Control Panel** app.
- Expand **3D Settings** on the left panel.
- Click **Manage 3D Settings** in the left panel.
- Find **CUDA - Sysmem Fallback Policy** in the right panel and set it to **Prefer No Sysmem Fallback**.
![cuda-sysmem-fallback](./cuda-sysmem-fallback.png)
!!! tip "Invoke does the same thing, but better"
If the sysmem fallback feature sounds familiar, that's because Invoke's partial model loading strategy is conceptually very similar - use VRAM when there's room, else fall back to RAM.
Unfortunately, the Nvidia implementation is not optimized for applications like Invoke and does more harm than good.

View File

@@ -75,14 +75,14 @@ The following commands vary depending on the version of Invoke being installed a
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm62`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.1`.
- **In all other cases, do not use an index.**
=== "Invoke v4"
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm52`.
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm5.2`.
- **In all other cases, do not use an index.**
8. Install the `invokeai` package. Substitute the package specifier and version.

View File

@@ -54,7 +54,7 @@ If you have an existing Invoke installation, you can select it and let the launc
- Open the **Invoke-Installer-mac-arm64.dmg** file.
- Drag the launcher to **Applications**.
- Open a terminal.
- Run `xattr -cr /Applications/Invoke-Installer.app`.
- Run `xattr -d 'com.apple.quarantine' /Applications/Invoke\ Community\ Edition.app`.
You should now be able to run the launcher.

View File

@@ -535,7 +535,7 @@ View:
**Node Link:** https://github.com/simonfuhrmann/invokeai-stereo
**Example Workflow and Output**
</br><img src="https://github.com/simonfuhrmann/invokeai-stereo/blob/main/docs/example_promo_03.jpg" width="500" />
</br><img src="https://raw.githubusercontent.com/simonfuhrmann/invokeai-stereo/refs/heads/main/docs/example_promo_03.jpg" width="600" />
--------------------------------
### Simple Skin Detection

View File

@@ -10,7 +10,9 @@ import torchvision.transforms as T
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.adapter import T2IAdapter
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.schedulers.scheduling_tcd import TCDScheduler
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
from PIL import Image
@@ -89,6 +91,7 @@ def get_scheduler(
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -104,6 +107,10 @@ def get_scheduler(
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
if scheduler_class is DPMSolverMultistepScheduler or scheduler_class is DPMSolverSinglestepScheduler:
if scheduler_config["_class_name"] == "DEISMultistepScheduler" and scheduler_config["algorithm_type"] == "deis":
scheduler_config["algorithm_type"] = "dpmsolver++"
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@@ -411,6 +418,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
device: torch.device,
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
@@ -452,7 +460,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
device=device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
@@ -605,6 +613,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
device: torch.device,
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
@@ -655,7 +664,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
width=control_width_resize,
height=control_height_resize,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
device=device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
@@ -946,6 +955,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
device = TorchDevice.choose_torch_device()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -960,6 +970,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context,
self.t2i_adapter,
latents.shape,
device=device,
do_classifier_free_guidance=True,
)
@@ -1006,13 +1017,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
noise = noise.to(device=device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
mask = mask.to(device=device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
@@ -1028,7 +1039,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
device=device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
@@ -1041,6 +1052,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=latents.shape,
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -1058,7 +1070,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -276,7 +276,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context
ip_adapter_fields, context, device=x.device
)
cfg_scale = self.prep_cfg_scale(
@@ -626,6 +626,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
device: torch.device,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
clip_image_processor = CLIPImageProcessor()
@@ -665,11 +666,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds
clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds
pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)

View File

@@ -26,6 +26,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -98,7 +99,7 @@ class ImageToLatentsInvocation(BaseInvocation):
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)

View File

@@ -16,6 +16,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
@@ -39,7 +40,7 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.

View File

@@ -22,6 +22,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
from invokeai.backend.util.devices import TorchDevice
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
@@ -102,7 +103,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)
# Run the model on each tile.
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
@@ -116,9 +117,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
@@ -151,7 +150,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return pil_image
@torch.inference_mode()
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
@@ -197,7 +196,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.inference_mode()
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.

View File

@@ -201,6 +201,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
yield (lora_info.model, lora.weight)
del lora_info
device = TorchDevice.choose_torch_device()
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet) as unet,
@@ -209,9 +210,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
noise = noise.to(device=device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
@@ -225,7 +226,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
device=device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
@@ -238,6 +239,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
@@ -263,7 +265,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,

View File

@@ -82,12 +82,14 @@ class InvokeAIAppConfig(BaseSettings):
profile_graphs: Enable graph profiling using `cProfile`.
profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory.
ram: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
vram: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behaviour is out of beta.
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
@@ -155,12 +157,15 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
ram: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
vram: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behaviour is out of beta.")
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")

View File

@@ -84,8 +84,8 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
max_ram_cache_size_gb=app_config.ram,
max_vram_cache_size_gb=app_config.vram,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),
logger=logger,
)

View File

@@ -8,6 +8,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.util.devices import TorchDevice
class XLabsIPAdapterExtension:
@@ -45,7 +46,7 @@ class XLabsIPAdapterExtension:
) -> torch.Tensor:
clip_image_processor = CLIPImageProcessor()
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
clip_image = clip_image.to(device=TorchDevice.choose_torch_device(), dtype=image_encoder.dtype)
clip_image_embeds = image_encoder(clip_image).image_embeds
return clip_image_embeds

View File

@@ -339,17 +339,16 @@ class ModelCache:
self._delete_cache_entry(cache_entry)
raise
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the cache to use (takes into account the working
memory).
def _get_total_vram_available_to_cache(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the total amount of VRAM available for storing models. I.e. the amount of VRAM available to the
process minus the amount of VRAM to keep for working memory.
"""
# If self._max_vram_cache_size_gb is set, then it overrides the default logic.
if self._max_vram_cache_size_gb is not None:
vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB)
return vram_total_available_to_cache - self._get_vram_in_use()
return int(self._max_vram_cache_size_gb * GB)
working_mem_bytes_default = int(self._execution_device_working_mem_gb * GB)
working_mem_bytes = max(working_mem_bytes or working_mem_bytes_default, working_mem_bytes_default)
working_mem_bytes = max(working_mem_bytes or 0, working_mem_bytes_default)
if self._execution_device.type == "cuda":
# TODO(ryand): It is debatable whether we should use memory_reserved() or memory_allocated() here.
@@ -360,19 +359,28 @@ class ModelCache:
vram_free, _vram_total = torch.cuda.mem_get_info(self._execution_device)
vram_available_to_process = vram_free + vram_allocated
elif self._execution_device.type == "mps":
vram_reserved = torch.mps.driver_allocated_memory()
vram_allocated = torch.mps.driver_allocated_memory()
# TODO(ryand): Is it accurate that MPS shares memory with the CPU?
vram_free = psutil.virtual_memory().available
vram_available_to_process = vram_free + vram_reserved
vram_available_to_process = vram_free + vram_allocated
else:
raise ValueError(f"Unsupported execution device: {self._execution_device.type}")
vram_total_available_to_cache = vram_available_to_process - working_mem_bytes
vram_cur_available_to_cache = vram_total_available_to_cache - self._get_vram_in_use()
return vram_cur_available_to_cache
return vram_available_to_process - working_mem_bytes
def _get_vram_available(self, working_mem_bytes: Optional[int]) -> int:
"""Calculate the amount of additional VRAM available for the model cache to use (takes into account the working
memory).
"""
return self._get_total_vram_available_to_cache(working_mem_bytes) - self._get_vram_in_use()
def _get_vram_in_use(self) -> int:
"""Get the amount of VRAM currently in use by the cache."""
# NOTE(ryand): To be conservative, we are treating the amount of VRAM allocated by torch as entirely being used
# by the model cache. In reality, some of this allocated memory is being used as working memory. This is a
# reasonable conservative assumption, because this function is typically called before (not during)
# working-memory-intensive operations. This conservative definition also helps to handle models whose size
# increased after initial load (e.g. a model whose precision was upcast by application code).
if self._execution_device.type == "cuda":
return torch.cuda.memory_allocated()
elif self._execution_device.type == "mps":
@@ -389,29 +397,71 @@ class ModelCache:
ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB)
return ram_total_available_to_cache - self._get_ram_in_use()
# We have 3 strategies for calculating the amount of RAM available to the cache. We calculate all 3 options and
# then use a heuristic to decide which one to use.
# - Strategy 1: Match RAM cache size to VRAM cache size
# - Strategy 2: Aim to keep at least 10% of RAM free
# - Strategy 3: Use a minimum RAM cache size of 4GB
# ---------------------
# Calculate Strategy 1
# ---------------------
# Under Strategy 1, the RAM cache size is equal to the total VRAM available to the cache. The RAM cache size
# should **roughly** match the VRAM cache size for the following reasons:
# - Setting it much larger than the VRAM cache size means that we would accumulate mmap'ed model files for
# models that are 0% loaded onto the GPU. Accumulating a large amount of virtual memory causes issues -
# particularly on Windows. Instead, we should drop these extra models from the cache and rely on the OS's
# disk caching behavior to make reloading them fast (if there is enough RAM for disk caching to be possible).
# - Setting it much smaller than the VRAM cache size would increase the likelihood that we drop models from the
# cache even if they are partially loaded onto the GPU.
#
# TODO(ryand): In the future, we should re-think this strategy. Setting the RAM cache size like this doesn't
# really make sense, and is done primarily for consistency with legacy behavior. We should be relying on the
# OS's caching behavior more and make decisions about whether to drop models from the cache based primarily on
# how much of the model can be kept in VRAM.
cache_ram_used = self._get_ram_in_use()
if self._execution_device.type == "cpu":
# Strategy 1 is not applicable for CPU.
ram_available_based_on_default_ram_cache_size = 0
else:
default_ram_cache_size_bytes = self._get_total_vram_available_to_cache(None)
ram_available_based_on_default_ram_cache_size = default_ram_cache_size_bytes - cache_ram_used
# ---------------------
# Calculate Strategy 2
# ---------------------
# If RAM memory pressure is high, then we want to be more conservative with the RAM cache size.
virtual_memory = psutil.virtual_memory()
ram_total = virtual_memory.total
ram_available = virtual_memory.available
ram_used = ram_total - ram_available
# The total size of all the models in the cache will often be larger than the amount of RAM reported by psutil
# (due to lazy-loading and OS RAM caching behaviour). We could just rely on the psutil values, but it feels
# like a bad idea to over-fill the model cache. So, for now, we'll try to keep the total size of models in the
# cache under the total amount of system RAM.
cache_ram_used = self._get_ram_in_use()
ram_used = max(cache_ram_used, ram_used)
# Aim to keep 10% of RAM free.
# We aim to keep at least 10% of RAM free.
ram_available_based_on_memory_usage = int(ram_total * 0.9) - ram_used
# If we are running out of RAM, then there's an increased likelihood that we will run into this issue:
# ---------------------
# Calculate Strategy 3
# ---------------------
# If the RAM cache is very small, then there's an increased likelihood that we will run into this issue:
# https://github.com/invoke-ai/InvokeAI/issues/7513
# To keep things running smoothly, there's a minimum RAM cache size that we always allow (even if this means
# using swap).
min_ram_cache_size_bytes = 4 * GB
ram_available_based_on_min_cache_size = min_ram_cache_size_bytes - cache_ram_used
return max(ram_available_based_on_memory_usage, ram_available_based_on_min_cache_size)
# ----------------------------
# Decide which strategy to use
# ----------------------------
# First, take the minimum of strategies 1 and 2.
ram_available = min(ram_available_based_on_default_ram_cache_size, ram_available_based_on_memory_usage)
# Then, apply strategy 3 as the lower bound.
ram_available = max(ram_available, ram_available_based_on_min_cache_size)
self._logger.debug(
f"Calculated RAM available: {ram_available/MB:.2f} MB. Strategies considered (1,2,3): "
f"{ram_available_based_on_default_ram_cache_size/MB:.2f}, "
f"{ram_available_based_on_memory_usage/MB:.2f}, "
f"{ram_available_based_on_min_cache_size/MB:.2f}"
)
return ram_available
def _get_ram_in_use(self) -> int:
"""Get the amount of RAM currently in use."""

View File

@@ -14,6 +14,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.devices import TorchDevice
class ModelPatcher:
@@ -122,7 +123,7 @@ class ModelPatcher:
)
model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
device=TorchDevice.choose_torch_device(), dtype=text_encoder.dtype
)
ti_tokens.append(token_id)

View File

@@ -12,6 +12,7 @@ from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
@@ -89,7 +90,7 @@ class T2IAdapterExt(ExtensionBase):
width=input_width,
height=input_height,
num_channels=model.config["in_channels"],
device=model.device,
device=TorchDevice.choose_torch_device(),
dtype=model.dtype,
resize_mode=self._resize_mode,
)

View File

@@ -1185,6 +1185,7 @@
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDescLocal": "Follow our <LinkComponent>Low VRAM guide</LinkComponent> to reduce OOMs.",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters",
"parameterSet": "Parameter Recalled",
@@ -2133,15 +2134,12 @@
"toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio.",
"downloadStarterModels": "Download Starter Models",
"importModels": "Import Models",
"noModelsInstalled": "It looks like you don't have any models installed"
"lowVRAMMode": "For best performance, follow our <LinkComponent>Low VRAM guide</LinkComponent>.",
"noModelsInstalled": "It looks like you don't have any models installed! You can <DownloadStarterModelsButton>download a starter model bundle</DownloadStarterModelsButton> or <ImportModelsButton>import models</ImportModelsButton>."
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"<StrongComponent>Flux Control Layers</StrongComponent>: New control models for edge detection and depth mapping are now supported for Flux dev models."
],
"items": ["Low-VRAM mode", "Dynamic memory management", "Faster model loading times", "Fewer memory errors"],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -57,6 +57,7 @@ export const CanvasMainPanelContent = memo(() => {
gap={2}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
<CanvasManagerProviderGate>
<CanvasToolbar />
@@ -70,6 +71,7 @@ export const CanvasMainPanelContent = memo(() => {
h="full"
bg={dynamicGrid ? 'base.850' : 'base.900'}
borderRadius="base"
overflow="hidden"
>
<InvokeCanvasComponent />
<CanvasManagerProviderGate>

View File

@@ -46,6 +46,7 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
left={0}
alignItems="center"
justifyContent="center"
overflow="hidden"
>
{hasImageToCompare && <CompareToolbar />}
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}

View File

@@ -1,4 +1,5 @@
import { Button, Divider, Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
@@ -7,9 +8,10 @@ import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallMode
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectIsLocal } from 'features/system/store/configSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiImageBold } from 'react-icons/pi';
import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi';
import { useMainModels } from 'services/api/hooks/modelsByType';
export const NoContentForViewer = memo(() => {
@@ -18,6 +20,105 @@ export const NoContentForViewer = memo(() => {
const isLocal = useAppSelector(selectIsLocal);
const isEnabled = useFeatureStatus('starterModels');
const { t } = useTranslation();
const showStarterBundles = useMemo(() => {
return isEnabled && data && mainModels.length === 0;
}, [mainModels.length, data, isEnabled]);
if (hasImages === LOADING_SYMBOL) {
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
// fetching.
return <LoadingSpinner />;
}
if (hasImages) {
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
}
return (
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center" maxW="600px">
<InvokeLogoIcon w={32} h={32} />
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center">
{isLocal ? <GetStartedLocal /> : <GetStartedCommercial />}
{showStarterBundles && <StarterBundlesCallout />}
<Divider />
<GettingStartedVideosCallout />
{isLocal && <LowVRAMAlert />}
</Flex>
</Flex>
);
});
NoContentForViewer.displayName = 'NoContentForViewer';
const LoadingSpinner = () => {
return (
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
</Flex>
);
};
export const ExternalLink = (props: ButtonProps & { href: string }) => {
return (
<Button
as={Link}
variant="unstyled"
isExternal
display="inline-flex"
alignItems="center"
rightIcon={<PiArrowSquareOutBold />}
color="base.50"
mt={-1}
{...props}
/>
);
};
const InlineButton = (props: PropsWithChildren<{ onClick: () => void }>) => {
return (
<Button variant="link" size="md" onClick={props.onClick} color="base.50">
{props.children}
</Button>
);
};
const StrongComponent = <Text as="span" color="base.50" fontSize="md" />;
const GetStartedLocal = () => {
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStartedLocal" components={{ StrongComponent }} />
</Text>
);
};
const GetStartedCommercial = () => {
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStarted" components={{ StrongComponent }} />
</Text>
);
};
const GettingStartedVideosCallout = () => {
return (
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.gettingStartedSeries"
components={{
LinkComponent: (
<ExternalLink href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO" />
),
}}
/>
</Text>
);
};
const StarterBundlesCallout = () => {
const dispatch = useAppDispatch();
const handleClickDownloadStarterModels = useCallback(() => {
@@ -30,89 +131,31 @@ export const NoContentForViewer = memo(() => {
$installModelsTab.set(0);
}, [dispatch]);
const showStarterBundles = useMemo(() => {
return isEnabled && data && mainModels.length === 0;
}, [mainModels.length, data, isEnabled]);
if (hasImages === LOADING_SYMBOL) {
return (
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
// If we show the logo while loading, there is an awkward layout shift where the invoke logo moves a bit. Less
// jarring to show a blank bg with a spinner - it will only be shown for a moment as we do the initial images
// fetching.
<Flex position="relative" width="full" height="full" alignItems="center" justifyContent="center">
<Spinner label="Loading" color="grey" position="absolute" size="sm" width={8} height={8} right={4} bottom={4} />
</Flex>
);
}
if (hasImages) {
return <IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />;
}
return (
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center" maxW="600px">
<InvokeLogoIcon w={40} h={40} />
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center">
<Text fontSize="md" color="base.200" pt={16}>
{isLocal ? (
<Trans
i18nKey="newUserExperience.toGetStartedLocal"
components={{
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
}}
/>
) : (
<Trans
i18nKey="newUserExperience.toGetStarted"
components={{
StrongComponent: <Text as="span" color="white" fontSize="md" fontWeight="semibold" />,
}}
/>
)}
</Text>
{showStarterBundles && (
<Flex flexDir="column" gap={2} alignItems="center">
<Text fontSize="md" color="base.200">
{t('newUserExperience.noModelsInstalled')}
</Text>
<Flex gap={3} alignItems="center">
<Button size="sm" onClick={handleClickDownloadStarterModels}>
{t('newUserExperience.downloadStarterModels')}
</Button>
<Text fontSize="sm" color="base.200">
{t('common.or')}
</Text>
<Button size="sm" onClick={handleClickImportModels}>
{t('newUserExperience.importModels')}
</Button>
</Flex>
</Flex>
)}
<Divider />
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.gettingStartedSeries"
components={{
LinkComponent: (
<Text
as="a"
color="white"
fontSize="md"
fontWeight="semibold"
href="https://www.youtube.com/playlist?list=PLvWK1Kc8iXGrQy8r9TYg6QdUuJ5MMx-ZO"
target="_blank"
/>
),
}}
/>
</Text>
</Flex>
</Flex>
<Text fontSize="md" color="base.200">
<Trans
i18nKey="newUserExperience.noModelsInstalled"
components={{
DownloadStarterModelsButton: <InlineButton onClick={handleClickDownloadStarterModels} />,
ImportModelsButton: <InlineButton onClick={handleClickImportModels} />,
}}
/>
</Text>
);
});
};
NoContentForViewer.displayName = 'NoContentForViewer';
const LowVRAMAlert = () => {
return (
<Alert status="warning" borderRadius="base" fontSize="md" shadow="md" w="fit-content">
<AlertIcon />
<AlertDescription>
<Trans
i18nKey="newUserExperience.lowVRAMMode"
components={{
LinkComponent: <ExternalLink href="https://invoke-ai.github.io/InvokeAI/features/low-vram/" />,
}}
/>
</AlertDescription>
</Alert>
);
};

View File

@@ -86,6 +86,16 @@ export const addInpaint = async ({
type: 'img_resize',
...scaledSize,
});
const resizeImageToOriginalSize = g.addNode({
id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize',
...originalSize,
});
const resizeMaskToOriginalSize = g.addNode({
id: getPrefixedId('resize_mask_to_original_size'),
type: 'img_resize',
...originalSize,
});
const createGradientMask = g.addNode({
id: getPrefixedId('create_gradient_mask'),
type: 'create_gradient_mask',
@@ -99,11 +109,6 @@ export const addInpaint = async ({
type: 'canvas_v2_mask_and_crop',
mask_blur: params.maskBlur,
});
const resizeOutput = g.addNode({
id: getPrefixedId('resize_output'),
type: 'img_resize',
...originalSize,
});
// Resize initial image and mask to scaled size, feed into to gradient mask
g.addEdge(alphaToMask, 'image', resizeMaskToScaledSize, 'image');
@@ -120,20 +125,21 @@ export const addInpaint = async ({
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
// Paste the generated masked image back onto the original image
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
// After denoising, resize the image and mask back to original size
g.addEdge(l2i, 'image', resizeImageToOriginalSize, 'image');
g.addEdge(createGradientMask, 'expanded_mask_area', resizeMaskToOriginalSize, 'image');
// Finally, resize the output back to the original size
g.addEdge(canvasPasteBack, 'image', resizeOutput, 'image');
// Finally, paste the generated masked image back onto the original image
g.addEdge(resizeImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
g.addEdge(resizeMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
g.addEdge(resizeImageToScaledSize, 'image', canvasPasteBack, 'source_image');
canvasPasteBack.source_image = { image_name: initialImage.image_name };
}
return resizeOutput;
return canvasPasteBack;
} else {
// No scale before processing, much simpler
const i2l = g.addNode({

View File

@@ -131,33 +131,40 @@ export const addOutpaint = async ({
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
// Resize the output image back to the original size
const resizeOutputImageToOriginalSize = g.addNode({
id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize',
...originalSize,
});
const resizeOutputMaskToOriginalSize = g.addNode({
id: getPrefixedId('resize_mask_to_original_size'),
type: 'img_resize',
...originalSize,
});
const canvasPasteBack = g.addNode({
id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop',
mask_blur: params.maskBlur,
});
const resizeOutput = g.addNode({
id: getPrefixedId('resize_output'),
type: 'img_resize',
...originalSize,
});
// Resize initial image and mask to scaled size, feed into to gradient mask
// Paste the generated masked image back onto the original image
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
// After denoising, resize the image and mask back to original size
g.addEdge(l2i, 'image', resizeOutputImageToOriginalSize, 'image');
g.addEdge(createGradientMask, 'expanded_mask_area', resizeOutputMaskToOriginalSize, 'image');
// Finally, resize the output back to the original size
g.addEdge(canvasPasteBack, 'image', resizeOutput, 'image');
// Finally, paste the generated masked image back onto the original image
g.addEdge(resizeOutputImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
g.addEdge(resizeOutputMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
g.addEdge(resizeInputImageToScaledSize, 'image', canvasPasteBack, 'source_image');
canvasPasteBack.source_image = { image_name: initialImage.image_name };
}
return resizeOutput;
return canvasPasteBack;
} else {
infill.image = { image_name: initialImage.image_name };
// No scale before processing, much simpler

View File

@@ -1,39 +1,45 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { t } from 'i18next';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { ExternalLink } from 'features/gallery/components/ImageViewer/NoContentForViewer';
import { useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
function onCopy(sessionId: string) {
navigator.clipboard.writeText(sessionId);
}
const ERROR_TYPE_TO_TITLE: Record<string, string> = {
OutOfMemoryError: 'toast.outOfMemoryError',
};
const COMMERCIAL_ERROR_TYPE_TO_DESC: Record<string, string> = {
OutOfMemoryError: 'toast.outOfMemoryErrorDesc',
};
export const getTitleFromErrorType = (errorType: string) => {
return t(ERROR_TYPE_TO_TITLE[errorType] ?? 'toast.serverError');
};
type Props = { errorType: string; errorMessage?: string | null; sessionId: string; isLocal: boolean };
export default function ErrorToastDescription({ errorType, errorMessage, sessionId, isLocal }: Props) {
export const ErrorToastTitle = ({ errorType }: Props) => {
const { t } = useTranslation();
if (errorType === 'OutOfMemoryError') {
return t('toast.outOfMemoryError');
}
return t('toast.serverError');
};
export default function ErrorToastDescription({ errorType, isLocal, sessionId, errorMessage }: Props) {
const { t } = useTranslation();
const description = useMemo(() => {
// Special handling for commercial error types
const descriptionTKey = isLocal ? null : COMMERCIAL_ERROR_TYPE_TO_DESC[errorType];
if (descriptionTKey) {
return t(descriptionTKey);
}
if (errorMessage) {
if (errorType === 'OutOfMemoryError') {
if (isLocal) {
return (
<Trans
i18nKey="toast.outOfMemoryErrorDescLocal"
components={{
LinkComponent: <ExternalLink href="https://invoke-ai.github.io/InvokeAI/features/low-vram/" />,
}}
/>
);
} else {
return t('toast.outOfMemoryErrorDesc');
}
} else if (errorMessage) {
return `${errorType}: ${errorMessage}`;
}
}, [errorMessage, errorType, isLocal, t]);
const copySessionId = useCallback(() => navigator.clipboard.writeText(sessionId), [sessionId]);
return (
<Flex flexDir="column">
{description && (
@@ -50,14 +56,12 @@ export default function ErrorToastDescription({ errorType, errorMessage, session
size="sm"
aria-label="Copy"
icon={<PiCopyBold />}
onClick={onCopy.bind(null, sessionId)}
onClick={copySessionId}
variant="ghost"
sx={sx}
sx={{ svg: { fill: 'base.50' } }}
/>
</Flex>
)}
</Flex>
);
}
const sx = { svg: { fill: 'base.50' } };

View File

@@ -117,7 +117,7 @@ export const AppContent = memo(() => {
});
return (
<Flex id="invoke-app-tabs" w="full" h="full" gap={4} p={4}>
<Flex id="invoke-app-tabs" w="full" h="full" gap={4} p={4} overflow="hidden">
<VerticalNavBar />
<PanelGroup
ref={imperativePanelGroupRef}

View File

@@ -8,7 +8,7 @@ import type { AppStore } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import ErrorToastDescription, { ErrorToastTitle } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach, isNil, round } from 'lodash-es';
@@ -400,7 +400,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
title: (
<ErrorToastTitle
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
status: 'error',
duration: null,
updateDescription: isLocal,

View File

@@ -1 +1 @@
__version__ = "5.6.0rc1"
__version__ = "5.6.0rc2"

View File

@@ -137,6 +137,7 @@ nav:
- Invocation API: 'nodes/invocation-api.md'
- Configuration: 'configuration.md'
- Features:
- Low VRAM mode: 'features/low-vram.md'
- Database: 'features/database.md'
- New to InvokeAI?: 'help/gettingStartedWithAI.md'
- Contributing:

View File

@@ -94,8 +94,8 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb,
enable_partial_loading=mm2_app_config.enable_partial_loading,
max_ram_cache_size_gb=mm2_app_config.ram,
max_vram_cache_size_gb=mm2_app_config.vram,
max_ram_cache_size_gb=mm2_app_config.max_cache_ram_gb,
max_vram_cache_size_gb=mm2_app_config.max_cache_vram_gb,
execution_device=TorchDevice.choose_torch_device(),
logger=InvokeAILogger.get_logger(),
)