mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
tests from lowerer branch (#5339)
* tests from lowerer branch * Update test_image_dtype.py * Update test_image_dtype.py * Update test_image_dtype.py
This commit is contained in:
@@ -28,7 +28,12 @@ if __name__ == "__main__":
|
||||
# confirm linearize can be called twice
|
||||
uops1 = lin.linearize().uops
|
||||
uops2 = lin.linearize().uops
|
||||
assert tuple(uops1) == tuple(uops2), f"uops mismatch {lin.colored_shape()}"
|
||||
for x,y in zip(uops1.uops, uops2.uops):
|
||||
# for some reason DEFINE_ACC is changing the arg
|
||||
if x.op != y.op or x.dtype != y.dtype: # or x.arg != y.arg:
|
||||
uops1.print()
|
||||
uops2.print()
|
||||
raise Exception(f"UOPS MISMATCH {x} {y}")
|
||||
|
||||
print(len(tactions), len(actions))
|
||||
print(sorted(list(tactions)))
|
||||
|
||||
@@ -42,6 +42,13 @@ class TestConv(unittest.TestCase):
|
||||
|
||||
print(ret.numpy())
|
||||
|
||||
def test_two_binops_no_rerun_small(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,1,32,32)
|
||||
w = Tensor.rand(1,1,3,3)
|
||||
out = x.conv2d(w, padding=(1,1))
|
||||
np.testing.assert_allclose(out.relu().numpy(), np.maximum(out.numpy(), 0))
|
||||
|
||||
def test_two_binops_no_rerun(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.randn(1,12,128,256)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Device, dtypes, Tensor, Variable
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.linearizer import to_image_idx
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
@@ -30,6 +29,11 @@ class TestImageDType(unittest.TestCase):
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.lazydata.base.realized.dtype, ImageDType)
|
||||
|
||||
def test_sum(self):
|
||||
it = Tensor.rand(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
itn = it.numpy()
|
||||
np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6)
|
||||
|
||||
def test_shrink_max(self):
|
||||
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
imgv = it.numpy()
|
||||
@@ -64,14 +68,5 @@ class TestImageDType(unittest.TestCase):
|
||||
it = data.cast(dtypes.imageh((9,27,4))).realize()
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
|
||||
class TestImageIdx(unittest.TestCase):
|
||||
def test_to_image_idx_real1(self):
|
||||
gidx0 = Variable('gidx0', 0, 511)
|
||||
base_idx = (((gidx0*4)%32)*32)+((gidx0//8)%32)
|
||||
base_valid = gidx0<256
|
||||
(idx, idy), valid = to_image_idx((4, 64, 4), base_idx, base_valid)
|
||||
print(idx, idy, idx.min, idx.max, idy.min, idy.max, valid)
|
||||
assert valid.min == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -114,6 +114,10 @@ class TestFloatUOps(TestUOps):
|
||||
def test_where(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
|
||||
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only python supports MULACC")
|
||||
def test_mulacc(self):
|
||||
self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
|
||||
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (dtypes.int32, ))
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
|
||||
@@ -127,6 +127,7 @@ python_alu = {
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
||||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z,
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def truncate_fp16(x):
|
||||
|
||||
Reference in New Issue
Block a user