mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
LOOP -> RANGE (#4650)
This commit is contained in:
@@ -337,7 +337,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop is UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
assert not any(u.uop is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
@@ -391,7 +391,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# ignore kernel optimized IF/LOOP statements for now
|
||||
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
|
||||
uops = uops[:uops.index(if_op)]
|
||||
assert len(set([u.uop for u in uops if u.uop in {UOps.LOOP, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
|
||||
assert len(set([u.uop for u in uops if u.uop in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
|
||||
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
|
||||
# TODO: once uops track min/max this will be fixed
|
||||
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
@@ -174,7 +174,7 @@ class Linearizer(Kernel):
|
||||
|
||||
# render loop
|
||||
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
|
||||
new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
|
||||
@@ -22,9 +22,9 @@ class UOps(Enum):
|
||||
# memory/assignment ops
|
||||
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
|
||||
# control flow ops
|
||||
BARRIER = auto(); IF = auto(); LOOP = auto() # noqa: E702
|
||||
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
|
||||
# these two are not graph nodes
|
||||
ENDLOOP = auto(); ENDIF = auto() # noqa: E702
|
||||
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
||||
|
||||
@dataclass(eq=False)
|
||||
class UOp:
|
||||
@@ -135,13 +135,13 @@ constant_folder = PatternMatcher([
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
|
||||
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
|
||||
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.LOOP, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
|
||||
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
|
||||
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.LOOP, "__name__": "loop"},)},
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
|
||||
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.LOOP, "__name__": "loop"},)},)},
|
||||
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
|
||||
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
|
||||
# deal with UNMUL
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
|
||||
@@ -312,7 +312,7 @@ class UOpGraph:
|
||||
add_parents(x)
|
||||
in_degree[u] += 1
|
||||
graph[x].append(u)
|
||||
if u.uop is UOps.LOOP: loops.append(u)
|
||||
if u.uop is UOps.RANGE: loops.append(u)
|
||||
if u.uop is UOps.IF: ifs.append(u)
|
||||
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
|
||||
add_parents(sink)
|
||||
@@ -350,7 +350,7 @@ class UOpGraph:
|
||||
for u, ss in loops_children.items():
|
||||
if x in ss:
|
||||
ss.remove(x)
|
||||
if len(ss) == 0: self._uops.append(UOp(UOps.ENDLOOP, None, (u,)))
|
||||
if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
|
||||
for u in graph[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
@@ -376,10 +376,10 @@ class UOpGraph:
|
||||
mults: sint = 1
|
||||
mult_stack = []
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.LOOP:
|
||||
if u.uop is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.vin[1])
|
||||
elif u.uop is UOps.ENDLOOP:
|
||||
elif u.uop is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.uop is UOps.ALU:
|
||||
flops += mults
|
||||
|
||||
@@ -91,10 +91,10 @@ def print_tree(lazyop:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s
|
||||
def graph_uops(uops:List[UOp]):
|
||||
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
||||
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
||||
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
||||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
if u.uop in {UOps.ENDLOOP, UOps.ENDIF}: continue
|
||||
if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
||||
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
|
||||
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
|
||||
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
||||
|
||||
@@ -146,7 +146,7 @@ class PTXRenderer(Renderer):
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
elif uop is UOps.ENDRANGE:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
@@ -164,7 +164,7 @@ class PTXRenderer(Renderer):
|
||||
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
|
||||
@@ -112,7 +112,7 @@ class CStyleLanguage(Renderer):
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.BARRIER: kk(self.barrier)
|
||||
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
|
||||
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
@@ -121,7 +121,7 @@ class CStyleLanguage(Renderer):
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
if uop is UOps.RANGE:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
|
||||
@@ -108,7 +108,7 @@ class LLVMRenderer(Renderer):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
else:
|
||||
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
||||
elif uop is UOps.ENDLOOP:
|
||||
elif uop is UOps.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
@@ -117,7 +117,7 @@ class LLVMRenderer(Renderer):
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
if uop is UOps.RANGE:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class PythonProgram:
|
||||
loop_ends: Dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {UOps.STORE, UOps.ENDLOOP, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
||||
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
||||
if uop is UOps.DEFINE_ACC: idp.clear()
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
@@ -62,7 +62,7 @@ class PythonProgram:
|
||||
if g: _store(m, o, v)
|
||||
i += 1
|
||||
continue
|
||||
elif uop is UOps.ENDLOOP:
|
||||
elif uop is UOps.ENDRANGE:
|
||||
loop_ends[idp[0]] = i
|
||||
i = idp[0]
|
||||
continue
|
||||
@@ -90,7 +90,7 @@ class PythonProgram:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
|
||||
elif uop is UOps.LOOP:
|
||||
elif uop is UOps.RANGE:
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
for j in range(len(ul[i])):
|
||||
|
||||
Reference in New Issue
Block a user