mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 22:54:59 -05:00
Refactor the use of pattern matcher in ptx (#4043)
This commit is contained in:
@@ -43,34 +43,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
def eq_rep(root, x, y):
|
||||
root.arg = BinaryOps.XOR
|
||||
return uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
|
||||
|
||||
def lt_rep(x, y):
|
||||
u.vin = (uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)), u.vin[1])
|
||||
u.arg = BinaryOps.MUL
|
||||
|
||||
def ld_rep(root, x, y):
|
||||
root.dtype = dtypes.uint8
|
||||
ptr_ar(root)
|
||||
return uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
|
||||
|
||||
def st_rep(root, z):
|
||||
root.vin = root.vin[:2] + (uops.add(UOps.CAST, dtypes.uint8, (z,), insert_before=uops.uops.index(root)),)
|
||||
ptr_ar(root)
|
||||
|
||||
def gate_rep(root, x, y, z, k):
|
||||
root.vin = (x,y,z,uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root)))
|
||||
return ld_rep(root,x,y)
|
||||
|
||||
def half_alu_rep(root):
|
||||
if root.arg not in lang.supports_half:
|
||||
root.vin = tuple([uops.add(UOps.CAST, dtypes.float32, (vv,), insert_before=(uops.uops.index(root))) for vv in root.vin])
|
||||
root.dtype = dtypes.float32
|
||||
return uops.add(UOps.CAST, dtypes.half, (root,), insert_before=(uops.uops.index(root)+1))
|
||||
|
||||
def ptr_ar(root):
|
||||
def ptr_ar(root, uops):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
@@ -83,16 +56,22 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep),
|
||||
({"__name__": "root", "uop": UOps.ALU, "dtype": dtypes.half}, half_alu_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, st_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, st_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar),
|
||||
({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
|
||||
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
|
||||
])
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
@@ -106,6 +85,18 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
if rew := matcher.rewrite(u): replace[u] = rew
|
||||
|
||||
for o,n in replace.items():
|
||||
queue = [n]
|
||||
while queue:
|
||||
if all([qq in uops.uops for qq in queue[-1].vin]):
|
||||
q = queue.pop()
|
||||
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1)
|
||||
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
|
||||
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
|
||||
uops.uops.remove(o)
|
||||
|
||||
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
Reference in New Issue
Block a user