python bfloat16 (#11912)

* python bf16

* _to_torch_storage_type

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-08-30 03:18:02 +08:00
committed by GitHub
parent afad7d0cd1
commit b2cc06218a
4 changed files with 48 additions and 27 deletions

View File

@@ -4,25 +4,35 @@
# this is the (living) definition of uops
from typing import Any, TYPE_CHECKING
import pickle, base64, itertools, time, struct, sys
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
def _load(m, i):
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt
def to_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
return x
def from_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
return x
def _load(m, i, dtype: DType):
if i is None: return 0.0
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]
return from_storage_scalar(m[i], dtype)
def load(inp, j=0):
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]]
def load(inp, j, dtype: DType):
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
def _store(m, i, v):
def _store(m, i, v, dtype: DType):
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
m[i] = v
m[i] = to_storage_scalar(v, dtype)
class PythonProgram:
def __init__(self, name:str, lib:bytes):
@@ -57,19 +67,20 @@ class PythonProgram:
if uop is Ops.STORE:
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
for (m,o,g),v in zip(inp[0], val):
if g: _store(m, o+j, v)
if g: _store(m, o+j, v, dtp[1].scalar())
i += 1
continue
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
# REGs are per thread
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)]
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
ul[i] = [buf.cast(dtype.fmt)] * warp_size
ul[i] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
@@ -98,16 +109,17 @@ class PythonProgram:
continue
elif uop is Ops.VECTORIZE: ul[i] = inp
elif uop is Ops.BITCAST:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]])
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]]
elif uop is Ops.CAST:
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \
for j in range(dtype.count)]
else:
ul[i] = load(inp)
ul[i] = load(inp, 0, dtype)
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware