mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
automatically emulate unsupported (tiny) floats [skip_process_replay] (#15366)
This commit is contained in:
committed by
GitHub
parent
78ad089817
commit
0c89340a1e
@@ -18,11 +18,10 @@ settings.register_profile("my_profile", max_examples=200, deadline=None, derando
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
# dont cast internal dtypes
|
||||
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
if not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
||||
if dtype in (dtypes.long, dtypes.ulong): return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support
|
||||
else: return []
|
||||
dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)]
|
||||
if dtype in (dtypes.long, dtypes.ulong) and (not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes)):
|
||||
return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support
|
||||
if not is_dtype_supported(dtype) and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return []
|
||||
return dts
|
||||
|
||||
def _to_torch_storage_type(dtype:DType):
|
||||
@@ -60,10 +59,8 @@ class TestDType(unittest.TestCase):
|
||||
DATA: Any = None
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
def setUp(self):
|
||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
if cls.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=is_dtype_supported(cls.DTYPE))
|
||||
|
||||
def test_to_np(self):
|
||||
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
|
||||
@@ -132,7 +129,6 @@ class TestDType(unittest.TestCase):
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
|
||||
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
|
||||
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
|
||||
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
|
||||
_assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2])
|
||||
@@ -195,7 +191,6 @@ class TestFp8sConversions(unittest.TestCase):
|
||||
def test_fp8e5m2fnuz_to_float(self, x):
|
||||
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2fnuz), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2fnuz).float().item())
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
data = [-1, 1, 2]
|
||||
@@ -215,7 +210,6 @@ class TestBFloat16(unittest.TestCase):
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
@@ -229,7 +223,6 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) and is_dtype_supported(dtypes.float16), "bfloat16 or float16 not supported")
|
||||
class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
def test_f16_to_bf16_conversion(self):
|
||||
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
|
||||
@@ -278,6 +271,7 @@ class TestFloatDType(TestDType):
|
||||
_test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32,
|
||||
[0, 0, 1])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.double), f"no double on {Device.DEFAULT}")
|
||||
class TestDoubleDType(TestDType):
|
||||
DTYPE = dtypes.double
|
||||
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \
|
||||
|
||||
@@ -81,8 +81,8 @@ def universal_test_unary(a, dtype, op):
|
||||
if op[0] == Tensor.cos and abs(a) > 30: return
|
||||
if op[0] == Tensor.log and a <= 0: return
|
||||
if dtype in dtypes.fp8s:
|
||||
# normals are zero
|
||||
if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return
|
||||
# denormals are zero
|
||||
if dtype in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dtype) and abs(ta.numpy().item()) < 0.015625: return
|
||||
tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype)
|
||||
numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item())
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
|
||||
@@ -410,7 +410,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.PARAM, dt.ptr(), arg=0)
|
||||
v = d0.index(UOp.const(dtypes.int, 0))
|
||||
uops = to_uops_list([v.bitcast(dt)])
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}")
|
||||
|
||||
def test_sub_with_cast_folds(self):
|
||||
a = Variable("a", 0, 5)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest, math, subprocess
|
||||
from tinygrad.tensor import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, DTYPES_DICT
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.helpers import getenv, DEBUG, EMULATED_DTYPES
|
||||
from test.helpers import slow
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
import numpy as np
|
||||
@@ -24,8 +24,13 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
|
||||
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1, dtypes.fp8e4m3fnuz:1e-1, dtypes.fp8e5m2fnuz:5e-1}.get(target_dtype, tol_target_dtype))
|
||||
# denormals are zero
|
||||
if target_dtype in dtypes.floats and (not is_dtype_supported(target_dtype) or target_dtype in EMULATED_DTYPES.tolist(dtypes)):
|
||||
fe, fm = dtypes.finfo(target_dtype)
|
||||
kwargs = {"atol":2 ** (2 - (1 << (fe - 1))), "rtol": 2 ** (-fm)}
|
||||
else: kwargs = {"rtol": {dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1,
|
||||
dtypes.fp8e4m3fnuz:1e-1, dtypes.fp8e5m2fnuz:5e-1}.get(target_dtype, tol_target_dtype)}
|
||||
np.testing.assert_allclose(tensor.numpy(), target, **kwargs)
|
||||
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
@@ -148,7 +148,8 @@ class TestGGUFGEMV(unittest.TestCase):
|
||||
|
||||
x = rng.standard_normal(cols).astype(np.float32)
|
||||
np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2)
|
||||
np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
# can only expect the weights to be identical if we really support float16 (ie. not decompositions)
|
||||
if is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref)
|
||||
assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf"
|
||||
|
||||
def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from typing import cast
|
||||
from dataclasses import replace
|
||||
import itertools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, TracingKey, Context
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, TracingKey, Context
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.dtype import dtypes, promo_lattice
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import panic
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
@@ -92,11 +91,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV))
|
||||
pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)
|
||||
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
|
||||
if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes):
|
||||
sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
||||
for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float))
|
||||
for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]:
|
||||
sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True)
|
||||
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.device), name="decomp dtypes")
|
||||
sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
|
||||
@@ -3,10 +3,10 @@ from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator, Generator, TYPE_CHECKING
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
|
||||
from tinygrad.helpers import BENCHMARKS, CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
|
||||
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK
|
||||
from tinygrad.helpers import EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey, size_to_str
|
||||
from tinygrad.helpers import EMULATE, EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey, size_to_str
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -333,15 +333,15 @@ class Compiled:
|
||||
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
if device == "METAL": return not CI
|
||||
if device == "CUDA": return not CI and not CUDA_PTX
|
||||
if device == "NV": return not CI and not NV_PTX and not NV_NAK
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP
|
||||
if device == "METAL": return not CI or BENCHMARKS
|
||||
if device == "CUDA": return (not CI or BENCHMARKS) and not CUDA_PTX
|
||||
if device == "NV": return (not CI or BENCHMARKS) and not NV_PTX and not NV_NAK
|
||||
if device in {"CPU"}: return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP
|
||||
return device in {"AMD", "CL", "PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8_ocp:
|
||||
if device == "CUDA": return not CI and not CUDA_PTX
|
||||
if device == "NV": return not CI and not NV_PTX and not NV_NAK
|
||||
if device == "AMD": return not CI and getattr(Device["AMD"], "target") == (9,5,0)
|
||||
if device == "CUDA": return (not CI or BENCHMARKS) and not CUDA_PTX
|
||||
if device == "NV": return (not CI or BENCHMARKS) and not NV_PTX and not NV_NAK
|
||||
if device == "AMD": return (not CI or BENCHMARKS) and getattr(Device["AMD"], "target") == (9,5,0)
|
||||
return device in {"PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8_fnuz: return device in {"PYTHON", "NULL"}
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
@@ -352,9 +352,9 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
||||
# double can't be bitcast to anything without long support
|
||||
if dtype == dtypes.half:
|
||||
if device == "CL": return not CI and not OSX
|
||||
if device == "CL": return (not CI or BENCHMARKS) and not OSX
|
||||
if device == "QCOM": return False # QCOM compiler is flaky with half
|
||||
if device in ["CUDA", "NV"]: return not CI
|
||||
if device in ["CUDA", "NV"]: return (not CI or BENCHMARKS) or "CUDA" in EMULATE.value
|
||||
if device == "CPU" and CPU_LLVM: return OSX
|
||||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return (device not in {"METAL", "QCOM"} and not (OSX and device == "CL") and not NULL_IR3 and not NULL_QCOMCL
|
||||
|
||||
@@ -14,7 +14,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
|
||||
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX, WIN = platform.system() == "Darwin", sys.platform == "win32"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
CI, BENCHMARKS = os.getenv("CI", "") != "", os.getenv("RUNNER_ENVIRONMENT", "") == "self-hosted"
|
||||
ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64"))
|
||||
BASEDIR = pathlib.Path(__file__).parent
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ from typing import Callable
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, DType, promo_lattice, truncate
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import flatten, polyN
|
||||
from tinygrad.helpers import flatten, polyN, EMULATED_DTYPES
|
||||
from tinygrad.uop import GroupOp
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
|
||||
|
||||
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
|
||||
@@ -319,7 +319,7 @@ def threefry2x32(x: UOp, key: UOp):
|
||||
|
||||
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
|
||||
def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, shr(v.bitcast(dtypes.uint), 16)
|
||||
def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off))
|
||||
def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off, *idx.src[2:]))
|
||||
|
||||
# 4.3.1 is the relevant section in TAOCP
|
||||
def l2i(op: Ops, dt: DType, *uops:UOp):
|
||||
@@ -511,10 +511,15 @@ pm_float_decomp = PatternMatcher([
|
||||
(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x:
|
||||
x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None),
|
||||
(UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None),
|
||||
# bitcasted load should just replace load
|
||||
(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld:
|
||||
ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None),
|
||||
ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype == ctx[0] else None),
|
||||
# bitcast from
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x:
|
||||
bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None),
|
||||
# bitcast to
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("x"),), name="bc"), lambda ctx,bc,x:
|
||||
f2f(x.bitcast(f2f_dt[ctx[0]]), ctx[0], ctx[1]) if bc.dtype == ctx[0] else None),
|
||||
(UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val:
|
||||
f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None),
|
||||
(UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x:
|
||||
@@ -525,3 +530,20 @@ pm_float_decomp = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val:
|
||||
f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None),
|
||||
])
|
||||
|
||||
def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], str]) -> UOp:
|
||||
def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dt, ctx[1])
|
||||
for fr in filter(_should_emulate, ctx[0]):
|
||||
if fr in dtypes.floats:
|
||||
to = dtypes.half if is_dtype_supported(dtypes.half, ctx[1]) and fr in dtypes.fp8s else dtypes.float
|
||||
sink = graph_rewrite(sink, pm_float_decomp, name=f"decomp {fr} -> {to}", ctx=(fr, to), bottom_up=True)
|
||||
else: sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True)
|
||||
return sink
|
||||
|
||||
pm_dtype_decomps = PatternMatcher([
|
||||
# detect dtypes to decompose
|
||||
(UPat(GroupOp.All, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half, dtypes.long, dtypes.ulong), name="x"), lambda x,ctx:
|
||||
ctx[0].add({dtypes.ulong:dtypes.long}.get(dt:=x.dtype.base.scalar(), dt))),
|
||||
# do the rewrites
|
||||
(UPat(Ops.SINK, name="sink"), do_dtype_decomps),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user