mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
render cast is rewrite rules [run_process_replay] (#6772)
* render cast is rewrite rules [run_process_replay] * move load/store to rewrite rules * render_alu smaller * render_gep
This commit is contained in:
@@ -7,57 +7,25 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
|
||||
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
|
||||
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
|
||||
return val
|
||||
|
||||
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp) -> str:
|
||||
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
|
||||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
|
||||
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
|
||||
return val
|
||||
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
|
||||
def render_alu(r:CStyleLanguage, x:UOp):
|
||||
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
|
||||
else: operands = [r[v] for v in x.src]
|
||||
return r.code_for_op[x.arg](*operands, x.dtype)
|
||||
|
||||
def render_gep(r:CStyleLanguage, x:UOp):
|
||||
from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \
|
||||
or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")
|
||||
|
||||
base_pm = PatternMatcher([
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
|
||||
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
|
||||
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
|
||||
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
|
||||
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
|
||||
# load/store image
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
|
||||
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
|
||||
# r method accesses
|
||||
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
|
||||
(UPat(UOps.VECTORIZE, name="x"),
|
||||
lambda r,x: f"{r.float4.replace('float4', r.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([r[y] for y in x.src])}}}" if r.device == "CLANG" else f"({','.join([r[y] for y in x.src])})")),
|
||||
(UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)),
|
||||
(UPat(UOps.CAST, name="x"), lambda r,x: f"({r.render_dtype(x.dtype)})({r[x.src[0]]})"),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"(*(({r.buffer_prefix}{r.render_dtype(x.dtype)}*)&{r[x.src[0]]}))"),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"),
|
||||
(UPat(UOps.BARRIER), lambda r: r.barrier),
|
||||
(UPat(UOps.NOOP, name="x"), lambda r,x: r[x.src[0]]),
|
||||
@@ -72,11 +40,18 @@ base_pm = PatternMatcher([
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"),
|
||||
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
|
||||
# function calls
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
|
||||
(UPat(UOps.ALU, name="x"), render_alu),
|
||||
(UPat(UOps.GEP, name="x"), render_gep),
|
||||
# load/store
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
|
||||
lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
|
||||
lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
|
||||
# alu/gep
|
||||
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
|
||||
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
|
||||
(UPat(UOps.GEP, name="x"), lambda r,x: r[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
@@ -120,13 +95,9 @@ class CStyleLanguage(Renderer):
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
|
||||
return f"({self.render_dtype(var_dtype)})({x})"
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
@@ -174,7 +145,7 @@ class CStyleLanguage(Renderer):
|
||||
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, base_pm.rewrite(u, ctx=self))
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
|
||||
@@ -242,8 +213,17 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
float4 = "(float4)"
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))),
|
||||
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
||||
@@ -253,12 +233,14 @@ class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
||||
tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
|
||||
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda r,x: f"intel_convert_bfloat16_as_ushort({r[x[0]]})"),
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_dtype(self, var_dtype:DType) -> str:
|
||||
return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype)
|
||||
def render_cast(self, x, var_dtype, bitcast=False, from_dtype=None) -> str:
|
||||
return f"intel_convert_bfloat16_as_ushort({x[0]})" if (var_dtype, from_dtype) == (dtypes.bfloat16, dtypes.float) else \
|
||||
(f"intel_convert_as_bfloat16_float({x[0]})" if (var_dtype, from_dtype) == (dtypes.float, dtypes.bfloat16) else \
|
||||
super().render_cast(x, var_dtype, bitcast))
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = []
|
||||
@@ -296,8 +278,9 @@ class MetalRenderer(CStyleLanguage):
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
||||
|
||||
Reference in New Issue
Block a user