new uops is an actual graph (#4560)

* new uops is an actual graph

* it's way slower

* simpler

* fix define acc

* render_loop unique

* ops test pass

* add pattern matcher back, there's bugs

* rewrite

* use priority queue

* recursive children

* fix tests

* fix tests with SINK

* fix abstractions

* fix assembly

* simpler

* link define_acc

* fix DEFINE_ACC placement

* type verify

* full cmp

* fix cmp

* ACCESS_ACC

* insert DEFINE_ACC

* fix PHI

* recursive rewrite

* fix many tests

* sum collapse

* more patterns

* correct change

* fold arange

* fix that lin test

* space

* big folding rule works

* close

* has more maxes, meh

* cached node replace

* set changed

* simplest folding yet

* works

* works

* DIV

* all tests pass

* del

* fuzz linearizer fails

* sum_collapse

* test depth 2 cf

* fix lin test 14

* fix clang depth

* disable that

* failure 14 is fixed

* fix ptx

* failure 27 is fixed

* fix llama

* run_cnt

* Revert "Optimize PTX gated loads index calculation (#4304)"

This reverts commit d97d5a7689.

* fix uops loop

* fix ptx bugs

* add barrier

* print

* mem_type in ptx direct

* bypass tests that fail in CI but pass locally

* ptx remove ptr_ar

* more ptx passing

* fix ptx tests

* assert compile support

* remove  model inference benchmark from red
This commit is contained in:
George Hotz
2024-05-17 18:00:18 -07:00
committed by GitHub
parent daf57af3eb
commit 07b350a8f4
14 changed files with 431 additions and 451 deletions

View File

@@ -1,6 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct, copy
from collections import defaultdict
from tinygrad.helpers import DEBUG
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
@@ -14,37 +15,6 @@ def render_val(x, dtype):
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
def ptr_ar(root, uops):
assert root.arg in {'.shared', '.global', None}
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
if root.vin[1].uop is UOps.ALU and root.vin[1].arg in [BinaryOps.ADD, BinaryOps.SUB] and root.vin[1].vin[1].uop is UOps.CONST:
offset = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[0], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
offset = uops.add(UOps.CAST, dtypes.uint64, (offset,), insert_before=uops.uops.index(root))
cache = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], offset), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if root.vin[1].arg == BinaryOps.SUB: ptr = uops.add(UOps.ALU, dtypes.int, (ptr,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root))
root.vin = (cache, ptr) + root.vin[2:]
else:
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, insert_before=uops.uops.index(root))
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
def optimize_gated_loads(uops: UOpGraph):
def successors(uop): return list(filter(lambda u: uop in u.vin, uops.uops))
for gl in list(filter(lambda u:u.uop is UOps.LOAD and len(u.vin)>3, uops.uops)):
uops.uops.insert(uops.uops.index(gl), gate:=UOp(UOps.IF, None, (gl.vin[2],)))
uops.uops.insert(uops.uops.index(gl)+1, end:=UOp(UOps.ENDIF, None, (gate,) + (gl, gl.vin[3])))
for u in reversed(uops.uops.copy()[:uops.uops.index(gate)]):
if (u.uop not in [UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR, UOps.DEFINE_LOCAL, UOps.PHI, UOps.STORE, UOps.ENDIF, UOps.ENDLOOP] and
all(uops.uops.index(s)>uops.uops.index(gate) and uops.uops.index(s)<=uops.uops.index(end) for s in successors(u))):
uops.uops.insert(uops.uops.index(gate), uops.uops.pop(uops.uops.index(u)))
gl.vin = gl.vin[:2]
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
@@ -101,12 +71,13 @@ class PTXRenderer(Renderer):
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
def render_bra(self, b1, pred=None, neg=False) -> List[str]: return [f"@{'!' if neg else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"]
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
def render_load(self, loc, dest, dtype, ss="", offset=0) -> List[str]:
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
assert dtype is not dtypes.bool
if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
@@ -134,38 +105,8 @@ class PTXRenderer(Renderer):
kernel:List[str] = []
bufs = []
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in self.asm_for_op.keys() if op not in self.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g"})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,), root.arg),))),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
matcher.rewrite_graph(uops)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
optimize_gated_loads(uops)
uops.linearize(ptx_matcher)
if DEBUG >= 4: uops.print()
def kk(*s: str): kernel.append("\n".join(s))
@@ -193,7 +134,7 @@ class PTXRenderer(Renderer):
return self.render_const(x, dtype)
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if atype == dtype or isinstance(atype, PtrDType):
if u: r[u] = a
return a
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
@@ -203,27 +144,24 @@ class PTXRenderer(Renderer):
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), neg=True))
kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDLOOP:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(r_label[vin[0]], pred))
kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"@{_cast(r[vin[0].vin[0]], dtypes.bool, vin[0].vin[0].dtype, u=u, pred=True)} bra {r_label[vin[0]]}_true;")
kk(f"{r_label[vin[0]]}:")
if len(vin) > 1 and vin[1].dtype.count > 1:
kk(*[f"mov.b{self.types[vin[1].dtype.scalar()][1:]} {dd}, {r[vin[2]][i]};" for i, dd in enumerate(r[vin[1]])])
elif len(vin) > 1:
kk(*[f"mov.b{self.types[vin[1].dtype][1:]} {r[vin[1]]}, {r[vin[2]]};" ])
kk(f"{r_label[vin[0]]}_true:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
assert vin[0].dtype is not None and vin[2].dtype is not None
assert vin[0].dtype is dtypes.int64, "store isn't int64"
assert vin[1].uop is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{u.arg}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_type(vin[2].dtype.scalar())} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
else:
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
@@ -249,14 +187,23 @@ class PTXRenderer(Renderer):
else: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
assert vin[0].dtype is dtypes.int64, "load isn't int64"
assert vin[1].uop is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
kk(f"ld{u.arg}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
if(len(vin)>3):
for v in r[u]: kk(f"mov.{self.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{mem_type}.v{dtype.count}.{self.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, ss=u.arg, offset=vin[1].arg))
kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
if dtype.count > 1:
for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
else:
kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
@@ -288,3 +235,38 @@ class PTXRenderer(Renderer):
else: raise NotImplementedError(f"no code for {uop}")
return self.render_kernel(kernel, name, bufs, c.items())
ptx_matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD, "dtype": set([dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]),
"vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
# ptr_ar (load/store)
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"__name__": "const", "uop":UOps.CONST})},
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
)+root.vin[2:])),
({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
{"__name__": "alu"})}, # no const here
lambda root, alu: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.vin[2:])),
])