diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7af6294c83..f63a9e9a71 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase): def _test_no_nested_ranges(self, lins, skip=None): for l in lins: range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG]) - ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)] for i,u in enumerate(ranges): if skip and i in skip: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase): # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE uops = get_program(ast, opts=opt).uops begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] - end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] + end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0] for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) for u in uops: if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG: diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d5d75462f3..ca93f3a0cf 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -674,7 +674,7 @@ class TestUOpGraph(unittest.TestCase): store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf)) uops = to_uops_list([store]) ranges = [x for x in uops if x.op is Ops.RANGE] - endranges = [x for x in uops if x.op is Ops.ENDRANGE] + endranges = [x for x in uops if x.op is Ops.END] # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) diff --git a/tinygrad/codegen/late/linearize.py b/tinygrad/codegen/late/linearize.py index d860125adf..af6727819e 100644 --- a/tinygrad/codegen/late/linearize.py +++ b/tinygrad/codegen/late/linearize.py @@ -105,7 +105,7 @@ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp while len(ends_to_add): r:UOp = ends_to_add.pop(-1) new_ctx = tuple([z for z in new_ctx if z is not r]) - end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)) + end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.END, src=(r,)) base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt)) return base_block @@ -215,7 +215,7 @@ def remove_blockend(x:UOp): # NOTE: DEFINE_ACC doesn't have to be handled in any special way late_ops = list(x.arg.lst) # NOTE: we have to add a barrier at the start if barrier is used in the range - if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE: + if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.END: late_ops = [UOp(Ops.BARRIER)] + late_ops # peephole opt, remove any BARRIERs next to each other for i in range(len(late_ops)-1): diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 849ec9d48e..a1d8f89d5f 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -47,7 +47,7 @@ class Estimates: mults *= cast(sint, u.src[0].ssimplify()) # SPECIAL are already counted in mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults - elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) + elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): lds += u.dtype.itemsize * mults diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e6d01bfc97..5afcd0711a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu base_rewrite = PatternMatcher([ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), - (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), + (UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"), (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), # r method accesses (UPat(Ops.RANGE, name="x"), @@ -173,7 +173,7 @@ class CStyleLanguage(Renderer): l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" - if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 + if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 032532e75c..b67bd9cb32 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([ f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n" f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n" f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"), - (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: + (UPat(Ops.END, name="x"), lambda ctx,x: f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n" f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n" f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n" diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index efaeddbecd..eec9cade89 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -186,7 +186,7 @@ class NIRRenderer(Renderer): nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) mesa.nir_push_loop(self.b) self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype) - elif u.op == Ops.ENDRANGE: + elif u.op == Ops.END: nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]), functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break)) mesa.nir_pop_loop(self.b, None) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index a57ee6a838..cc95e357a3 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([ if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []), (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]), - (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + (UPat(Ops.END, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]), f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), @@ -219,7 +219,7 @@ class PTXRenderer(Renderer): [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)], [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]] r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] - prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None), + prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]), Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index afb1bb87f7..9a8ade8e18 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -52,11 +52,11 @@ class PythonProgram: loop_ends: dict[int, int] = {} while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] - void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} + void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE} inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) - if uop is Ops.ENDRANGE: + if uop is Ops.END: loop_ends[idp[0]] = i i = idp[0] continue diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index cfdfa39a8a..fcb62cd1f3 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -70,7 +70,7 @@ class Ops(FastEnum): WHERE = auto(); MULACC = auto() # noqa: E702 # control flow ops - BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702 + BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702 # consts. VCONST is a vectorized const VCONST = auto(); CONST = auto() # noqa: E702 diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index ca9b32c7ec..667936cb49 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -195,7 +195,7 @@ spec = PatternMatcher([ (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), - (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), + (UPat(Ops.END, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), # WMMA has a (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index da864578ce..827d672509 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -20,7 +20,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", - Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866"} + Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866", + Ops.END: "#524C46"} # VIZ API