mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rename Ops.ENDRANGE -> Ops.END (#12824)
This commit is contained in:
@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
@@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
uops = get_program(ast, opts=opt).uops
|
||||
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
|
||||
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
|
||||
for u in uops:
|
||||
if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG:
|
||||
|
||||
@@ -674,7 +674,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
||||
uops = to_uops_list([store])
|
||||
ranges = [x for x in uops if x.op is Ops.RANGE]
|
||||
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
|
||||
endranges = [x for x in uops if x.op is Ops.END]
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp
|
||||
while len(ends_to_add):
|
||||
r:UOp = ends_to_add.pop(-1)
|
||||
new_ctx = tuple([z for z in new_ctx if z is not r])
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
||||
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.END, src=(r,))
|
||||
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
||||
return base_block
|
||||
|
||||
@@ -215,7 +215,7 @@ def remove_blockend(x:UOp):
|
||||
# NOTE: DEFINE_ACC doesn't have to be handled in any special way
|
||||
late_ops = list(x.arg.lst)
|
||||
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
||||
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.END:
|
||||
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
||||
# peephole opt, remove any BARRIERs next to each other
|
||||
for i in range(len(late_ops)-1):
|
||||
|
||||
@@ -47,7 +47,7 @@ class Estimates:
|
||||
mults *= cast(sint, u.src[0].ssimplify())
|
||||
# SPECIAL are already counted in mults
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
||||
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
# r method accesses
|
||||
(UPat(Ops.RANGE, name="x"),
|
||||
@@ -173,7 +173,7 @@ class CStyleLanguage(Renderer):
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
||||
|
||||
@@ -108,7 +108,7 @@ base_rewrite = PatternMatcher([
|
||||
f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n"
|
||||
f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
(UPat(Ops.END, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n"
|
||||
f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
|
||||
|
||||
@@ -186,7 +186,7 @@ class NIRRenderer(Renderer):
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
|
||||
elif u.op == Ops.ENDRANGE:
|
||||
elif u.op == Ops.END:
|
||||
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]),
|
||||
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break))
|
||||
mesa.nir_pop_loop(self.b, None)
|
||||
|
||||
@@ -115,7 +115,7 @@ string_rewrite = PatternMatcher([
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
||||
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
||||
(UPat(Ops.END, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
||||
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
|
||||
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
|
||||
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
||||
@@ -219,7 +219,7 @@ class PTXRenderer(Renderer):
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
if prefix: r[u] = ssa(prefix, u, dtype)
|
||||
|
||||
@@ -52,11 +52,11 @@ class PythonProgram:
|
||||
loop_ends: dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is Ops.ENDRANGE:
|
||||
if uop is Ops.END:
|
||||
loop_ends[idp[0]] = i
|
||||
i = idp[0]
|
||||
continue
|
||||
|
||||
@@ -70,7 +70,7 @@ class Ops(FastEnum):
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
# consts. VCONST is a vectorized const
|
||||
VCONST = auto(); CONST = auto() # noqa: E702
|
||||
|
||||
@@ -195,7 +195,7 @@ spec = PatternMatcher([
|
||||
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
||||
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
(UPat(Ops.END, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
|
||||
# WMMA has a <a, b, acc>
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat()), name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 8),
|
||||
|
||||
@@ -20,7 +20,8 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866"}
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.SUBSTITUTE: "#ffff00", Ops.AFTER: "#8A7866",
|
||||
Ops.END: "#524C46"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
||||
Reference in New Issue
Block a user