From 0c89340a1ed30fa5e76f5fead40931f37520391c Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Thu, 19 Mar 2026 23:31:44 -0700 Subject: [PATCH] automatically emulate unsupported (tiny) floats [skip_process_replay] (#15366) --- test/backend/test_dtype.py | 20 +++++++------------- test/backend/test_dtype_alu.py | 4 ++-- test/null/test_uop_graph.py | 2 +- test/unit/test_dtype_spec.py | 11 ++++++++--- test/unit/test_gguf.py | 3 ++- tinygrad/codegen/__init__.py | 13 ++++--------- tinygrad/device.py | 22 +++++++++++----------- tinygrad/helpers.py | 2 +- tinygrad/uop/decompositions.py | 30 ++++++++++++++++++++++++++---- 9 files changed, 62 insertions(+), 45 deletions(-) diff --git a/test/backend/test_dtype.py b/test/backend/test_dtype.py index c5fe197163..1b206d091f 100644 --- a/test/backend/test_dtype.py +++ b/test/backend/test_dtype.py @@ -18,11 +18,10 @@ settings.register_profile("my_profile", max_examples=200, deadline=None, derando settings.load_profile("my_profile") def get_available_cast_dtypes(dtype: DType) -> List[DType]: - # dont cast internal dtypes - dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] - if not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes): - if dtype in (dtypes.long, dtypes.ulong): return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support - else: return [] + dts = [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) or v in dtypes.fp8s+(dtypes.half,dtypes.bfloat16,dtypes.long)] + if dtype in (dtypes.long, dtypes.ulong) and (not is_dtype_supported(dtype) or dtypes.long in EMULATED_DTYPES.tolist(dtypes)): + return [dt for dt in dts if dt != dtypes.double] # can't bitcast with no 64-bit support + if not is_dtype_supported(dtype) and dtype not in dtypes.fp8s+(dtypes.half,dtypes.bfloat16): return [] return dts def _to_torch_storage_type(dtype:DType): @@ -60,10 +59,8 @@ class TestDType(unittest.TestCase): DATA: Any = None @classmethod def setUpClass(cls): - if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") - cls.DATA = rand_for_dtype(cls.DTYPE, 10) - def setUp(self): - if self.DTYPE is None: raise unittest.SkipTest("base class") + if cls.DTYPE is None: raise unittest.SkipTest("base class") + cls.DATA = rand_for_dtype(cls.DTYPE, 0x10, allow_subnormal=is_dtype_supported(cls.DTYPE)) def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE))) @@ -132,7 +129,6 @@ class TestDType(unittest.TestCase): def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype) - if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) _assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2]) @@ -195,7 +191,6 @@ class TestFp8sConversions(unittest.TestCase): def test_fp8e5m2fnuz_to_float(self, x): np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2fnuz), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2fnuz).float().item()) -@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported") class TestBFloat16(unittest.TestCase): def test_bf16_creation_numpy(self): data = [-1, 1, 2] @@ -215,7 +210,6 @@ class TestBFloat16(unittest.TestCase): assert t.dtype == dtypes.bfloat16 np.testing.assert_allclose(t.numpy(), np.eye(3)) -@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported") class TestBFloat16DType(unittest.TestCase): def test_bf16_to_float(self): _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32) @@ -229,7 +223,6 @@ class TestBFloat16DType(unittest.TestCase): back = t.cast(dtypes.float32) assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) -@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) and is_dtype_supported(dtypes.float16), "bfloat16 or float16 not supported") class TestBFloat16DTypeCast(unittest.TestCase): def test_f16_to_bf16_conversion(self): original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16) @@ -278,6 +271,7 @@ class TestFloatDType(TestDType): _test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32, [0, 0, 1]) +@unittest.skipUnless(is_dtype_supported(dtypes.double), f"no double on {Device.DEFAULT}") class TestDoubleDType(TestDType): DTYPE = dtypes.double @unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or \ diff --git a/test/backend/test_dtype_alu.py b/test/backend/test_dtype_alu.py index 9e4a51e88a..aeb10b6401 100644 --- a/test/backend/test_dtype_alu.py +++ b/test/backend/test_dtype_alu.py @@ -81,8 +81,8 @@ def universal_test_unary(a, dtype, op): if op[0] == Tensor.cos and abs(a) > 30: return if op[0] == Tensor.log and a <= 0: return if dtype in dtypes.fp8s: - # normals are zero - if dtype in EMULATED_DTYPES.tolist(dtypes) and abs(ta.numpy().item()) < 0.015625: return + # denormals are zero + if dtype in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dtype) and abs(ta.numpy().item()) < 0.015625: return tensor_value = fp8_to_float(op[0](ta.realize()).bitcast(dtypes.uint8).item(), dtype) numpy_value = truncate[dtype](v:=op[1](ta.numpy()).item()) # cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2) diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index b6aa0116f0..0071fd8617 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -410,7 +410,7 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.PARAM, dt.ptr(), arg=0) v = d0.index(UOp.const(dtypes.int, 0)) uops = to_uops_list([v.bitcast(dt)]) - self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}") + self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST and x.dtype is dt]), 0, f"dtype = {dt}") def test_sub_with_cast_folds(self): a = Variable("a", 0, 5) diff --git a/test/unit/test_dtype_spec.py b/test/unit/test_dtype_spec.py index 00b3ddd3ab..065a61b7d0 100644 --- a/test/unit/test_dtype_spec.py +++ b/test/unit/test_dtype_spec.py @@ -2,7 +2,7 @@ import unittest, math, subprocess from tinygrad.tensor import Tensor, dtypes, Device from tinygrad.dtype import DType, DTYPES_DICT from tinygrad.device import is_dtype_supported -from tinygrad.helpers import getenv, DEBUG +from tinygrad.helpers import getenv, DEBUG, EMULATED_DTYPES from test.helpers import slow from hypothesis import given, settings, strategies as strat import numpy as np @@ -24,8 +24,13 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, - dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1, dtypes.fp8e4m3fnuz:1e-1, dtypes.fp8e5m2fnuz:5e-1}.get(target_dtype, tol_target_dtype)) + # denormals are zero + if target_dtype in dtypes.floats and (not is_dtype_supported(target_dtype) or target_dtype in EMULATED_DTYPES.tolist(dtypes)): + fe, fm = dtypes.finfo(target_dtype) + kwargs = {"atol":2 ** (2 - (1 << (fe - 1))), "rtol": 2 ** (-fm)} + else: kwargs = {"rtol": {dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1, + dtypes.fp8e4m3fnuz:1e-1, dtypes.fp8e5m2fnuz:5e-1}.get(target_dtype, tol_target_dtype)} + np.testing.assert_allclose(tensor.numpy(), target, **kwargs) except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e diff --git a/test/unit/test_gguf.py b/test/unit/test_gguf.py index 6c22143125..5b7f2f4de5 100644 --- a/test/unit/test_gguf.py +++ b/test/unit/test_gguf.py @@ -148,7 +148,8 @@ class TestGGUFGEMV(unittest.TestCase): x = rng.standard_normal(cols).astype(np.float32) np.testing.assert_allclose((tensors["weight"] @ Tensor(x)).numpy(), ref @ x, atol=1e-2, rtol=1e-2) - np.testing.assert_equal(tensors["weight"].numpy(), ref) + # can only expect the weights to be identical if we really support float16 (ie. not decompositions) + if is_dtype_supported(dtypes.half): np.testing.assert_equal(tensors["weight"].numpy(), ref) assert np.isfinite(ref).all() and np.isfinite(tensors["weight"].numpy()).all(), f"{qtype.name} has NaN/Inf" def test_gguf_gemv_q8_0(self): self._test_gguf_gemv(GGMLQuantizationType.Q8_0) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index aaf23c8bcd..6670ce98c3 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,19 +1,18 @@ from typing import cast from dataclasses import replace import itertools -from tinygrad.helpers import DISABLE_FAST_IDIV, EMULATED_DTYPES, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, TracingKey, Context +from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, TracingKey, Context from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer, ProgramSpec, Estimates -from tinygrad.dtype import dtypes, promo_lattice -from tinygrad.device import is_dtype_supported +from tinygrad.dtype import dtypes from tinygrad.helpers import panic from tinygrad.codegen.opt import Opt # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load -from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_float_decomp, pm_long_decomp +from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads @@ -92,11 +91,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, bool(DISABLE_FAST_IDIV)) pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2) sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions") - if not is_dtype_supported(dtypes.long, ren.device) or dtypes.long in EMULATED_DTYPES.tolist(dtypes): - sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True) - for fr, to in [(fr, next((to for to in promo_lattice[fr] if is_dtype_supported(to, ren.device)), dtypes.float)) - for fr in EMULATED_DTYPES.tolist(dtypes) if fr in dtypes.floats]: - sink = graph_rewrite(sink, pm_float_decomp, ctx=(fr, to), name=f"decomp {fr} -> {to}", bottom_up=True) + sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren.device), name="decomp dtypes") sink = graph_rewrite(sink, pm_transcendental, ctx=ren.device, name="transcendental") # final rules for the renderer (without sym) diff --git a/tinygrad/device.py b/tinygrad/device.py index 97190fc528..d511839ea0 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,10 +3,10 @@ from dataclasses import dataclass, replace from collections import defaultdict from typing import Any, Generic, TypeVar, Iterator, Generator, TYPE_CHECKING import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal -from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored +from tinygrad.helpers import BENCHMARKS, CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK -from tinygrad.helpers import EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey, size_to_str +from tinygrad.helpers import EMULATE, EMULATED_DTYPES, NULL_IR3, NULL_QCOMCL, TracingKey, size_to_str from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -333,15 +333,15 @@ class Compiled: def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: - if device == "METAL": return not CI - if device == "CUDA": return not CI and not CUDA_PTX - if device == "NV": return not CI and not NV_PTX and not NV_NAK - if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP + if device == "METAL": return not CI or BENCHMARKS + if device == "CUDA": return (not CI or BENCHMARKS) and not CUDA_PTX + if device == "NV": return (not CI or BENCHMARKS) and not NV_PTX and not NV_NAK + if device in {"CPU"}: return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP return device in {"AMD", "CL", "PYTHON", "NULL"} if dtype in dtypes.fp8_ocp: - if device == "CUDA": return not CI and not CUDA_PTX - if device == "NV": return not CI and not NV_PTX and not NV_NAK - if device == "AMD": return not CI and getattr(Device["AMD"], "target") == (9,5,0) + if device == "CUDA": return (not CI or BENCHMARKS) and not CUDA_PTX + if device == "NV": return (not CI or BENCHMARKS) and not NV_PTX and not NV_NAK + if device == "AMD": return (not CI or BENCHMARKS) and getattr(Device["AMD"], "target") == (9,5,0) return device in {"PYTHON", "NULL"} if dtype in dtypes.fp8_fnuz: return device in {"PYTHON", "NULL"} if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, @@ -352,9 +352,9 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: # PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751 # double can't be bitcast to anything without long support if dtype == dtypes.half: - if device == "CL": return not CI and not OSX + if device == "CL": return (not CI or BENCHMARKS) and not OSX if device == "QCOM": return False # QCOM compiler is flaky with half - if device in ["CUDA", "NV"]: return not CI + if device in ["CUDA", "NV"]: return (not CI or BENCHMARKS) or "CUDA" in EMULATE.value if device == "CPU" and CPU_LLVM: return OSX if device == "PYTHON": return sys.version_info >= (3, 12) if dtype == dtypes.float64: return (device not in {"METAL", "QCOM"} and not (OSX and device == "CL") and not NULL_IR3 and not NULL_QCOMCL diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index b43ef6666a..e92f4d8b11 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -14,7 +14,7 @@ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1) # NOTE: helpers is not allowed to import from anything else in tinygrad OSX, WIN = platform.system() == "Darwin", sys.platform == "win32" -CI = os.getenv("CI", "") != "" +CI, BENCHMARKS = os.getenv("CI", "") != "", os.getenv("RUNNER_ENVIRONMENT", "") == "self-hosted" ARCH_X86 = any(x in platform.processor() for x in ("Intel", "i386", "x86_64")) BASEDIR = pathlib.Path(__file__).parent diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 4ef918f749..def170ed2e 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -2,9 +2,9 @@ from typing import Callable import math, functools from tinygrad.dtype import dtypes, DType, promo_lattice, truncate from tinygrad.device import is_dtype_supported -from tinygrad.helpers import flatten, polyN +from tinygrad.helpers import flatten, polyN, EMULATED_DTYPES from tinygrad.uop import GroupOp -from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher +from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64) @@ -319,7 +319,7 @@ def threefry2x32(x: UOp, key: UOp): l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint} def unpack32(v:UOp) -> tuple[UOp, UOp]: return v.bitcast(dtypes.uint) & 0xFFFF, shr(v.bitcast(dtypes.uint), 16) -def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off)) +def reindex(idx:UOp, off:int, mul=2) -> UOp: return idx.replace(src=(idx.src[0], idx.src[1]*mul+off, *idx.src[2:])) # 4.3.1 is the relevant section in TAOCP def l2i(op: Ops, dt: DType, *uops:UOp): @@ -511,10 +511,15 @@ pm_float_decomp = PatternMatcher([ (UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda ctx,x: x.replace(dtype=f2f_dt[ctx[0]].ptr(x.dtype.size), tag=ctx[0]) if x.dtype.base == ctx[0] else None), (UPat(Ops.LOAD, dtypes.floats, name="x"), lambda ctx,x: f2f_load(x, *ctx) if x.dtype.scalar() == ctx[0] else None), + # bitcasted load should just replace load (UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"), lambda ctx,bc,ld: - ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype.bitsize == ctx[0].bitsize else None), + ld.replace(dtype=f2f_dt[ctx[0]]).bitcast(bc.dtype) if ld.dtype == ctx[0] else None), + # bitcast from (UPat(Ops.BITCAST, src=(UPat.var("x", dtypes.floats),), name="bc"), lambda ctx,bc,x: bc.replace(src=(f2f(x.bitcast(f2f_dt[ctx[1]]), ctx[1], ctx[0]),)) if x.dtype == ctx[1] and bc.dtype.bitsize == ctx[0].bitsize else None), + # bitcast to + (UPat(Ops.BITCAST, src=(UPat.var("x"),), name="bc"), lambda ctx,bc,x: + f2f(x.bitcast(f2f_dt[ctx[0]]), ctx[0], ctx[1]) if bc.dtype == ctx[0] else None), (UPat(Ops.CAST, dtypes.floats, src=(UPat.var("val"),), name="x"), lambda ctx,x,val: f2f_clamp(val.cast(ctx[1]), ctx[0]) if x.dtype.scalar() == ctx[0] else None), (UPat(GroupOp.All-{Ops.BITCAST}, dtypes.floats, name="x"), lambda ctx,x: @@ -525,3 +530,20 @@ pm_float_decomp = PatternMatcher([ (UPat(Ops.STORE, src=(UPat.var("idx"), UPat.var("val", dtypes.floats)), name='st'), lambda ctx,st,idx,val: f2f_store(st, idx, val, *ctx) if val.dtype.scalar() == ctx[1] and (idx:=idx.src[0] if idx.op == Ops.CAST else idx).tag == ctx[0] else None), ]) + +def do_dtype_decomps(sink:UOp, ctx:tuple[set[DType], str]) -> UOp: + def _should_emulate(dt): return dt in EMULATED_DTYPES.tolist(dtypes) or not is_dtype_supported(dt, ctx[1]) + for fr in filter(_should_emulate, ctx[0]): + if fr in dtypes.floats: + to = dtypes.half if is_dtype_supported(dtypes.half, ctx[1]) and fr in dtypes.fp8s else dtypes.float + sink = graph_rewrite(sink, pm_float_decomp, name=f"decomp {fr} -> {to}", ctx=(fr, to), bottom_up=True) + else: sink = graph_rewrite(sink, pm_long_decomp, name="decomp long -> int", bottom_up=True) + return sink + +pm_dtype_decomps = PatternMatcher([ + # detect dtypes to decompose + (UPat(GroupOp.All, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half, dtypes.long, dtypes.ulong), name="x"), lambda x,ctx: + ctx[0].add({dtypes.ulong:dtypes.long}.get(dt:=x.dtype.base.scalar(), dt))), + # do the rewrites + (UPat(Ops.SINK, name="sink"), do_dtype_decomps), +])