From 572a3c15c6fc78911649a2ee174a3624556728ba Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Thu, 4 Sep 2025 09:31:44 +0200 Subject: [PATCH] Move Ops.SPECIAL arg to src (#11918) * initial moving bound to src * arg to src * remove import * fixup linearizer * arg to src * fix test_uop_graph * fix more tests * fix python renderer * get const value from const uop * ssimplify uop estimates * fix webgpu locals * fix old test * gate Ops.SPECIAL in linearizer * use ssimplify() for local/global_size * remove toposort gate_parents_instead_of_self * fix rendering in comment * cleanup * rename and add comments * add BottomUpGate with test --- extra/backends/rdna.py | 8 +++--- test/test_linearizer.py | 12 ++++----- test/test_renderer_failures.py | 8 +++--- test/test_uop_graph.py | 22 ++++++++-------- test/test_uops.py | 8 +++--- test/unit/test_block_reorder.py | 4 +-- test/unit/test_graph_rewrite.py | 2 +- test/unit/test_rewrite_bottom_up_gate.py | 28 ++++++++++++++++++++ test/unit/test_simplify_valid_idx.py | 2 +- test/unit/test_uop_vmin_vmax.py | 2 +- tinygrad/codegen/gpudims.py | 4 +-- tinygrad/codegen/late/linearize.py | 11 +++++--- tinygrad/renderer/__init__.py | 8 +++--- tinygrad/renderer/cstyle.py | 7 ++--- tinygrad/renderer/llvmir.py | 7 ++--- tinygrad/renderer/ptx.py | 11 ++++---- tinygrad/renderer/wgsl.py | 2 +- tinygrad/runtime/ops_python.py | 7 ++--- tinygrad/uop/ops.py | 33 +++++++++++++----------- tinygrad/uop/spec.py | 4 +-- 20 files changed, 114 insertions(+), 76 deletions(-) create mode 100644 test/unit/test_rewrite_bottom_up_gate.py diff --git a/extra/backends/rdna.py b/extra/backends/rdna.py index 32e9b8d9ea..a5b775b734 100644 --- a/extra/backends/rdna.py +++ b/extra/backends/rdna.py @@ -29,10 +29,10 @@ def uops_to_rdna(function_name:str, uops:UOpGraph) -> str: r: Dict[UOp, str] = {} for u in uops: if u.uop == UOps.SPECIAL: - if u.arg[1].startswith("lidx"): - r[u] = f'v{u.arg[0]}' - elif u.arg[1].startswith("gidx"): - r[u] = f's{2+u.arg[0]}' + if u.arg.startswith("lidx"): + r[u] = f'v{u.src[0].arg}' + elif u.arg.startswith("gidx"): + r[u] = f's{2+u.src[0].arg}' else: raise NotImplementedError elif u.uop == UOps.CONST: diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 722fba84a8..c2f8b67786 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -473,8 +473,8 @@ class TestLinearizer(unittest.TestCase): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])) - loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0]) - sizes = [x.arg[1] for x in loop_idxs] + loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg) + sizes = [x.src[0].arg for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" if assert_same_length: assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" @@ -547,10 +547,10 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(t+1)[0] uops = get_program(k.ast, k.opts, k.applied_opts).uops idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL]) - idxs = sorted(idxs, key=lambda uop: uop.arg[0]) - assert idxs[0].arg == ('gidx0', 6), idxs[0].arg - assert idxs[1].arg == ('gidx1', 5), idxs[1].arg - assert idxs[2].arg == ('gidx2', 4), idxs[2].arg + idxs = sorted(idxs, key=lambda uop: uop.arg) + assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0] + assert (idxs[1].arg, idxs[1].src[0].arg) == ('gidx1', 5), idxs[1].arg + assert (idxs[2].arg, idxs[2].src[0].arg) == ('gidx2', 4), idxs[2].arg def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index bcfad46430..8092914be7 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) @@ -56,8 +56,8 @@ class TestRendererFailures(unittest.TestCase): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer") def test_gated_store_with_alu_2d(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) - gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0) + gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) + gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1))) sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = full_rewrite(sink, Device[Device.DEFAULT].renderer) @@ -101,7 +101,7 @@ class TestPTXFailures(unittest.TestCase): @unittest.skip("INDEX can only have a gate ALU parent, not an IF") def test_gated_store_with_if(self): a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0) val = UOp.const(dtypes.int, 1) if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val)) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index b56fe1a922..7d705e0080 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -458,8 +458,8 @@ class TestUOpGraph(unittest.TestCase): sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0") # Define indices, valids and barrier - gidx = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 416)) - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 10)) + gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0") + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0") gate = (gidx<400) & (lidx<8) @@ -512,7 +512,7 @@ class TestUOpGraph(unittest.TestCase): def test_in_out_bounds_access_with_mask(self): with Context(IGNORE_OOB=0): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtype=dtypes.int, arg=("gidx0", 42)) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 42),), "gidx0") ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0, (5=0)&(ld0<32)),)) to_uops_list([ld1]) @@ -559,7 +559,7 @@ class TestUOpGraph(unittest.TestCase): def test_fold_gated_load_local(self): glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp") - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int))) barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier)) @@ -756,8 +756,8 @@ class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, addrspace=AddrSpace.LOCAL), (), "smem") - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0") gate = valid&(lidx.ne(2)) idx = UOp.const(dtypes.int, 0) st = UOp(Ops.STORE, dtypes.void, (sbuf.index(idx), UOp.const(dtypes.float, 42))) @@ -775,8 +775,8 @@ class TestIFUOps(unittest.TestCase): def test_expand_ifs_one_gate(self): gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, addrspace=AddrSpace.LOCAL), (), "smem") - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "gidx0")<1 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 16),), "lidx0") gate = valid&(lidx.ne(2)) st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) @@ -794,8 +794,8 @@ class TestIFUOps(unittest.TestCase): @unittest.expectedFailure def test_expand_ifs_dumb(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5 - lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) + valid = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "gidx0")<5 + lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), "lidx0") gate = valid&(lidx.ne(2)) stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) diff --git a/test/test_uops.py b/test/test_uops.py index defdcf00ef..4a0433ee5e 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -270,7 +270,7 @@ class TestConstantFolding(unittest.TestCase): class TestGatedStoreRewrite(unittest.TestCase): def test_tiny_gate_store(self): gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'gidx0') gate = gidx0 len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") # try to split up dims: (a,) -> (b, c) if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (sint_to_uop(s),), (f"{prefix}{i}")) for i,s in enumerate(limited)] if len(limited) < len(dims): ret = [] if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}") diff --git a/tinygrad/codegen/late/linearize.py b/tinygrad/codegen/late/linearize.py index 54ff2cbfc7..a8ef31f3e4 100644 --- a/tinygrad/codegen/late/linearize.py +++ b/tinygrad/codegen/late/linearize.py @@ -2,7 +2,7 @@ from __future__ import annotations import heapq from collections import defaultdict from dataclasses import dataclass, replace -from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp +from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat, GroupOp, BottomUpGate from tinygrad.helpers import dedup, all_same, flatten, BLOCK_REORDER # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed @@ -76,12 +76,13 @@ class BlockContext: def from_sink(sink:UOp) -> BlockContext: # get children and all block contexts ctx = BlockContext({}, {}, {}) - for u in sink.toposort(): + for u in sink.toposort(gate=lambda u:u.op is not Ops.SPECIAL): this_block_ctx: list[UOp] = [] ctx.child_count[u] = 0 # get children and accumulate the last_ctx for s in u.src: + if s.op is Ops.SPECIAL: continue # NOTE: if a parent appears multiple times in the src, it counts multiple times as a child ctx.child_count[s] += 1 this_block_ctx += ctx.last_ctx(s) @@ -142,7 +143,7 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): # add unmergables to sources srcs = [] - for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx, cnt=cnt)]*cnt + for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs.get(u,()), current_ctx, cnt=cnt)]*cnt # add blockseeds, with blockends as needed for (new_ctx, new_child_ctx), v in blockseeds.items(): @@ -154,8 +155,12 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp): bb = BasicBlock(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx) return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb) +# we prevent the source of the SPECIAL from being linearized since its not part of the kernel +def raise_bottom_up_gate(): raise BottomUpGate() + block_create = PatternMatcher([ (UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up), + (UPat(Ops.SPECIAL), raise_bottom_up_gate) ]) # ***** blockend merging **** diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index d721440f81..22660bc80d 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -46,7 +46,7 @@ class Estimates: # SPECIAL are already counted in mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) - elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these + elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): lds += u.dtype.itemsize * mults elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): @@ -82,9 +82,9 @@ class ProgramSpec: if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL]) if u.op is Ops.SPECIAL: # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this - if u.arg[0][0] == 'i': self.local_size = None - special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size - if special_size is not None: special_size[int(u.arg[0][-1])] = u.arg[1] + special_size = self.local_size if u.arg[0] == 'l' else self.global_size + assert special_size is not None, f"special_size is None but found SPECIAL in uops {u}" + special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) # TODO: the type here should be sint self.vars = sorted(self.vars, key=lambda v: v.arg) self.outs = sorted(dedup(self.outs)) self.ins = sorted(dedup(self.ins)) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 506ee465cd..d04fa73567 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -26,7 +26,7 @@ base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), (UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]), - (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {sint_to_uop(x.arg[1]).render()} */"), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"), # const (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"), (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"), @@ -111,7 +111,8 @@ class CStyleLanguage(Renderer): tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] - launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + launch_bounds = sint_to_uop(prod(local_dims)).vmax prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) @@ -156,7 +157,7 @@ class CStyleLanguage(Renderer): # naming prefix = None - if u.op is Ops.SPECIAL: r[u] = u.arg[0] + if u.op is Ops.SPECIAL: r[u] = u.arg elif u.op is Ops.RANGE: r[u] = "ridx"+'_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]]) else: prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 47ed182bf5..44c7bd0f7e 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -3,7 +3,7 @@ import math, struct, sys from tinygrad.codegen.opt import tc from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import AMDRenderer -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, sint_to_uop from tinygrad.dtype import dtypes, DType, PtrDType, truncate from tinygrad.helpers import prod, AMX @@ -207,7 +207,7 @@ class AMDLLVMRenderer(LLVMRenderer): abi = "amdgpu_kernel" code_for_op = {**LLVMRenderer.code_for_op, **{op: lambda: None for op in llvm_intrinsics}} string_rewrite = PatternMatcher([ - (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "), + (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0]](x.arg[-1])}; "), (UPat(tuple(llvm_intrinsics), name="x"), lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BARRIER), lambda ctx: barrier), @@ -220,7 +220,8 @@ class AMDLLVMRenderer(LLVMRenderer): ]) def _render_footer(self, uops: list[UOp]) -> str: # TODO: this is copied from cstyle - requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + requiredMaxThreadsPerBlock = sint_to_uop(prod(local_dims)).vmax attributes = ["alwaysinline", "nounwind", '"no-builtins"', f'"amdgpu-flat-work-group-size"="1,{requiredMaxThreadsPerBlock}"', '"no-trapping-math"="true"'] return 'attributes #0 = { ' + ' '.join(attributes) + ' }' diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 3c78bb220d..cc764d1b0b 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -2,7 +2,7 @@ from typing import cast, Callable import struct from collections import defaultdict from tinygrad.codegen.opt import tc -from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, sint_to_uop from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -91,7 +91,7 @@ string_rewrite = PatternMatcher([ (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \ f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \ f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"), - (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg}, %{'ctaid' if x.arg[0] == 'g' else 'tid'}.{chr(120+int(x.arg[-1]))};"), (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"), (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)), lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])), @@ -155,7 +155,8 @@ class PTXRenderer(Renderer): def render_kernel(self, kernel, function_name, bufs, regs, uops) -> str: def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"])) - launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") + local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] + launch_bounds = sint_to_uop(prod(local_dims)).vmax params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) return f"{self.kernel_prefix.format(launch_bounds=launch_bounds)} {function_name} (\n\t{params}\n)\n.maxntid {launch_bounds}\n{{\n{kernel}\n}}" @@ -202,7 +203,7 @@ class PTXRenderer(Renderer): typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:]) kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};") continue - if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] + if u.op is Ops.SPECIAL: r[u] = "%" + u.arg elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype)) elif u.op is Ops.LOAD: assert u.src[0].dtype == dtypes.int64, "load isn't int64" @@ -223,5 +224,5 @@ class PTXRenderer(Renderer): raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") kernel.extend([l] if isinstance(l, str) else l) - if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel + if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg};"] + kernel return self.render_kernel(kernel, name, bufs, c.items(), uops) diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 4063a23c76..49a952ba36 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -84,7 +84,7 @@ class WGSLRenderer(CStyleLanguage): def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x def buf_map(self, dt:DType) -> str: return "atomic" if is_packed(dt) else self.type_map[dt.base] def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str: - local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])] + local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)] if not local_size: local_size = [1] bind_it = iter(range(len(bufs))) external_local_bufs = [line.lstrip() for line in kernel if "var" in line] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index f5550abd7a..72228ab070 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -84,8 +84,8 @@ class PythonProgram: elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size elif uop is Ops.SPECIAL: - if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size - elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp] + if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size + elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp] elif uop is Ops.CONST: ul[i] = [arg] * warp_size elif uop is Ops.INDEX: ret:list = [] @@ -220,7 +220,8 @@ class PythonRenderer(Renderer): if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx def render(self, uops:list[UOp]) -> str: - lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops] + # the value of SPECIAL comes from local/global_size, not form its source + lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops] return base64.b64encode(pickle.dumps(lops)).decode() class PythonCompiler(Compiler): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 11ca18be58..421c46a3b7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -566,11 +566,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax) # NOTE: returned UOp is assumed to be CONST if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] - if self.op is Ops.RANGE: return 0, (self.src[0]-1).vmax + if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) - # TODO: Ops.SPECIAL is Ops.DEFINE_VAR - if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax-1 if self.op is Ops.CONST: return self.arg, self.arg if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) if self.op is Ops.GEP: return self.src[0]._min_max @@ -942,6 +940,7 @@ if TRACK_MATCH_STATS or PROFILE: # *** simple graph rewrite engine *** class RewriteNotReady(Exception): pass +class BottomUpGate(Exception): pass class RewriteContext: def __init__(self, pm, bpm, ctx=None): self.pm: PatternMatcher|None = pm @@ -969,17 +968,20 @@ class RewriteContext: if n in self.replace: continue # skip any nodes we have seen try: if stage == 0: - # if bottom up, we rewrite this node early. in both cases, we add its parents to the stack - if self.bpm is not None: - # apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match - test_n: UOp|None = n - seen = set() - while test_n is not None: - if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite") - seen.add(test_n) - new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) - stack.append((n, 1, new_n)) - for x in reversed(new_n.src): stack.append((x, 0, x)) + try: + # if bottom up, we rewrite this node early. in both cases, we add its parents to the stack + if self.bpm is not None: + # apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match + test_n: UOp|None = n + seen = set() + while test_n is not None: + if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite") + seen.add(test_n) + new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) + stack.append((n, 1, new_n)) + for x in reversed(new_n.src): stack.append((x, 0, x)) + # if the bpm matching raised a gate, we are done with this node and dont continue down the srcs + except BottomUpGate: self.replace[n] = new_n elif stage == 1: try: new_src = tuple([self.replace[x] for x in new_n.src]) except KeyError: raise RewriteNotReady @@ -1028,7 +1030,8 @@ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} renderer = PatternMatcher([ - (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat((Ops.DEFINE_VAR,), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat((Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg)), (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}" if x.arg[0] >= 0 else f"ridxm{-x.arg[0]}")), (UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), (UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 4d658f0ef1..1a68dc6cbc 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -20,8 +20,6 @@ try: # each uop gets rewritten to NOOP(arg=(solver, z3_object)), the arg has the solver first due to UOpMetaClass caching. z3 objects from different # contexts can have the same hash but error on comparison z3_renderer = PatternMatcher([ - # Ops.SPECIAL can have symbolic arg but it wont be in the toposort beacuse its not a src, we need to add it manually - (UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))), (UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg, 0, x.src[0].arg[1]-1, ctx[0])))), (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])))), (UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(ctx[0],create_bounded(f"ridx{x.arg}", 0, x.src[0].arg[1]-1, ctx[0])))), @@ -157,7 +155,7 @@ spec = PatternMatcher([ (UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), (UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, tuple)), - (UPat(Ops.SPECIAL, src=()), lambda: True), + (UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)), (UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)), (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),