from typing import Optional, Tuple, Any, List import unittest, math import numpy as np from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.uops import UOps, NOp, UOp from tinygrad.codegen.uopgraph import UOpGraph from test.helpers import is_dtype_supported, TestUOps as TestEqUOps def _uops_to_prg(uops_list, print_uops=False): uops = UOpGraph(uops_list) uops.linearize(Device[Device.DEFAULT].renderer.extra_matcher) src = Device[Device.DEFAULT].renderer.render("test", uops.uops) if print_uops: uops.print() has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops.uops, global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None)) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(src), arg)) return uops[-1] def _test_single_value(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) prg.exec([buf]+buf2) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] def _test_uops_result(output_dtype, uops, res): # uops = [] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) # res = output_fn(uops) out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out], print_uops=True) prg.exec([buf]) ret = np.empty(1, _to_np_dtype(output_dtype)) buf.copyout(ret.data) return ret[0] class TestUOps(unittest.TestCase): def _equal(self, v1, v2): assert isinstance(v2, (float, int, bool)) if isinstance(v2, float): np.testing.assert_allclose(v1, v2, rtol=2e-7) else: np.testing.assert_equal(v1, v2) def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: a = dtypes.as_const(a, dts[0]) self._equal(f([a], op, dts), fxn(a)) def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1]) self._equal(f([a,b], op, dts), fxn(a,b)) def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0, 1]: for b in [-3.0, 3.0]: for c in [-4.0, 4.0]: a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(b, dts[1]) c = dtypes.as_const(c, dts[2]) self._equal(f([a,b,c], op, dts), fxn(a,b,c)) class TestFloatUOps(TestUOps): def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf')) def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is UOps.IF) endif = next(u for u in uops if u.op is UOps.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, UOps.STORE) @unittest.expectedFailure def test_gate_some_stores(self): gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)] uops = UOpGraph(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is UOps.IF) endif = next(u for u in uops if u.op is UOps.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) self.assertEqual(len(gated_uops), 1) self.assertIs(gated_uops[-1].op, UOps.STORE) # scaled down version of TestLinearizerDumb.test_unmerged_ifs @unittest.expectedFailure def test_merge_ifs_alt(self): gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)] uops = UOpGraph(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) ifs = [u for u in uops if u.op is UOps.IF] endifs = [u for u in uops if u.op is UOps.ENDIF] self.assertEqual(len(ifs), 1) self.assertEqual(len(endifs), 1) gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])]) self.assertEqual(len(gated_uops), 2) for x in gated_uops: self.assertIs(x.op, UOps.STORE) class TestLocalAccess(unittest.TestCase): # NOTE: this is failing on METAL CI, no idea why. Works locally. @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "failing only in CI") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16)) st = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, UOps.BARRIER, None, (st,)) sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16)) st1 = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) st2 = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) barr = uop(uops, UOps.BARRIER, None, (st1,st2)) ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), barr)) sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") class TestAssembly(unittest.TestCase): def test_bitshift_left(self): g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) c1 = UOp(UOps.CONST, dtypes.int, (), 2) c2 = UOp(UOps.CONST, dtypes.int, (), 3) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) uops = UOpGraph([a1,a2]) uops.linearize(Device[Device.DEFAULT].renderer.extra_matcher) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops.uops[-1].arg, BinaryOps.SHL) self.assertEqual(uops.uops[-2].arg, BinaryOps.MUL) def test_bitshift_right(self): g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) c1 = UOp(UOps.CONST, dtypes.int, (), 2) c2 = UOp(UOps.CONST, dtypes.int, (), 3) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) uops = UOpGraph([a1,a2]) uops.linearize(Device[Device.DEFAULT].renderer.extra_matcher) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops.uops[-1].arg, BinaryOps.SHR) self.assertEqual(uops.uops[-2].arg, BinaryOps.IDIV) class TestUOpCompare(unittest.TestCase): def test_alu_same_src_different_arg(self): a = UOp(UOps.CONST, dtypes.float, (), 2.0) b = UOp(UOps.CONST, dtypes.float, (), 3.0) add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD) mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL) assert (add < mul) or (mul < add), "add and mul with same src should have an order" class TestUOpStr(TestEqUOps): def test_uop_str(self): a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0) for _ in range(20): a = a + a assert len(str(a)) < 10_000, "exponential string growth" assert str(eval(str(a))) == str(a) t = Tensor.arange(10) t = t + t * Tensor.rand(10) # nice big complicated uop with Context(NOOPT=1): sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink self.assert_equiv_uops(sink, eval(str(sink))) def test_nop_str(self): a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1") assert str(eval(str(a))) == str(a) if __name__ == '__main__': unittest.main(verbosity=2)