bilinear interp uint8 fails (#6103)

* new test for e2e compile failures

* fix bug

* bilinear interp uint8 fails

* better tests
This commit is contained in:
George Hotz
2024-08-15 19:34:39 -07:00
committed by GitHub
parent c850e03758
commit 553ae9ebc0
2 changed files with 16 additions and 1 deletions

View File

@@ -2024,6 +2024,21 @@ class TestOps(unittest.TestCase):
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
class TestOpsUint8(unittest.TestCase):
@unittest.skip('this is broken for negative numbers')
def test_cast(self):
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
def test_cast_relu(self):
helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True)
@unittest.skip('this is wrong output')
def test_interpolate_bilinear(self):
out_sz = (10, 10)
helper_test_op([(2,3,64,64)],
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="bilinear"),
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="linear"), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@@ -1970,7 +1970,7 @@ class Tensor:
reshape[i] = expand[i] = size[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
x = x.gather(i, low).lerp(x.gather(i, high), perc)
return x
return x.cast(self.dtype)
# ***** unary ops *****