diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 9fdf54849f..d614f0530b 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -1,7 +1,7 @@ import unittest, operator, math from tinygrad import Context, Tensor, dtypes, Device from tinygrad.dtype import DType, truncate -from tinygrad.helpers import CI, getenv +from tinygrad.helpers import CI, EMULATED_DTYPES, getenv from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported from tinygrad.runtime.ops_python import from_storage_scalar @@ -64,7 +64,10 @@ def universal_test(a, b, dtype, op): numpy_value = op[1](ta.numpy(), tb.numpy()) if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item()) if dtype in dtypes.floats: - atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7)) + if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero + fe, fm = dtypes.finfo(dtype) + atol, rtol = 2 ** (2 - (1 << (fe - 1))), 2 ** (-fm) + else: atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype, (1e-10, 1e-7)) np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol) else: np.testing.assert_equal(tensor_value, numpy_value) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 2cc4ce9c27..e304c27c6d 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -380,33 +380,21 @@ def l2i(op: Ops, dt: DType, *uops:UOp): # ***** floats ***** f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint } -# a modification of https://graphics.stanford.edu/~seander/bithacks.html#IntegerLog -def clz(bits: int, v: UOp) -> UOp: - r = v.const_like(0) - for s in [1 << i for i in range((bits - 1).bit_length() - 1, -1, -1)]: - r |= (shift := (v >> (bits - s)).eq(0).cast(v.dtype) * s) - v <<= shift - return r | ((v >> (bits - 1)) ^ 1) - -def rne(v: UOp, s) -> UOp: return (v >> s) + (((v >> (s - 1)) & 1) & ((v & ((v.const_like(1) << (s - 1)) - 1)).ne(0).cast(v.dtype) | ((v >> s) & 1))) +def rne(v: UOp, s) -> UOp: return (v >> s) + (((v >> (s - 1)) & 1) & ((v & ((1 << (s - 1)) - 1)).ne(0).cast(v.dtype) | ((v >> s) & 1))) def f2f(v, fr:DType, to:DType): fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to) - # TODO: "denormals are zero" could make this much simpler + # NB: denormals are zero! if fs < ts: sign, nosign = (v & (1 << (fs-1))).cast(f2f_dt[to]) << (ts - fs), (v & ((1 << (fs-1)) - 1)).cast(f2f_dt[to]) - norm, exp, mantissa = (nosign << (tm - fm)) + ((tb - fb) << tm), nosign >> fm, nosign & ((1 << fm) - 1) + exp, norm = nosign >> fm, (nosign << (tm - fm)) + ((tb - fb) << tm) inf_or_nan = (nosign << (tm - fm)) | (((1 << te) - 1) << tm) - shift = clz(ts, mantissa) - (ts - 1 - tm) - subnorm = ((mantissa << shift) & ((1 << tm) - 1)) | ((1 + tb - fb - fm + tm - shift) << tm) - return (sign | exp.eq(0).where(mantissa.eq(0).where(nosign, subnorm), exp.eq((1 << fe) - 1).where(inf_or_nan, norm))).bitcast(to) + return (sign | exp.eq(0).where(0, exp.eq((1 << fe) - 1).where(inf_or_nan, norm))).bitcast(to) else: - sign, nosign = (v >> (fs - ts)) & (1 << (ts - 1)), v & ((1 << (fs - 1)) - 1) - norm, exp = (rne(nosign, fm - tm) - ((fb - tb) << tm)).cast(f2f_dt[to]), (v >> fm) & ((1 << fe) - 1) - infnan = (sign | ((nosign >> (fm - tm)) & ((1 << tm) - 1)) | (((1 << te) - 1) << tm)).cast(f2f_dt[to]) - subnorm = rne((1 << fm) | (nosign & ((1 << fm) - 1)), (fm + 1 + fb - tb - tm) - exp).cast(f2f_dt[to]) - uf, sn, of = exp < (fb - tb - tm), exp < (1 + fb - tb), exp > ((1 << te) - 2 + (fb - tb)) - return exp.eq((1 << fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | uf.where(0, sn.where(subnorm, of.where(((1 << te) - 1) << tm, norm)))) + sign, nosign, exp = (v >> (fs - ts)) & (1 << (ts - 1)), v & ((1 << (fs - 1)) - 1), (v >> fm) & ((1 << fe) - 1) + norm, infnan = (rne(nosign, fm - tm) - ((fb - tb) << tm)).cast(f2f_dt[to]), (sign | ((nosign >> (fm - tm)) & ((1 << tm) - 1)) | (((1 << te) - 1) << tm)).cast(f2f_dt[to]) + underflow, overflow = exp < (1 + fb - tb), exp > ((1 << te) - 2 + (fb - tb)) + return exp.eq((1 << fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | underflow.where(0, overflow.where(((1 << te) - 1) << tm, norm))) # ***** decomposition patterns *****