mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-09 14:15:22 -05:00
python float8 support (#11960)
* basic support * alu * nan in exec_alu * rand_for_dtype * inf + 0.0 * finfo * revert rand_for_dtype * clean * truncate fp8s inf * spec ok * float_to_fp8 nan/inf * least_upper_dtype * clean up --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -4,21 +4,23 @@
|
||||
# this is the (living) definition of uops
|
||||
from typing import Any, TYPE_CHECKING, cast
|
||||
import pickle, base64, itertools, time, struct, sys
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
|
||||
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else dtype.fmt
|
||||
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
|
||||
|
||||
def to_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
|
||||
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
|
||||
return x
|
||||
|
||||
def from_storage_scalar(x, dtype: DType):
|
||||
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
|
||||
if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype)
|
||||
return x
|
||||
|
||||
def _load(m, i, dtype: DType):
|
||||
|
||||
Reference in New Issue
Block a user