diff --git a/test/test_dtype.py b/test/test_dtype.py index 6f5022a7a3..3fb4b1b5cc 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -381,8 +381,29 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16 class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3 + +class TestEmulatedFp8e4m3(TestFp8e4m3): + @classmethod + def setUpClass(cls): + cls.stack = contextlib.ExitStack() + cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3")) + cls.DATA = rand_for_dtype(cls.DTYPE, 10) + + @classmethod + def tearDownClass(cls): cls.stack.close() + class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2 +class TestEmulatedFp8e5m2(TestFp8e5m2): + @classmethod + def setUpClass(cls): + cls.stack = contextlib.ExitStack() + cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2")) + cls.DATA = rand_for_dtype(cls.DTYPE, 10) + + @classmethod + def tearDownClass(cls): cls.stack.close() + class TestPtrDType(unittest.TestCase): def test_vec_double(self): dt1 = dtypes.float.vec(4).ptr().vec(4) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 813948ccdd..50a273d6fd 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -1,6 +1,6 @@ import unittest, operator, math from tinygrad import Context, Tensor, dtypes, Device -from tinygrad.dtype import DType, truncate +from tinygrad.dtype import DType, truncate, fp8_to_float from tinygrad.helpers import CI, EMULATED_DTYPES, getenv from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported @@ -59,9 +59,10 @@ def universal_test(a, b, dtype, op): # lt and max with nan is undefined in tinygrad if op[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype) - tensor_value = (op[0](ta, tb)).numpy() - numpy_value = op[1](ta.numpy(), tb.numpy()) - if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item()) + if dtype in dtypes.fp8s and op[0] not in (operator.lt, operator.eq): + tensor_value = fp8_to_float((op[0](ta.realize(), tb.realize())).bitcast(dtypes.uint8).item(), dtype) + numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item()) + else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy()) if dtype in dtypes.floats: if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero fe, fm = dtypes.finfo(dtype) @@ -76,13 +77,14 @@ def universal_test_unary(a, dtype, op): # TODO: cos does not match for large input if op[0] == Tensor.cos and abs(a) > 30: return if op[0] == Tensor.log and a <= 0: return - out: Tensor = op[0](ta) - tensor_value = out.numpy() - numpy_value = op[1](ta.numpy()) if dtype in dtypes.fp8s: + # normals are zero + if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return + tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype) + numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item()) # cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2) - if math.isinf(numpy_value.item()): return - numpy_value = truncate[dtype](numpy_value.item()) + if math.isinf(v): return + else: tensor_value, numpy_value = op[0](ta).numpy(), op[1](ta.numpy()) if dtype in dtypes.floats: atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5)) @@ -133,11 +135,21 @@ class TestDTypeALU(unittest.TestCase): def test_fp8e4m3(self, a, b, op): universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op) + @given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations)) + @Context(EMULATED_DTYPES="fp8e4m3") + def test_emulated_fp8e4m3(self, a, b, op): + universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op) + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}") @given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations)) def test_fp8e5m2(self, a, b, op): universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op) + @given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations)) + @Context(EMULATED_DTYPES="fp8e5m2") + def test_emulated_fp8e5m2(self, a, b, op): + universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op) + @given(ht.float32, strat.sampled_from(unary_operations)) def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op) @@ -159,12 +171,24 @@ class TestDTypeALU(unittest.TestCase): if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0) universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op) + @given(ht.fp8e4m3, strat.sampled_from(unary_operations)) + @Context(EMULATED_DTYPES="fp8e4m3") + def test_emulated_fp8e4m3_unary(self, a, op): + if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0) + universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op) + @unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}") @given(ht.fp8e5m2, strat.sampled_from(unary_operations)) def test_fp8e5m2_unary(self, a, op): if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0) universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op) + @given(ht.fp8e5m2, strat.sampled_from(unary_operations)) + @Context(EMULATED_DTYPES="fp8e5m2") + def test_emulated_fp8e5m2_unary(self, a, op): + if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0) + universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op) + @given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations)) def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 707cfc24ba..37d64a616b 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -5,14 +5,15 @@ from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TR from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer, ProgramSpec -from tinygrad.dtype import dtypes +from tinygrad.dtype import dtypes, promo_lattice +from tinygrad.device import is_dtype_supported from tinygrad.helpers import panic from tinygrad.codegen.opt import Opt # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load -from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns, get_transcendental_patterns +from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads @@ -88,10 +89,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # decompositions supported_ops = tuple(ren.code_for_op.keys()) pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV)) - pm_unsupported = get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes))) pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2) sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions") - sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True) + if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes): + sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True) + for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float)) + for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]: + sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True) sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental") # final rules for the renderer (without sym) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 9629c9eae4..1689f45ab8 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -14,8 +14,8 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp): # *** helper functions for bit manipulation *** def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1] -def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()] -def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()] +def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1 +def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1 # **** utils **** def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y) @@ -378,33 +378,46 @@ def l2i(op: Ops, dt: DType, *uops:UOp): case _: raise NotImplementedError(f"long decomposition of {op} unsupported") # ***** floats ***** -f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint } +f2f_dt = { f:getattr(dtypes, f"uint{f.bitsize}") for f in dtypes.floats } def rne(v: UOp, s) -> UOp: return shr(v, s) + ((shr(v, s - 1) & 1) & ((v & ((1 << (s - 1)) - 1)).ne(0).cast(v.dtype) | (shr(v, s) & 1))) def f2f(v, fr:DType, to:DType): fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to) # NB: denormals are zero! - if fe < te and fm < tm: + if fe <= te and fm < tm: sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to]) exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm) - inf_or_nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm) - return (sign | exp.eq(0).where(0, exp.eq(shl(1, fe) - 1).where(inf_or_nan, norm))).bitcast(to) - elif fe > te and fm > tm: - sign, nosign, exp = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1), shr(v, fm) & (shl(1, fe) - 1) + nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm) + # fp8e4m3 has only one nan + is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1)) + return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to) + elif fe >= te and fm > tm: + v = f2f_clamp(v.bitcast(fr), to).bitcast(f2f_dt[fr]) + sign, nosign = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1) norm = (rne(nosign, fm - tm) - shl(fb - tb, tm)).cast(f2f_dt[to]) - infnan = (sign | (shr(nosign, fm - tm) & (shl(1, tm) - 1)) | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to]) - underflow, overflow = exp < (1 + fb - tb), exp > (shl(1, te) - 2 + (fb - tb)) - return exp.eq(shl(1, fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | underflow.where(0, overflow.where(shl(shl(1, te) - 1, tm), norm))) + underflow = (shr(v, fm) & (shl(1, fe) - 1)) < (1 + fb - tb) + nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1)) + nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to]) + is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1) + return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm)) else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}") -def f2f_load(x: UOp) -> UOp: - if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=dtypes.ushort), dtypes.half, dtypes.float) - return UOp.vectorize(*(f2f(x.replace(dtype=dtypes.ushort, src=(reindex(x.src[0].src[0], i, 1),)), dtypes.half, dtypes.float) for i in range(n))) +def f2f_clamp(val:UOp, dt:DType) -> UOp: + e, m = dtypes.finfo(dt) + max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1) + mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m))) + sat = mx if dt in dtypes.fp8s else val.const_like(float('inf')) + # FIXME: CMPLT of nan is undefined + return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val))) -def f2f_store(st, idx, val): - if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half))) - return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n))) +def f2f_load(x: UOp, fr:DType, to:DType) -> UOp: + if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to) + return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n))) + +def f2f_store(st, idx, val, fr:DType, to:DType): + if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr))) + return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n))) # ***** decomposition patterns ***** @@ -463,40 +476,44 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, disable_fast_idiv pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))] return PatternMatcher(pat) -@functools.cache -def get_unsupported_dtypes_patterns(device:str, emulated_dtypes:tuple[DType, ...]) -> PatternMatcher: - pat: list[tuple[UPat, Callable]] = [] - if not is_dtype_supported(dtypes.long, device) or dtypes.long in emulated_dtypes: - pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: - x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None)] - pat += [(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype]))] - pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val: - st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None)] - pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x: - l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))] - pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x: - l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None)] - pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x: - (a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag])] - pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x: - l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None)] - pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x: - l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt)) - if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])] - pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: - x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)))] - pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x: - UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))] - if dtypes.half in emulated_dtypes: - pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: - x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)] - pat += [(UPat(Ops.LOAD, dtypes.half, name="x"), f2f_load)] - pat += [(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, dtypes.half, name="ld"),), name="bc"), lambda bc,ld: - ld.replace(dtype=dtypes.ushort).bitcast(bc.dtype))] - pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x: - bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))] - pat += [(UPat(GroupOp.All, dtypes.half, name="x"), lambda x: - x.replace(dtype=dtypes.float.vec(x.dtype.count), src=tuple(s.cast(dtypes.float) if s.dtype == dtypes.half else s for s in x.src)))] - pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val: - f2f_store(st, idx, val) if (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == dtypes.half else None)] - return PatternMatcher(pat) +pm_long_decomp = PatternMatcher([ + (UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x: + x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None), + (UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])), + (UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val: + st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None), + (UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x: + l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))), + (UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x: + l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None), + (UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x: + (a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag]), + (UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x: + l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None), + (UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x: + l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt)) + if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]), + (UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))), + (UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x: + UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))) +]) + +# float decomposition patterns - ctx is (fr, to) tuple +pm_float_decomp = PatternMatcher([ + (UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x: + x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None), + (UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None), + (UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld: + ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None), + (UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x: + bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None), + (UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val: + f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None), + (UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x: + x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src)) + if x.dtype.scalar() == ctx[0] else None), + (UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val: + st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None), + (UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val: + f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None), +])