From dd575da7ee098ada8ce7a4dfd4505bbd2005f333 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 25 Sep 2024 12:40:46 +0800 Subject: [PATCH] real minimum cstyle change (#6709) * real minimum cstyle change * make it match * bring back DEFINE_GLOBAL store marking writable * bump line count to 9800 * closer * precompute don't render * cast/bitcast too * smem_align * vectorize * more pr match * remove that test * less PR diff --- .github/workflows/test.yml | 3 - tinygrad/renderer/cstyle.py | 187 +++++++++++++++++++----------------- 2 files changed, 99 insertions(+), 91 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1f0c8fbe2e..b2677cface 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -534,9 +534,6 @@ jobs: run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20 - name: Run process replay tests run: | - if [ "${{ matrix.backend }}" == "amd" ] && [ "${GITHUB_REF_NAME}" != "master" ]; then - MAX_DIFF_PCT=1 RUN_PROCESS_REPLAY=0 test/external/process_replay/test_process_replay.sh - fi export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH") export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }}) cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 3639547eca..a42f3a5605 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,11 +1,57 @@ -from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable +from __future__ import annotations +from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer, TensorCore +def render_load(r:CStyleLanguage, x:UOp): + val = r.render_load(x.dtype, r[x.src[0]], x.src[0].dtype, strip_parens(r[x.src[1]])) + # NOTE: this relies on the load not happening if it's in the unselected branch + if len(x.src) > 3 and x.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[x.src[3]], val, r[x.src[2]], x.dtype) + return val + +def render_store(r:CStyleLanguage, x:UOp): + assert isinstance(x.src[0].dtype, (ImageDType, PtrDType)) + rendered_store = r.render_store(r[x.src[0]], x.src[0].dtype, r[x.src[2]], x.src[2].dtype, strip_parens(r[x.src[1]])) + return f"if ({r[x.src[3]]}) {{ {rendered_store} }}" if len(x.src) > 3 and x.src[3].op is not UOps.IF else rendered_store + +def render_alu(r:CStyleLanguage, x:UOp): + if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src] + elif x.arg is BinaryOps.MAX: operands = [r.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in x.src] + else: operands = [r[v] for v in x.src] + return r.code_for_op[x.arg](*operands, x.dtype) + +def render_gep(r:CStyleLanguage, x:UOp): + from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} + return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \ + (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \ + or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}") + +base_pm = PatternMatcher([ + (UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]), + (UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"), + (UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"), + (UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"), + (UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"), + # r method accesses + (UPat(UOps.CONST, name="x"), lambda r,x: r.render_const(x.arg, x.dtype) if x.arg >= 0 else f"({r.render_const(x.arg, x.dtype)})"), + (UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"), + (UPat(UOps.VECTORIZE, name="x"), lambda r,x: r.render_vectorize([r[y] for y in x.src], x.dtype)), + (UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)), + (UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)), + (UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"), + (UPat(UOps.BARRIER), lambda r: r.barrier), + (UPat(UOps.SPECIAL, name="x"), lambda r,x: f"{r.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), + # function calls + (UPat(UOps.LOAD, name="x"), render_load), + (UPat(UOps.STORE, name="x"), render_store), + (UPat(UOps.ALU, name="x"), render_alu), + (UPat(UOps.GEP, name="x"), render_gep), +]) + class CStyleLanguage(Renderer): kernel_prefix: str = "" buffer_prefix: str = "" @@ -89,101 +135,66 @@ class CStyleLanguage(Renderer): return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" - def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];" def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "") + def __getitem__(self, key): return self.r[key] # hacky helper def render(self, name:str, uops:List[UOp]) -> str: - kernel = [] - bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {} - depth = 1 - def kk(s): kernel.append(" "*depth+s) - - c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, str] = {} + self.r = r - def ssa(prefix:str, u:Optional[UOp]=None): - nonlocal c, r - ret = f"{prefix}{c[prefix]}" - if u is not None: r[u] = ret - c[prefix] += 1 - return ret - + # get should render child_count = Counter(v for ru in uops for v in ru.src) - - seen_vars = set() + dont_render: Dict[UOp, bool] = {} for u in uops: - uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - # these four uops don't have output dtypes - if uop is UOps.IF: - kk(f"if ({r[src[0]]}) {{") - depth += 1 - elif uop is UOps.BARRIER: kk(self.barrier) - elif uop in {UOps.ENDRANGE, UOps.ENDIF}: - depth -= 1 - kk("}") - elif uop is UOps.STORE: - # mark DEFINE_GLOBAL buf as writable - assert isinstance(src[0].dtype, (ImageDType, PtrDType)) - if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True)) - rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]])) - kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store) + # bitcast src must be rendered (always earlier, so this is safe) + if u.op is UOps.BITCAST: dont_render[u.src[0]] = False + dont_render[u] = u.op in {UOps.CONST, UOps.GEP} or \ + (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST} and child_count[u] == 1 \ + and u.arg is not BinaryOps.MAX and not getenv("EXPAND_SSA")) + + bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {} + kernel = [] + depth = 1 + c: DefaultDict[str, int] = defaultdict(int) + c['temp'] += 1 # hack for process replay + for u in uops: + if u.op is UOps.DEFINE_GLOBAL: + r[u] = f"data{u.arg}" + bufs[u] = (r[u], (u.dtype, False)) + continue + if u.op is UOps.DEFINE_VAR: + r[u] = u.arg[0] + bufs[u] = (r[u], (u.dtype, False)) + continue + + # mark buffers that we store to writable + if u.op is UOps.STORE and u.src[0].op is UOps.DEFINE_GLOBAL: bufs[u.src[0]] = (bufs[u.src[0]][0], (bufs[u.src[0]][1][0], True)) + + # naming + prefix = None + if u.op is UOps.SPECIAL: + r[u] = u.arg[0] else: - if uop is UOps.RANGE: - kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{") - depth += 1 - elif uop is UOps.ALU: - # remove parens if ALU types are the same. TODO: can do more here - if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src] - elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in src] - else: operands = [r[v] for v in src] - val = self.code_for_op[args](*operands, dtype) - assert child_count[u] != 0, f"childless ALU op found {u}" - # TODO: fix index rendering issue. fix clang nested max macro issue - if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val - else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};") - elif uop is UOps.SPECIAL: - kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */") - r[u] = args[0] - elif uop is UOps.DEFINE_VAR: - assert args[0] not in seen_vars, f"duplicate variable {args[0]}" - seen_vars.add(args[0]) - bufs[u] = (args[0], (dtype,False)) - r[u] = args[0] - elif uop is UOps.LOAD: - val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]])) - # NOTE: this relies on the load not happening if it's in the unselected branch - if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype) - kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") - elif uop is UOps.ASSIGN: - kk(f"{r[src[0]]} = {r[src[1]]};") - r[u] = r[src[0]] - elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: - assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation" - if uop is UOps.BITCAST: - precast = ssa('precast') - kk(f"{self.render_dtype(src[0].dtype)} {precast} = {r[src[0]]};") - val = self.render_cast(precast, dtype, bitcast=True) - elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False) - else: val = self.render_vectorize([r[x] for x in src], dtype) - if child_count[u] <= 1: r[u] = val - else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};") - elif uop is UOps.DEFINE_LOCAL: - kk(self.render_local(args[0], dtype, args[1])) - r[u] = args[0] - elif uop is UOps.DEFINE_GLOBAL: - bufs[u] = (nm:=f"data{args}", (dtype, False)) - r[u] = nm - elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});") - elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};") - elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})" - elif uop is UOps.GEP: - assert len(args) == 1 - from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} - r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \ - (f"[{args[0]}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) \ - or self.device == 'CLANG' else f".{'xyzwabcd'[args[0]]}") - else: raise RuntimeError(f"failed to render {u}") + prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const", + UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast", + UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk") + r[u] = f"{prefix}{c[prefix]}" + + l = cast(str, base_pm.rewrite(u, ctx=self)) + assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" + + if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1 + if dont_render[u]: r[u] = l + else: + if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void: + if u.op is UOps.ASSIGN: r[u] = r[u.src[0]] + else: + l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "") + kernel.append(" "*depth + l) + if prefix: c[prefix] += 1 # if it was used, increment + if u.op in {UOps.IF, UOps.RANGE}: depth += 1 + del self.r # NOTE: this relies on bufs dict preserving order return self.render_kernel(name, kernel, list(bufs.values()), uops)