from typing import Optional, Tuple, Any, List import unittest, math import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import getenv from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device, CompiledRunner from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.engine.schedule import create_schedule from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.codegen.uops import exec_alu, UOpGraph from test.helpers import is_dtype_supported def _uops_to_prg(uops): src = Device[Device.DEFAULT].compiler.render("test", uops) has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg)) return uops[-1] def _test_single_value(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True)) buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, f'data{i+1}',False)) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg(UOpGraph(uops)) prg.exec([buf]+buf2) ret = np.empty(1, output_dtype.np) buf.copyout(ret.data) return ret[0] def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True)) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg(UOpGraph(uops)) prg.exec([buf]) ret = np.empty(1, output_dtype.np) buf.copyout(ret.data) return ret[0] def _test_uops_result(output_dtype, uops, res): # uops = [] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, 'data0',True)) # res = output_fn(uops) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg(UOpGraph(uops)) prg.exec([buf]) ret = np.empty(1, output_dtype.np) buf.copyout(ret.data) return ret[0] class TestUOps(unittest.TestCase): def _equal(self, v1, v2): assert isinstance(v2, (float, int, bool)) if isinstance(v2, float): np.testing.assert_allclose(v1, v2, rtol=2e-7) else: np.testing.assert_equal(v1, v2) def _test_uop_fxn(self, op, fxn, dts=(PtrDType(dtypes.float32), )): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: a = dtypes.as_const(a, dts[0]) self._equal(f([a], op, dts), fxn(a)) def _test_bop_fxn(self, op, fxn, dts=(PtrDType(dtypes.float32), )*2, no_b_zero=False): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0.0, 1.0]: for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(b, dts[1]) self._equal(f([a,b], op, dts), fxn(a,b)) def _test_top_fxn(self, op, fxn, dts=(PtrDType(dtypes.float32), )*3): for f in [_test_single_value, _test_single_value_const]: for a in [-2.0, 0, 1]: for b in [-3.0, 3.0]: for c in [-4.0, 4.0]: a = dtypes.as_const(a, dts[0]) b = dtypes.as_const(b, dts[1]) c = dtypes.as_const(c, dts[2]) self._equal(f([a,b,c], op, dts), fxn(a,b,c)) class TestFloatUOps(TestUOps): def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a) def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf')) def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a