Fix #1362 by improving VRAM usage patterns when doing .swap()

commit ef3f7a26e242b73c2beb0195c7fd8f654ef47f55
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:18:37 2022 +0100

    remove log spam

commit 7189d649622d4668b120b0dd278388ad672142c4
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:10:28 2022 +0100

    change the way saved slicing strategy is applied

commit 01c40f751ab72955140165c16f95ae411732265b
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 12:04:43 2022 +0100

    fix slicing_strategy_getter callsite

commit f8cfe25150a346958903316bc710737d99839923
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 11:56:22 2022 +0100

    cleanup, consistent dim=0 also tested

commit 5bf9b1e890d48e962afd4a668a219b68271e5dc1
Author: damian0815 <null@damianstewart.com>
Date:   Tue Nov 8 11:34:09 2022 +0100

    refactored context, tested with non-sliced cross attention control

commit d58a46e39bf562e7459290d2444256e8c08ad0b6
Author: damian0815 <null@damianstewart.com>
Date:   Sun Nov 6 00:41:52 2022 +0100

    cleanup

commit 7e2c658b4c06fe239311b65b9bb16fa3adec7fd7
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:57:31 2022 +0100

    disable logs

commit 20ee89d93841b070738b3d8a4385c93b097d92eb
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:36:58 2022 +0100

    slice saved attention if necessary

commit 0a7684a22c880ec0f48cc22bfed4526358f71546
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:32:38 2022 +0100

    raise instead of asserting

commit 7083104c7f3a0d8fd96e94a2f391de50a3c942e4
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:31:00 2022 +0100

    store dim when saving slices

commit f7c0808ed383ec1dc70645288a798ed2aa4fa85c
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:27:16 2022 +0100

    don't retry on exception

commit 749a721e939b3fe7c1741e7998dab6bd2c85a0cb
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:24:50 2022 +0100

    stuff

commit 032ab90e9533be8726301ec91b97137e2aadef9a
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:20:17 2022 +0100

    more logging

commit 3dc34b387f033482305360e605809d95a40bf6f8
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:16:47 2022 +0100

    logs

commit 901c4c1aa4b9bcef695a6551867ec8149e6e6a93
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:12:39 2022 +0100

    actually set save_slicing_strategy to True

commit f780e0a0a7c6b6a3db320891064da82589358c8a
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 22:10:35 2022 +0100

    store slicing strategy

commit 93bb6d566fd18c5c69ef7dacc8f74ba2cf671cb7
Author: damian <git@damianstewart.com>
Date:   Sat Nov 5 20:43:48 2022 +0100

    still not it

commit 5e3a9541f8ae00bde524046963910323e20c40b7
Author: damian <git@damianstewart.com>
Date:   Sat Nov 5 17:20:02 2022 +0100

    wip offloading attention slices on-demand

commit 4c2966aa856b6f3b446216da3619ae931552ef08
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 15:47:40 2022 +0100

    pre-emptive offloading, idk if it works

commit 572576755e9f0a878d38e8173e485126c0efbefb
Author: root <you@example.com>
Date:   Sat Nov 5 11:25:32 2022 +0000

    push attention slices to cpu. slow but saves memory.

commit b57c83a68f2ac03976ebc89ce2ff03812d6d185f
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 12:04:22 2022 +0100

    verbose logging

commit 3a5dae116f110a96585d9eb71d713b5ed2bc3d2b
Author: damian0815 <null@damianstewart.com>
Date:   Sat Nov 5 11:50:48 2022 +0100

    wip fixing mem strategy crash (4 test on runpod)

commit 3cf237db5fae0c7b0b4cc3c47c81830bdb2ae7de
Author: damian0815 <null@damianstewart.com>
Date:   Fri Nov 4 09:02:40 2022 +0100

    wip, only works on cuda
This commit is contained in:
damian0815
2022-11-08 12:59:34 +01:00
committed by Lincoln Stein
parent 5702271991
commit 71bbfe4a1a
3 changed files with 234 additions and 143 deletions

View File

@@ -1,9 +1,11 @@
import traceback
from math import ceil
from typing import Callable, Optional, Union
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
from ldm.modules.attention import get_mem_free_total
class InvokeAIDiffuserComponent:
@@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent:
"""
self.model = model
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
@@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent:
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context)
def remove_cross_attention_control(self):
self.conditioning = None
@@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent:
CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor,dict],
conditioning: Union[torch.Tensor,dict],
@@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
context: CrossAttentionControl.Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
wants_hybrid_conditioning = isinstance(conditioning, dict)
@@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
@@ -134,31 +133,29 @@ class InvokeAIDiffuserComponent:
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context:CrossAttentionControl.Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
#print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
#print("applying saved attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model)
finally:
context.clear_requests(cleanup=True)
return unconditioned_next_x, conditioned_next_x
except RuntimeError:
# make sure we clean out the attention slices we're storing on the model
# TODO don't store things on the model
CrossAttentionControl.clear_requests(self.model)
raise
return unconditioned_next_x, conditioned_next_x
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None: