mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Support weird loads in Image (#2498)
* image support weird loads * umm, that was always wrong * openpilot compile fails with a weird error * image test passes * we have valids now * clean that up * no more required opts * add fastvits test, fix bug * minor cleanups
This commit is contained in:
29
test/test_image_dtype.py
Normal file
29
test/test_image_dtype.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import ImageDType
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_shrink_load_float(self):
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
|
||||
|
||||
def test_mul_stays_image(self):
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.lazydata.realized.dtype, ImageDType)
|
||||
|
||||
def test_shrink_max(self):
|
||||
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
|
||||
|
||||
def test_shrink_to_float(self):
|
||||
it = Tensor.randn(4, 4).cast(dtypes.imagef((1,4,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(np.maximum(imgv[:, 0], 0), it[:, 0].relu().realize())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user