mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
decompose fp8 to bigger floats [skip_process_replay] (#14554)
* decompose fp8 also * it works * cleanup * no shift required * default to float * cleanup * fixes * fp8e5m2 * don't rely on behavior comparing nans * cleanup
This commit is contained in:
committed by
GitHub
parent
81f6cdb4ab
commit
7bb45e7df0
@@ -381,8 +381,29 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
|
||||
class TestBFloat16Type(TestDType): DTYPE = dtypes.bfloat16
|
||||
|
||||
class TestFp8e4m3(TestDType): DTYPE = dtypes.fp8e4m3
|
||||
|
||||
class TestEmulatedFp8e4m3(TestFp8e4m3):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e4m3"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestFp8e5m2(TestDType): DTYPE = dtypes.fp8e5m2
|
||||
|
||||
class TestEmulatedFp8e5m2(TestFp8e5m2):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.stack = contextlib.ExitStack()
|
||||
cls.stack.enter_context(Context(EMULATED_DTYPES="fp8e5m2"))
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, operator, math
|
||||
from tinygrad import Context, Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, truncate
|
||||
from tinygrad.dtype import DType, truncate, fp8_to_float
|
||||
from tinygrad.helpers import CI, EMULATED_DTYPES, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -59,9 +59,10 @@ def universal_test(a, b, dtype, op):
|
||||
# lt and max with nan is undefined in tinygrad
|
||||
if op[0] in (operator.lt, Tensor.maximum) and (math.isnan(a) or math.isnan(b)): return
|
||||
ta, tb = Tensor([a], dtype=dtype), Tensor([b], dtype=dtype)
|
||||
tensor_value = (op[0](ta, tb)).numpy()
|
||||
numpy_value = op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value.item())
|
||||
if dtype in dtypes.fp8s and op[0] not in (operator.lt, operator.eq):
|
||||
tensor_value = fp8_to_float((op[0](ta.realize(), tb.realize())).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](op[1](ta.numpy(), tb.numpy()).item())
|
||||
else: tensor_value, numpy_value = (op[0](ta, tb)).numpy(), op[1](ta.numpy(), tb.numpy())
|
||||
if dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype) or dtype in EMULATED_DTYPES.tolist(dtypes): # denormals are zero
|
||||
fe, fm = dtypes.finfo(dtype)
|
||||
@@ -76,13 +77,14 @@ def universal_test_unary(a, dtype, op):
|
||||
# TODO: cos does not match for large input
|
||||
if op[0] == Tensor.cos and abs(a) > 30: return
|
||||
if op[0] == Tensor.log and a <= 0: return
|
||||
out: Tensor = op[0](ta)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s:
|
||||
# normals are zero
|
||||
if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return
|
||||
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
if math.isinf(numpy_value.item()): return
|
||||
numpy_value = truncate[dtype](numpy_value.item())
|
||||
if math.isinf(v): return
|
||||
else: tensor_value, numpy_value = op[0](ta).numpy(), op[1](ta.numpy())
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
@@ -133,11 +135,21 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3")
|
||||
def test_emulated_fp8e4m3(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e4m3), from_storage_scalar(b, dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
def test_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.fp8e5m2, ht.fp8e5m2, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2")
|
||||
def test_emulated_fp8e5m2(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.fp8e5m2), from_storage_scalar(b, dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.float32, strat.sampled_from(unary_operations))
|
||||
def test_float32_unary(self, a, op): universal_test_unary(a, dtypes.float32, op)
|
||||
|
||||
@@ -159,12 +171,24 @@ class TestDTypeALU(unittest.TestCase):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@given(ht.fp8e4m3, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e4m3")
|
||||
def test_emulated_fp8e4m3_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e4m3) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e4m3), dtypes.fp8e4m3, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), f"no fp8e5m2 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
def test_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.fp8e5m2, strat.sampled_from(unary_operations))
|
||||
@Context(EMULATED_DTYPES="fp8e5m2")
|
||||
def test_emulated_fp8e5m2_unary(self, a, op):
|
||||
if op[1] == np.reciprocal: assume(from_storage_scalar(a, dtype=dtypes.fp8e5m2) != 0.0)
|
||||
universal_test_unary(from_storage_scalar(a, dtype=dtypes.fp8e5m2), dtypes.fp8e5m2, op)
|
||||
|
||||
@given(ht.uint8, ht.uint8, strat.sampled_from(integer_binary_operations))
|
||||
def test_uint8(self, a, b, op): universal_test(a, b, dtypes.uint8, op)
|
||||
|
||||
|
||||
@@ -5,14 +5,15 @@ from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TR
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, promo_lattice
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import panic
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_unsupported_dtypes_patterns, get_transcendental_patterns
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
@@ -88,10 +89,13 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# decompositions
|
||||
supported_ops = tuple(ren.code_for_op.keys())
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
|
||||
pm_unsupported = get_unsupported_dtypes_patterns(ren.device, tuple(EMULATED_DTYPES.tolist(dtypes)))
|
||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
|
||||
sink = graph_rewrite(sink, pm_unsupported, ctx=ren.device, name="unsupported dtypes", bottom_up=True)
|
||||
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
||||
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
||||
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
|
||||
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
|
||||
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
|
||||
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
|
||||
@@ -14,8 +14,8 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
||||
|
||||
# *** helper functions for bit manipulation ***
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
||||
def exponent_bias(d:DType) -> int: return (1 << (dtypes.finfo(d.scalar())[0] - 1)) - 1
|
||||
def exponent_mask(d:DType) -> int: return (1 << dtypes.finfo(d.scalar())[0]) - 1
|
||||
|
||||
# **** utils ****
|
||||
def shr(x:UOp|int, y:UOp|int) -> UOp: return x // (2**(y.simplify().arg) if isinstance(y, UOp) else 2**y)
|
||||
@@ -378,33 +378,46 @@ def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
|
||||
|
||||
# ***** floats *****
|
||||
f2f_dt = { dtypes.half: dtypes.ushort, dtypes.float: dtypes.uint }
|
||||
f2f_dt = { f:getattr(dtypes, f"uint{f.bitsize}") for f in dtypes.floats }
|
||||
|
||||
def rne(v: UOp, s) -> UOp: return shr(v, s) + ((shr(v, s - 1) & 1) & ((v & ((1 << (s - 1)) - 1)).ne(0).cast(v.dtype) | (shr(v, s) & 1)))
|
||||
|
||||
def f2f(v, fr:DType, to:DType):
|
||||
fs, fb, (fe, fm), ts, tb, (te, tm) = fr.bitsize, exponent_bias(fr), dtypes.finfo(fr), to.bitsize, exponent_bias(to), dtypes.finfo(to)
|
||||
# NB: denormals are zero!
|
||||
if fe < te and fm < tm:
|
||||
if fe <= te and fm < tm:
|
||||
sign, nosign = shl((v & shl(1, fs-1)).cast(f2f_dt[to]), ts - fs), (v & (shl(1, fs-1) - 1)).cast(f2f_dt[to])
|
||||
exp, norm = shr(nosign, fm), shl(nosign, tm - fm) + shl(tb - fb, tm)
|
||||
inf_or_nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
|
||||
return (sign | exp.eq(0).where(0, exp.eq(shl(1, fe) - 1).where(inf_or_nan, norm))).bitcast(to)
|
||||
elif fe > te and fm > tm:
|
||||
sign, nosign, exp = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1), shr(v, fm) & (shl(1, fe) - 1)
|
||||
nan = shl(nosign, tm - fm) | shl((shl(1, te) - 1), tm)
|
||||
# fp8e4m3 has only one nan
|
||||
is_nan = (nosign.eq(shl(1, fm + fe) - 1) if fr == dtypes.fp8e4m3 else exp.eq(shl(1, fe) - 1))
|
||||
return (sign | exp.eq(0).where(0, is_nan.where(nan, norm))).bitcast(to)
|
||||
elif fe >= te and fm > tm:
|
||||
v = f2f_clamp(v.bitcast(fr), to).bitcast(f2f_dt[fr])
|
||||
sign, nosign = shr(v, fs - ts) & shl(1, ts - 1), v & (shl(1, fs - 1) - 1)
|
||||
norm = (rne(nosign, fm - tm) - shl(fb - tb, tm)).cast(f2f_dt[to])
|
||||
infnan = (sign | (shr(nosign, fm - tm) & (shl(1, tm) - 1)) | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
|
||||
underflow, overflow = exp < (1 + fb - tb), exp > (shl(1, te) - 2 + (fb - tb))
|
||||
return exp.eq(shl(1, fe) - 1).where(infnan, sign.cast(f2f_dt[to]) | underflow.where(0, overflow.where(shl(shl(1, te) - 1, tm), norm)))
|
||||
underflow = (shr(v, fm) & (shl(1, fe) - 1)) < (1 + fb - tb)
|
||||
nan_mantissa = (shl(1, tm) - 1) if to == dtypes.fp8e4m3 else (shr(nosign, fm - tm) & (shl(1, tm) - 1))
|
||||
nan = (sign | nan_mantissa | shl(shl(1, te) - 1, tm)).cast(f2f_dt[to])
|
||||
is_nan = (shr(v, fm) & (shl(1, fe) - 1)).eq(shl(1, fe) - 1)
|
||||
return is_nan.where(nan, sign.cast(f2f_dt[to]) | underflow.where(0, norm))
|
||||
else: raise NotImplementedError(f"unsupported decomp {fr} -> {to}")
|
||||
|
||||
def f2f_load(x: UOp) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=dtypes.ushort), dtypes.half, dtypes.float)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=dtypes.ushort, src=(reindex(x.src[0].src[0], i, 1),)), dtypes.half, dtypes.float) for i in range(n)))
|
||||
def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
||||
e, m = dtypes.finfo(dt)
|
||||
max_exp, max_man = ((1 << e) - 1, (1 << m) - 2) if dt == dtypes.fp8e4m3 else ((1 << e) - 2, (1 << m) - 1)
|
||||
mx = val.const_like(2.0**(max_exp - exponent_bias(dt)) * (1.0 + max_man / (1 << m)))
|
||||
sat = mx if dt in dtypes.fp8s else val.const_like(float('inf'))
|
||||
# FIXME: CMPLT of nan is undefined
|
||||
return val.ne(val).where(val, (val < -mx).where(-sat, (mx < val).where(sat, val)))
|
||||
|
||||
def f2f_store(st, idx, val):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(dtypes.uint), dtypes.float, dtypes.half)))
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(dtypes.uint), dtypes.float, dtypes.half))) for i in range(n)))
|
||||
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
|
||||
def f2f_store(st, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
|
||||
return UOp.group(*(st.replace(src=(reindex(idx, i, 1), f2f(val.gep(i).bitcast(f2f_dt[to]), to, fr))) for i in range(n)))
|
||||
|
||||
# ***** decomposition patterns *****
|
||||
|
||||
@@ -463,40 +476,44 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device:str, disable_fast_idiv
|
||||
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
@functools.cache
|
||||
def get_unsupported_dtypes_patterns(device:str, emulated_dtypes:tuple[DType, ...]) -> PatternMatcher:
|
||||
pat: list[tuple[UPat, Callable]] = []
|
||||
if not is_dtype_supported(dtypes.long, device) or dtypes.long in emulated_dtypes:
|
||||
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None)]
|
||||
pat += [(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype]))]
|
||||
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None)]
|
||||
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
|
||||
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None)]
|
||||
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag])]
|
||||
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None)]
|
||||
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
|
||||
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
|
||||
x.replace(dtype=l2i_dt[x.dtype], src=(reindex(idx, x.tag),)))]
|
||||
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))]
|
||||
if dtypes.half in emulated_dtypes:
|
||||
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=dtypes.uint16.ptr(x.dtype.size), tag=dtypes.half) if x.dtype.base == dtypes.half else None)]
|
||||
pat += [(UPat(Ops.LOAD, dtypes.half, name="x"), f2f_load)]
|
||||
pat += [(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, dtypes.half, name="ld"),), name="bc"), lambda bc,ld:
|
||||
ld.replace(dtype=dtypes.ushort).bitcast(bc.dtype))]
|
||||
pat += [(UPat(Ops.BITCAST, (dtypes.ushort, dtypes.short, dtypes.bfloat16), src=(UPat.var("x", dtypes.float),), name="bc"), lambda bc,x:
|
||||
bc.replace(src=(f2f(x.bitcast(dtypes.uint), dtypes.float, dtypes.half),)))]
|
||||
pat += [(UPat(GroupOp.All, dtypes.half, name="x"), lambda x:
|
||||
x.replace(dtype=dtypes.float.vec(x.dtype.count), src=tuple(s.cast(dtypes.float) if s.dtype == dtypes.half else s for s in x.src)))]
|
||||
pat += [(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.float)), name='st'), lambda st,idx,val:
|
||||
f2f_store(st, idx, val) if (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == dtypes.half else None)]
|
||||
return PatternMatcher(pat)
|
||||
pm_long_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
|
||||
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if hasattr(x.dtype, 'size') and x.dtype.base in l2i_dt else None),
|
||||
(UPat(Ops.INDEX, tuple(l2i_dt.keys()), name='x'), lambda x: reindex(x, x.tag).replace(dtype=l2i_dt[x.dtype])),
|
||||
(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
|
||||
st.replace(src=(reindex(idx, 0), val.rtag(0))).group(st.replace(src=(reindex(idx, 1), val.rtag(1)))) if val.tag is None else None),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
|
||||
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt))),
|
||||
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None and a.dtype not in l2i_dt else None),
|
||||
(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
(a.rtag(0).cast(dt:=l2i_dt[a.dtype]).bitcast(xdt:=l2i_dt[x.dtype]), a.rtag(1).cast(dt).bitcast(xdt))[x.tag]),
|
||||
(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
|
||||
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)) if x.dtype not in l2i_dt and a.tag is None else None),
|
||||
(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
|
||||
l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
|
||||
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag]),
|
||||
(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx: x.replace(dtype=l2i_dt[x.dtype],src=(reindex(idx, x.tag),))),
|
||||
(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
|
||||
UOp.const(dt:=l2i_dt[x.dtype], truncate[dt]((x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF))))
|
||||
])
|
||||
|
||||
# float decomposition patterns - ctx is (fr, to) tuple
|
||||
pm_float_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
|
||||
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld:
|
||||
ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x:
|
||||
bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None),
|
||||
(UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val:
|
||||
f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
|
||||
x.replace(dtype=ctx[1].vec(x.dtype.count), src=tuple(s.cast(ctx[1]) if s.dtype == ctx[0] else s for s in x.src))
|
||||
if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat(Ops.BITCAST, dtypes.floats, name="val")), name='st'), lambda ctx,st,idx,val:
|
||||
st.replace(src=(idx, val.replace(dtype=f2f_dt[ctx[0]]))) if val.dtype == ctx[0] and idx.tag == ctx[0] else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
|
||||
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user