Files
tinygrad/extra/fp8/fp8_linear.py
George Hotz 183d38b128 remove CUSTOM_KERNEL / directly construct it (#14604)
* remove CUSTOM_KERNEL / directly construct it

* clean that up

* simpler multi

* custom kernel spec

* remove Kernel

* fix multi

* use sharded shape

* explicit regression test
2026-02-08 18:43:33 +08:00

102 lines
5.1 KiB
Python

from typing import Callable, Any
from tinygrad import Tensor, dtypes, nn, UOp
from tinygrad.uop.ops import KernelInfo, AxisType, Ops
def quantize_to_fp8(x: Tensor, dtype=dtypes.fp8e4m3):
fp8_min = -448.0 if dtype == dtypes.fp8e4m3 else -57344.0
fp8_max = 448.0 if dtype == dtypes.fp8e4m3 else 57344.0
x_abs_max = x.abs().max().detach()
scale = fp8_max / (x_abs_max + 1e-8)
x_scaled = x * scale
x_det = x_scaled.detach()
x_clamped = x_det.clamp(fp8_min, fp8_max)
x_clamped_ste = x_scaled + (x_clamped - x_det)
res = x_clamped_ste.cast(dtype)
return res, scale.float().reciprocal()
def custom_matmul(output: UOp, inp: UOp, weight: UOp) -> UOp:
SEQ = inp.shape[1]
OUT = weight.shape[0]
IN = weight.shape[-1]
seq_idx = UOp.range(SEQ, 2, AxisType.LOOP)
out_idx = UOp.range(OUT, 3, AxisType.LOOP)
batch_idx = UOp.range(output.size//SEQ//OUT, 1, AxisType.LOOP)
reduce_idx = UOp.range(IN, 0, AxisType.REDUCE)
product = (inp.index((seq_idx*IN+reduce_idx+batch_idx*IN*SEQ)) * weight.index((out_idx*IN+reduce_idx))).cast(dtypes.float)
reduced = product.reduce(reduce_idx, arg=Ops.ADD)
store_op = output.index((seq_idx*OUT+out_idx+batch_idx*OUT*SEQ), ptr=True).store(reduced).end(batch_idx, seq_idx, out_idx)
return store_op.sink(arg=KernelInfo(name=f"fp8_matmul_{inp.shape}x{weight.shape}"))
def custom_matmul_backward(gradient: UOp, kernel: UOp) -> tuple[UOp, UOp]:
_, input_uop, weight_uop = kernel.src[1:]
input_tensor = Tensor(input_uop, device=input_uop.device)
grad_tensor = Tensor(gradient, device=gradient.device)
weight_tensor = Tensor(weight_uop, device=weight_uop.device)
grad_quantized, scale = quantize_to_fp8(grad_tensor)
scale_scalar = scale.reshape(())
grad_weight = Tensor.einsum("bso,bsi->oi", grad_quantized, input_tensor, dtype=dtypes.float)
grad_weight = grad_weight * scale_scalar
grad_2d = grad_quantized.reshape(grad_tensor.shape[0] * grad_tensor.shape[1], grad_tensor.shape[-1])
grad_input = (grad_2d.dot(weight_tensor, dtype=dtypes.float)).contiguous().reshape(input_tensor.shape) * scale
return (None, grad_input.uop, grad_weight.uop)
class FP8Linear:
def __init__(self, in_features:int, out_features:int, bias:bool=True):
self.weight = Tensor.empty(out_features, in_features, dtype=dtypes.float32)
self.bias = Tensor.empty(out_features, dtype=dtypes.float32) if bias else None
def __call__(self, x: Tensor) -> Tensor:
original_ndim = len(x.shape)
if original_ndim == 2: x = x.reshape(x.shape[0], 1, x.shape[1])
batch, seq, _ = x.shape
w_fp8, w_scale = quantize_to_fp8(self.weight)
x_fp8, x_scale = quantize_to_fp8(x)
GPUS = self.weight.device
if isinstance(GPUS, tuple) and len(GPUS) > 1:
y = Tensor(Tensor.empty((batch//len(GPUS), seq, self.weight.shape[0]), dtype=dtypes.float, device=GPUS).uop.multi(0), device=GPUS)
else:
y = Tensor.empty((batch, seq, self.weight.shape[0]), dtype=dtypes.float)
y = Tensor.custom_kernel(y, x_fp8, w_fp8, fxn=custom_matmul, grad_fxn=custom_matmul_backward)[0]
y = y * w_scale * x_scale
if self.bias is not None: y = y + self.bias
if original_ndim == 2: y = y.reshape(batch, self.weight.shape[0])
return y.cast(x.dtype)
def _replace_linear(layer: nn.Linear):
fp8_linear = FP8Linear(layer.weight.shape[1], layer.weight.shape[0], layer.bias is not None)
fp8_linear.weight = layer.weight
if layer.bias is not None: fp8_linear.bias = layer.bias
return fp8_linear
def _swap_linear_with_fp8(model, module_filter_fn:Callable[[Any, str],bool]|None=None, fqn:str="", parent:Any|None=None,
attr_name:str="", visited:set|None=None):
if visited is None: visited = set()
if id(model) in visited: return
visited.add(id(model))
if isinstance(model, (str, int, float, bool, type(None), Tensor, UOp)): return
elif isinstance(model, nn.Linear):
if module_filter_fn is not None and not module_filter_fn(model, fqn): return
fp8_linear = _replace_linear(model)
if parent is not None and attr_name:
setattr(parent, attr_name, fp8_linear)
elif isinstance(model, list):
for i, item in enumerate(model):
child_fqn = f"{fqn}.{i}" if fqn else str(i)
if isinstance(item, nn.Linear) and (module_filter_fn is None or module_filter_fn(item, child_fqn)): model[i] = _replace_linear(item)
else: _swap_linear_with_fp8(item, module_filter_fn, child_fqn, None, "", visited)
elif isinstance(model, dict):
for key, item in list(model.items()):
child_fqn = f"{fqn}.{key}" if fqn else str(key)
if isinstance(item, nn.Linear) and (module_filter_fn is None or module_filter_fn(item, child_fqn)): model[key] = _replace_linear(item)
else: _swap_linear_with_fp8(item, module_filter_fn, child_fqn, None, "", visited)
elif hasattr(model, "__dict__"):
for attr_key in list(vars(model).keys()):
try: attr = getattr(model, attr_key)
except Exception: continue
child_fqn = f"{fqn}.{attr_key}" if fqn else attr_key
_swap_linear_with_fp8(attr, module_filter_fn, child_fqn, model, attr_key, visited)
def convert_to_float8_training(model, module_filter_fn:Callable[[Any,str],bool]|None=None):
_swap_linear_with_fp8(model, module_filter_fn, "", None, "")
return model