LOOP -> RANGE (#4650)

This commit is contained in:
George Hotz
2024-05-19 06:40:20 -07:00
committed by GitHub
parent 286b4dbdf2
commit 4753283221
8 changed files with 23 additions and 23 deletions

View File

@@ -337,7 +337,7 @@ class TestLinearizer(unittest.TestCase):
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(*sched[0].ast)
assert not any(u.uop is UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
assert not any(u.uop is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
@@ -391,7 +391,7 @@ class TestLinearizer(unittest.TestCase):
# ignore kernel optimized IF/LOOP statements for now
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.uop for u in uops if u.uop in {UOps.LOOP, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
assert len(set([u.uop for u in uops if u.uop in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"

View File

@@ -174,7 +174,7 @@ class Linearizer(Kernel):
# render loop
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, (
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
self.loop_uops.update(new_loops)

View File

@@ -22,9 +22,9 @@ class UOps(Enum):
# memory/assignment ops
LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
# control flow ops
BARRIER = auto(); IF = auto(); LOOP = auto() # noqa: E702
BARRIER = auto(); IF = auto(); RANGE = auto() # noqa: E702
# these two are not graph nodes
ENDLOOP = auto(); ENDIF = auto() # noqa: E702
ENDRANGE = auto(); ENDIF = auto() # noqa: E702
@dataclass(eq=False)
class UOp:
@@ -135,13 +135,13 @@ constant_folder = PatternMatcher([
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
[{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.LOOP, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
"vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
{"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
# sum collapse to mul (with possible GEP)
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.LOOP, "__name__": "loop"},)},
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.LOOP, "__name__": "loop"},)},)},
"vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
{"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
# deal with UNMUL
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
@@ -312,7 +312,7 @@ class UOpGraph:
add_parents(x)
in_degree[u] += 1
graph[x].append(u)
if u.uop is UOps.LOOP: loops.append(u)
if u.uop is UOps.RANGE: loops.append(u)
if u.uop is UOps.IF: ifs.append(u)
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
add_parents(sink)
@@ -350,7 +350,7 @@ class UOpGraph:
for u, ss in loops_children.items():
if x in ss:
ss.remove(x)
if len(ss) == 0: self._uops.append(UOp(UOps.ENDLOOP, None, (u,)))
if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
for u in graph[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
@@ -376,10 +376,10 @@ class UOpGraph:
mults: sint = 1
mult_stack = []
for u in self.uops:
if u.uop is UOps.LOOP:
if u.uop is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.vin[1])
elif u.uop is UOps.ENDLOOP:
elif u.uop is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.uop is UOps.ALU:
flops += mults

View File

@@ -91,10 +91,10 @@ def print_tree(lazyop:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s
def graph_uops(uops:List[UOp]):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
G = nx.DiGraph()
for u in uops:
if u.uop in {UOps.ENDLOOP, UOps.ENDIF}: continue
if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')

View File

@@ -146,7 +146,7 @@ class PTXRenderer(Renderer):
assert vin[0].dtype is not None
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDLOOP:
elif uop is UOps.ENDRANGE:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
@@ -164,7 +164,7 @@ class PTXRenderer(Renderer):
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:

View File

@@ -112,7 +112,7 @@ class CStyleLanguage(Renderer):
kk(f"if ({r[vin[0]]}) {{")
depth += 1
elif uop is UOps.BARRIER: kk(self.barrier)
elif uop in {UOps.ENDLOOP, UOps.ENDIF}:
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
@@ -121,7 +121,7 @@ class CStyleLanguage(Renderer):
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
if uop is UOps.RANGE:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:

View File

@@ -108,7 +108,7 @@ class LLVMRenderer(Renderer):
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
else:
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
elif uop is UOps.ENDLOOP:
elif uop is UOps.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
@@ -117,7 +117,7 @@ class LLVMRenderer(Renderer):
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
if uop is UOps.RANGE:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)

View File

@@ -39,7 +39,7 @@ class PythonProgram:
loop_ends: Dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {UOps.STORE, UOps.ENDLOOP, UOps.BARRIER, UOps.IF, UOps.ENDIF}
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
if uop is UOps.DEFINE_ACC: idp.clear()
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
@@ -62,7 +62,7 @@ class PythonProgram:
if g: _store(m, o, v)
i += 1
continue
elif uop is UOps.ENDLOOP:
elif uop is UOps.ENDRANGE:
loop_ends[idp[0]] = i
i = idp[0]
continue
@@ -90,7 +90,7 @@ class PythonProgram:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.DEFINE_ACC:
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
elif uop is UOps.LOOP:
elif uop is UOps.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else:
for j in range(len(ul[i])):