mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
bilinear interp uint8 fails (#6103)
* new test for e2e compile failures * fix bug * bilinear interp uint8 fails * better tests
This commit is contained in:
@@ -2024,6 +2024,21 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
|
||||
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
@unittest.skip('this is broken for negative numbers')
|
||||
def test_cast(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
|
||||
|
||||
def test_cast_relu(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True)
|
||||
|
||||
@unittest.skip('this is wrong output')
|
||||
def test_interpolate_bilinear(self):
|
||||
out_sz = (10, 10)
|
||||
helper_test_op([(2,3,64,64)],
|
||||
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="bilinear"),
|
||||
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="linear"), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1970,7 +1970,7 @@ class Tensor:
|
||||
reshape[i] = expand[i] = size[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
||||
return x
|
||||
return x.cast(self.dtype)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user