mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix max_unpool2d inf (#11784)
* start * add regression test for maxunpool2d
This commit is contained in:
@@ -2461,6 +2461,20 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True),
|
||||
kernel_size=(2,2), output_size=(99,99,7,6)), forward_only=True)
|
||||
|
||||
def test_max_unpool2d_inf(self):
|
||||
data = [[[[math.inf, -math.inf, math.nan], [1.0, 2.0, 3.0]]]]
|
||||
ksz = (2,2)
|
||||
helper_test_op((),
|
||||
lambda: torch.nn.functional.max_unpool2d(
|
||||
*torch.nn.functional.max_pool2d(torch.tensor(data), kernel_size=ksz, return_indices=True),
|
||||
kernel_size=ksz
|
||||
),
|
||||
lambda: Tensor.max_unpool2d(
|
||||
*Tensor.max_pool2d(Tensor(data), kernel_size=ksz, return_indices=True),
|
||||
kernel_size=ksz
|
||||
),
|
||||
forward_only=True)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
|
||||
@@ -1192,7 +1192,7 @@ class Tensor(MathTrait):
|
||||
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
|
||||
|
||||
# dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
|
||||
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], (int, UOp))))
|
||||
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], sint)))
|
||||
|
||||
# tensor indexing
|
||||
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
|
||||
@@ -2437,7 +2437,7 @@ class Tensor(MathTrait):
|
||||
# https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1.
|
||||
output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_))
|
||||
else: output_size = output_size[-len(spatial_shape):]
|
||||
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3)
|
||||
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3)
|
||||
return ret.reshape(bs,c,*output_size)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
|
||||
Reference in New Issue
Block a user