Merge branch 'development' of https://github.com/pbaylies/stable-diffusion into development

This commit is contained in:
Peter Baylies
2022-09-12 18:35:25 -04:00
4 changed files with 93 additions and 65 deletions

View File

@@ -13,8 +13,9 @@ def choose_torch_device() -> str:
def choose_autocast_device(device):
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type in ('cuda','cpu'):
if device_type == 'cuda':
return device_type,autocast
elif device_type == 'cpu':
return device_type,nullcontext
else:
return 'cpu',nullcontext

View File

@@ -111,7 +111,6 @@ class Generate:
height = 512,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
precision = 'autocast',
full_precision = False,
strength = 0.75, # default in scripts/img2img.py
seamless = False,
@@ -129,7 +128,6 @@ class Generate:
self.sampler_name = sampler_name
self.grid = grid
self.ddim_eta = ddim_eta
self.precision = precision
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.strength = strength
self.seamless = seamless

View File

@@ -121,30 +121,17 @@ class ResnetBlock(nn.Module):
padding=0)
def forward(self, x, temb):
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
h = self.norm1(x)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@@ -152,7 +139,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h8
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
@@ -209,8 +196,7 @@ class AttnBlock(nn.Module):
h_ = torch.zeros_like(k, device=q.device)
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
if device_type == 'cuda':
if q.device.type == 'cuda':
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
@@ -599,22 +585,16 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h1 = self.conv_in(z)
h = self.conv_in(z)
# middle
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# prepare for up sampling
device_type = 'mps' if h.device.type == 'mps' else 'cuda'
gc.collect()
if device_type == 'cuda':
if h.device.type == 'cuda':
torch.cuda.empty_cache()
# upsampling
@@ -622,33 +602,19 @@ class Decoder(nn.Module):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
t = h
h = self.up[i_level].attn[i_block](t)
del t
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
t = h
h = self.up[i_level].upsample(t)
del t
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
t = h
h = torch.tanh(t)
del t
h = torch.tanh(h)
return h