Refactor the use of pattern matcher in ptx (#4043)

This commit is contained in:
Szymon Ożóg
2024-04-02 23:19:51 +02:00
committed by GitHub
parent 85edc493b0
commit ccf3c16d6a

View File

@@ -43,34 +43,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kernel:List[str] = []
bufs = []
def eq_rep(root, x, y):
root.arg = BinaryOps.XOR
return uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
def lt_rep(x, y):
u.vin = (uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)), u.vin[1])
u.arg = BinaryOps.MUL
def ld_rep(root, x, y):
root.dtype = dtypes.uint8
ptr_ar(root)
return uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
def st_rep(root, z):
root.vin = root.vin[:2] + (uops.add(UOps.CAST, dtypes.uint8, (z,), insert_before=uops.uops.index(root)),)
ptr_ar(root)
def gate_rep(root, x, y, z, k):
root.vin = (x,y,z,uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root)))
return ld_rep(root,x,y)
def half_alu_rep(root):
if root.arg not in lang.supports_half:
root.vin = tuple([uops.add(UOps.CAST, dtypes.float32, (vv,), insert_before=(uops.uops.index(root))) for vv in root.vin])
root.dtype = dtypes.float32
return uops.add(UOps.CAST, dtypes.half, (root,), insert_before=(uops.uops.index(root)+1))
def ptr_ar(root):
def ptr_ar(root, uops):
assert root.arg in {'.shared', '.global', None}
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
@@ -83,16 +56,22 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
root.vin = (fptr, zero) + root.vin[2:]
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep),
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep),
({"__name__": "root", "uop": UOps.ALU, "dtype": dtypes.half}, half_alu_rep),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, st_rep),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, st_rep),
({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar),
({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
*[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
for op in lang.asm_for_op.keys() if op not in lang.supports_half],
({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
@@ -106,6 +85,18 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
if o in u.vin and u is not n:
u.vin = tuple(n if x == o else x for x in u.vin)
if rew := matcher.rewrite(u): replace[u] = rew
for o,n in replace.items():
queue = [n]
while queue:
if all([qq in uops.uops for qq in queue[-1].vin]):
q = queue.pop()
new = uops.add(q.uop, q.dtype, q.vin, q.arg, insert_before=max([uops.uops.index(vv) for vv in q.vin])+1)
for vv in uops.uops + queue: vv.vin = tuple(new if x is q else x for x in vv.vin)
else: queue.extend([qq for qq in queue[-1].vin if qq not in uops.uops])
uops.uops.remove(o)
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
def kk(*s: str): kernel.append("\n".join(s))