more wip sliced attention (.swap doesn't work yet)

This commit is contained in:
Damian Stewart
2023-01-25 14:51:08 +01:00
parent 63c6019f92
commit a4aea1540b
3 changed files with 75 additions and 49 deletions

View File

@@ -306,6 +306,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
slice_size = 2
self.enable_attention_slicing(slice_size=slice_size)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
@@ -370,43 +373,40 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
do_attention_map_saving=False):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
latents=latents)
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
self.invokeai_diffuser.remove_attention_map_saving()
attention_map_saver: Optional[AttentionMapSaver] = None
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(batched_t, latents, conditioning_data,
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
predicted_original = getattr(step_output, 'pred_original_sample', None)
if i == len(timesteps)-1 and extra_conditioning_info is not None:
eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
attention_map_token_ids = range(1, eos_token_index)
attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
# TODO resuscitate attention map saving
#if i == len(timesteps)-1 and extra_conditioning_info is not None:
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
# attention_map_token_ids = range(1, eos_token_index)
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
predicted_original=predicted_original, attention_map_saver=attention_map_saver)
self.invokeai_diffuser.remove_attention_map_saving()
return latents, attention_map_saver
return latents, attention_map_saver
@torch.inference_mode()
def step(self, t: torch.Tensor, latents: torch.Tensor,