mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-07 02:14:56 -05:00
more wip sliced attention (.swap doesn't work yet)
This commit is contained in:
@@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
if is_xformers_available() and not Globals.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
slice_size = 2
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
@@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
do_attention_map_saving=False):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
attention_map_token_ids = range(1, eos_token_index)
|
||||
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
# TODO resuscitate attention map saving
|
||||
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
|
||||
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
|
||||
|
||||
self.invokeai_diffuser.remove_attention_map_saving()
|
||||
return latents, attention_map_saver
|
||||
return latents, attention_map_saver
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user