support bfloat16 for CL (#14073)

This commit is contained in:
chenyu
2026-01-08 14:14:29 -05:00
committed by GitHub
parent 16b652302e
commit b34c637767
3 changed files with 8 additions and 5 deletions

View File

@@ -213,7 +213,7 @@ class TestBFloat16DType(unittest.TestCase):
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) and is_dtype_supported(dtypes.float16), "bfloat16 or float16 not supported")
class TestBFloat16DTypeCast(unittest.TestCase):
def test_f16_to_bf16_conversion(self):
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
@@ -423,7 +423,6 @@ class TestDtypeUsage(unittest.TestCase):
class TestOpsBFloat16(unittest.TestCase):
def test_cast(self):
# TODO: helper_test_op breaks in unrelated part
# TODO: wrong output with CL=1 on mac
data = [60000.0, 70000.0, 80000.0]
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())

View File

@@ -351,7 +351,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device == "CUDA": return not CI and not CUDA_PTX
if device == "NV": return not CI and not NV_PTX and not NV_NAK
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP
return device in {"AMD", "PYTHON", "NULL"}
return device in {"AMD", "CL", "PYTHON", "NULL"}
if dtype in dtypes.fp8s:
if device == "CUDA": return not CI and not CUDA_PTX
if device == "NV": return not CI and not NV_PTX and not NV_NAK

View File

@@ -1,10 +1,10 @@
from typing import Literal, Callable, cast
import os, math, sys
import os, math, sys, struct
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@@ -298,9 +298,13 @@ class OpenCLRenderer(CStyleLanguage):
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
dtypes.bfloat16: "ushort" }
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
# bfloat16 constants need to be rendered as their bit pattern since bf16 is stored as ushort
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),