diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e86a11211f..bf13b407ba 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -7,57 +7,25 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore -def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: - sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]] - if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType): - val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))" - else: - val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]" - - # NOTE: this relies on the load not happening if it's in the unselected branch - if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype) - return val - -def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp) -> str: +def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType): sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx] - if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType): - prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix - val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};" - else: - val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};" - return val + if dtype.count > 1 and isinstance(buf.dtype, PtrDType): + return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))" + return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]" -def render_alu(r:CStyleLanguage, x:UOp): - if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src] - else: operands = [r[v] for v in x.src] - return r.code_for_op[x.arg](*operands, x.dtype) - -def render_gep(r:CStyleLanguage, x:UOp): - from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} - return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \ - (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \ - or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}") - -base_pm = PatternMatcher([ +base_rewrite = PatternMatcher([ (UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]), (UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"), (UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"), (UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"), (UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"), - # load/store image - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))), - lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"), - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))), - lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"), - (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True), - lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"), # r method accesses (UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"), (UPat(UOps.VECTORIZE, name="x"), lambda r,x: f"{r.float4.replace('float4', r.render_dtype(x.dtype))}" + \ (f"{{{','.join([r[y] for y in x.src])}}}" if r.device == "CLANG" else f"({','.join([r[y] for y in x.src])})")), - (UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)), - (UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)), + (UPat(UOps.CAST, name="x"), lambda r,x: f"({r.render_dtype(x.dtype)})({r[x.src[0]]})"), + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"(*(({r.buffer_prefix}{r.render_dtype(x.dtype)}*)&{r[x.src[0]]}))"), (UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"), (UPat(UOps.BARRIER), lambda r: r.barrier), (UPat(UOps.NOOP, name="x"), lambda r,x: r[x.src[0]]), @@ -72,11 +40,18 @@ base_pm = PatternMatcher([ (UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"), (UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"), (UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)), - # function calls - (UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load), - (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store), - (UPat(UOps.ALU, name="x"), render_alu), - (UPat(UOps.GEP, name="x"), render_gep), + # load/store + (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"), + lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"), + (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"), + lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)), + (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), + lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"), + # alu/gep + (UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg]( + *([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)), + (UPat(UOps.GEP, name="x"), lambda r,x: r[x.src[0]] + \ + (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) extra_pm = PatternMatcher([ @@ -120,13 +95,9 @@ class CStyleLanguage(Renderer): BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} + string_rewrite = base_rewrite extra_matcher = extra_pm - # returns a str expression of the casted xs with the given type - def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str: - if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))" - return f"({self.render_dtype(var_dtype)})({x})" - def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 @@ -174,7 +145,7 @@ class CStyleLanguage(Renderer): UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk") r[u] = f"{prefix}{c[prefix]}" - l = cast(str, base_pm.rewrite(u, ctx=self)) + l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1 @@ -242,8 +213,17 @@ class OpenCLRenderer(CStyleLanguage): float4 = "(float4)" code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"} type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" } - def render_cast(self, x, var_dtype, bitcast=False) -> str: - return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype) + + string_rewrite = PatternMatcher([ + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"), + # load/store image (OpenCL) + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))), + lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"), + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))), + lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"), + (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), + lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"), + ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or [])) @@ -253,12 +233,14 @@ class IntelRenderer(OpenCLRenderer): device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel " tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] + + string_rewrite = PatternMatcher([ + (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda r,x: f"intel_convert_bfloat16_as_ushort({r[x[0]]})"), + (UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"), + ]) + OpenCLRenderer.string_rewrite + def render_dtype(self, var_dtype:DType) -> str: return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype) - def render_cast(self, x, var_dtype, bitcast=False, from_dtype=None) -> str: - return f"intel_convert_bfloat16_as_ushort({x[0]})" if (var_dtype, from_dtype) == (dtypes.bfloat16, dtypes.float) else \ - (f"intel_convert_as_bfloat16_float({x[0]})" if (var_dtype, from_dtype) == (dtypes.float, dtypes.bfloat16) else \ - super().render_cast(x, var_dtype, bitcast)) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [] @@ -296,8 +278,9 @@ class MetalRenderer(CStyleLanguage): UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",} - def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str: - return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype) + string_rewrite = PatternMatcher([ + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"), + ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])