diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 04669ae669..2cc4ce9c27 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -380,24 +380,33 @@ def l2i(op: Ops, dt: DType, *uops:UOp): # ***** floats ***** f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint } +# a modification of https://graphics.stanford.edu/~seander/bithacks.html#IntegerLog +def clz(bits: int, v: UOp) -> UOp: + r = v.const_like(0) + for s in [1 << i for i in range((bits - 1).bit_length() - 1, -1, -1)]: + r |= (shift := (v >> (bits - s)).eq(0).cast(v.dtype) * s) + v <<= shift + return r | ((v >> (bits - 1)) ^ 1) + +def rne(v: UOp, s) -> UOp: return (v >> s) + (((v >> (s - 1)) & 1) & ((v & ((v.const_like(1) << (s - 1)) - 1)).ne(0).cast(v.dtype) | ((v >> s) & 1))) + def f2f(v, fr:DType, to:DType): fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to) + # TODO: "denormals are zero" could make this much simpler if fs < ts: sign, nosign = (v & (1 << (fs-1))).cast(f2f_dt[to]) << (ts - fs), (v & ((1 << (fs-1)) - 1)).cast(f2f_dt[to]) - norm = (nosign << (tm - fm)) + ((tb - fb) << tm) - exp = nosign >> fm - # TODO: subnormals + norm, exp, mantissa = (nosign << (tm - fm)) + ((tb - fb) << tm), nosign >> fm, nosign & ((1 << fm) - 1) inf_or_nan = (nosign << (tm - fm)) | (((1 << te) - 1) << tm) - return (sign | exp.eq(0).where(nosign, exp.eq((1 << fe) - 1).where(inf_or_nan, norm))).bitcast(to) + shift = clz(ts, mantissa) - (ts - 1 - tm) + subnorm = ((mantissa << shift) & ((1 << tm) - 1)) | ((1 + tb - fb - fm + tm - shift) << tm) + return (sign | exp.eq(0).where(mantissa.eq(0).where(nosign, subnorm), exp.eq((1 << fe) - 1).where(inf_or_nan, norm))).bitcast(to) else: sign, nosign = (v >> (fs - ts)) & (1 << (ts - 1)), v & ((1 << (fs - 1)) - 1) - exp = (v >> fm) & ((1 << fe) - 1) - nosign_rounded = nosign + (1 << (fm - tm - 1)) # round to nearest - norm = ((nosign_rounded >> (fm - tm)) - ((fb - tb) << tm)).cast(f2f_dt[to]) - inf_or_nan = (sign | ((nosign >> (fm - tm)) & ((1 << tm) - 1)) | (((1 << te) - 1) << tm)).cast(f2f_dt[to]) - underflow, overflow = exp < (1 + (fb - tb)), exp > ((1 << te) - 2 + (fb - tb)) - sign = sign.cast(f2f_dt[to]) - return exp.eq((1 << fe) - 1).where(inf_or_nan, sign | underflow.where(UOp.const(f2f_dt[to], 0), overflow.where(UOp.const(f2f_dt[to], ((1 << te) - 1) << tm), norm))) + norm, exp = (rne(nosign, fm - tm) - ((fb - tb) << tm)).cast(f2f_dt[to]), (v >> fm) & ((1 << fe) - 1) + infnan = (sign | ((nosign >> (fm - tm)) & ((1 << tm) - 1)) | (((1 << te) - 1) << tm)).cast(f2f_dt[to]) + subnorm = rne((1 << fm) | (nosign & ((1 << fm) - 1)), (fm + 1 + fb - tb - tm) - exp).cast(f2f_dt[to]) + uf, sn, of = exp < (fb - tb - tm), exp < (1 + fb - tb), exp > ((1 << te) - 2 + (fb - tb)) + return exp.eq((1 << fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | uf.where(0, sn.where(subnorm, of.where(((1 << te) - 1) << tm, norm)))) # ***** decomposition patterns ***** @@ -452,10 +461,11 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, force_transcenden pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))] if not is_dtype_supported(dtypes.half, device) or dtypes.half in emulated_dtypes: pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: - x.replace(dtype=dtypes.uint16.ptr(x.dtype.size)) if x.dtype.base == dtypes.half else None)] + x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)] pat += [(UPat(Ops.LOAD, dtypes.half, name="x"), lambda x: f2f(x.replace(dtype=dtypes.ushort), dtypes.half, dtypes.float))] pat += [(UPat(GroupOp.ALU, dtypes.half, name="x"), lambda x: x.replace(dtype=dtypes.float))] - pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val: st.replace(src=(idx,f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half).bitcast(dtypes.ushort))))] + pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val: + st.replace(src=(idx,f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half).bitcast(dtypes.ushort))) if idx.tag == dtypes.half else None)] if not is_dtype_supported(dtypes.long, device) or dtypes.long in emulated_dtypes: pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None)]