Add attention slicing support (#1087)

This commit is contained in:
gpetters94
2023-02-24 05:43:02 -05:00
committed by GitHub
parent e7eb116bd2
commit 694b1d43a8
2 changed files with 12 additions and 0 deletions

View File

@@ -226,6 +226,11 @@ class SharkifyStableDiffusionModel:
)
self.in_channels = self.unet.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self, latent, timestep, text_embedding, guidance_scale

View File

@@ -263,6 +263,13 @@ p.add_argument(
help="Use the accelerate package to reduce cpu memory consumption",
)
p.add_argument(
"--attention_slicing",
type=str,
default="none",
help="Amount of attention slicing to use (one of 'max', 'auto', 'none', or an integer)",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################