From ac08c31fbcb827668916f3fc73827880d18306cf Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 8 Oct 2024 11:11:22 -0400 Subject: [PATCH] Remove unnecessary hasattr checks for scaled_dot_product_attention. We pin the torch version, so there should be no concern that this function does not exist. --- .../stable_diffusion/diffusers_pipeline.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 311a44c2a1..8a16f90577 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -198,11 +198,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.disable_attention_slicing() return elif config.attention_type == "torch-sdp": - if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - else: - raise Exception("torch-sdp attention slicing not available") + # torch-sdp is the default in diffusers. + return # See https://github.com/invoke-ai/InvokeAI/issues/7049 for context. # Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results. @@ -210,17 +207,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory # utilization. if torch.backends.mps.is_available(): - assert hasattr(torch.nn.functional, "scaled_dot_product_attention") return - # the remainder if this code is called when attention_type=='auto' + # The remainder if this code is called when attention_type=='auto'. if self.unet.device.type == "cuda": if is_xformers_available() and prefer_xformers: self.enable_xformers_memory_efficient_attention() return - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return + # torch-sdp is the default in diffusers. + return if self.unet.device.type == "cpu" or self.unet.device.type == "mps": mem_free = psutil.virtual_memory().free