fix last bug in unet probz

This commit is contained in:
George Hotz
2022-09-05 11:32:44 -07:00
parent 3df67aa0af
commit b8bd34b5d2

View File

@@ -287,7 +287,7 @@ class SpatialTransformer:
class Downsample:
def __init__(self, channels):
self.op = Conv2d(channels, channels, 3, stride=2, padding=(0,1,0,1))
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
def __call__(self, x):
return self.op(x)
@@ -371,9 +371,9 @@ class UNetModel:
print("input block", i)
for bb in b:
x = run(x, bb)
print(x.numpy())
if i == 1:
return None
#if i == 3:
# print(x.numpy())
# return None
saved_inputs.append(x)
for bb in self.middle_block:
x = run(x, bb)