diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index a6cefc82ad..970f6aad8f 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -210,10 +210,7 @@ class AttnBlock(nn.Module): h_ = torch.zeros_like(k, device=q.device) device_type = 'mps' if q.device.type == 'mps' else 'cuda' - - if device_type == 'mps': - mem_free_total = psutil.virtual_memory().available - else: + if device_type == 'cuda': stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] @@ -221,14 +218,21 @@ class AttnBlock(nn.Module): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 - mem_required = tensor_size * 2.5 - steps = 1 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + else: + if psutil.virtual_memory().available / (1024**3) < 12: + slice_size = 1 + else: + slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1]))) + for i in range(0, q.shape[1], slice_size): end = i + slice_size