Merge branch 'development' into model-switching

This commit is contained in:
Lincoln Stein
2022-10-14 13:18:59 -04:00
committed by GitHub
32 changed files with 727 additions and 575 deletions

View File

@@ -49,9 +49,15 @@ class Upsample(nn.Module):
padding=1)
def forward(self, x):
cpu_m1_cond = True if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and \
x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False
if cpu_m1_cond:
x = x.to('cpu') # send to cpu
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if cpu_m1_cond:
x = x.to('mps') # return to mps
return x
@@ -117,6 +123,14 @@ class ResnetBlock(nn.Module):
padding=0)
def forward(self, x, temb):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
x_size = x.size()
if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0:
self.to('cpu')
x = x.to('cpu')
else:
self.to('mps')
x = x.to('mps')
h = self.norm1(x)
h = silu(h)
h = self.conv1(h)