Support weird loads in Image (#2498)

* image support weird loads

* umm, that was always wrong

* openpilot compile fails with a weird error

* image test passes

* we have valids now

* clean that up

* no more required opts

* add fastvits test, fix bug

* minor cleanups
This commit is contained in:
George Hotz
2023-11-29 08:30:46 -08:00
committed by GitHub
parent e333672675
commit 889acefe85
11 changed files with 79 additions and 105 deletions

29
test/test_image_dtype.py Normal file
View File

@@ -0,0 +1,29 @@
import unittest
import numpy as np
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import ImageDType
@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")
class TestImageDType(unittest.TestCase):
def test_shrink_load_float(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
def test_mul_stays_image(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
out = (it*2).realize()
assert isinstance(out.lazydata.realized.dtype, ImageDType)
def test_shrink_max(self):
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
def test_shrink_to_float(self):
it = Tensor.randn(4, 4).cast(dtypes.imagef((1,4,4))).realize()
imgv = it.numpy()
np.testing.assert_equal(np.maximum(imgv[:, 0], 0), it[:, 0].relu().realize())
if __name__ == '__main__':
unittest.main()