From 313b206ff8f2a2099cbc5dae7050bcd6b2576590 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 22 Jan 2023 18:12:11 +0100 Subject: [PATCH] squash float16/float32 mismatch on linux --- ldm/models/diffusion/cross_attention_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 4b89b5bd56..45294ac993 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -329,7 +329,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) + mask = torch.zeros(max_length, dtype=torch_dtype()) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: @@ -338,7 +338,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - +b context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) if is_running_diffusers: