mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
support bfloat16 for CL (#14073)
This commit is contained in:
@@ -213,7 +213,7 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) and is_dtype_supported(dtypes.float16), "bfloat16 or float16 not supported")
|
||||
class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
def test_f16_to_bf16_conversion(self):
|
||||
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
|
||||
@@ -423,7 +423,6 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
# TODO: wrong output with CL=1 on mac
|
||||
data = [60000.0, 70000.0, 80000.0]
|
||||
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
|
||||
|
||||
|
||||
@@ -351,7 +351,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device == "CUDA": return not CI and not CUDA_PTX
|
||||
if device == "NV": return not CI and not NV_PTX and not NV_NAK
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP
|
||||
return device in {"AMD", "PYTHON", "NULL"}
|
||||
return device in {"AMD", "CL", "PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8s:
|
||||
if device == "CUDA": return not CI and not CUDA_PTX
|
||||
if device == "NV": return not CI and not NV_PTX and not NV_NAK
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Literal, Callable, cast
|
||||
import os, math, sys
|
||||
import os, math, sys, struct
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
||||
@@ -298,9 +298,13 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
|
||||
dtypes.bfloat16: "ushort" }
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
# bfloat16 constants need to be rendered as their bit pattern since bf16 is stored as ushort
|
||||
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
|
||||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
|
||||
Reference in New Issue
Block a user