mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix last bug in unet probz
This commit is contained in:
@@ -287,7 +287,7 @@ class SpatialTransformer:
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=(0,1,0,1))
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.op(x)
|
||||
@@ -371,9 +371,9 @@ class UNetModel:
|
||||
print("input block", i)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
print(x.numpy())
|
||||
if i == 1:
|
||||
return None
|
||||
#if i == 3:
|
||||
# print(x.numpy())
|
||||
# return None
|
||||
saved_inputs.append(x)
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
|
||||
Reference in New Issue
Block a user