mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
Move transformations to PatternMatcher + clean up existing patterns (#3997)
This commit is contained in:
@@ -45,25 +45,31 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
def eq_rep(root, x, y):
|
||||
root.arg = BinaryOps.XOR
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
|
||||
return new
|
||||
return uops.add(UOps.ALU, dtypes.bool, (root,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)+1)
|
||||
|
||||
def lt_rep(x, y):
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
|
||||
u.vin = (new, u.vin[1])
|
||||
u.vin = (uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)), u.vin[1])
|
||||
u.arg = BinaryOps.MUL
|
||||
|
||||
def ld_rep(root, x, y):
|
||||
root.dtype = dtypes.uint8
|
||||
new = uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
|
||||
ptr_ar(root)
|
||||
return new
|
||||
return uops.add(UOps.CAST, dtypes.bool, (root,), insert_before=uops.uops.index(root)+1)
|
||||
|
||||
def st_rep(root, z):
|
||||
root.vin = root.vin[:2] + (uops.add(UOps.CAST, dtypes.uint8, (z,), insert_before=uops.uops.index(root)),)
|
||||
ptr_ar(root)
|
||||
|
||||
def gate_rep(root, x, y, z, k):
|
||||
new = uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root))
|
||||
root.vin = (x,y,z,new)
|
||||
root.vin = (x,y,z,uops.add(UOps.CAST, dtypes.uint8, (k,), insert_before=uops.uops.index(root)))
|
||||
return ld_rep(root,x,y)
|
||||
|
||||
def half_alu_rep(root):
|
||||
if root.arg not in lang.supports_half:
|
||||
root.vin = tuple([uops.add(UOps.CAST, dtypes.float32, (vv,), insert_before=(uops.uops.index(root))) for vv in root.vin])
|
||||
root.dtype = dtypes.float32
|
||||
return uops.add(UOps.CAST, dtypes.half, (root,), insert_before=(uops.uops.index(root)+1))
|
||||
|
||||
def ptr_ar(root):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
@@ -79,9 +85,12 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, eq_rep),
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})}, lt_rep),
|
||||
({"__name__": "root", "uop": UOps.ALU, "dtype": dtypes.half}, half_alu_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool,
|
||||
"vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})}, gate_rep),
|
||||
({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({"__name__": "x"},{"__name__": "y"})}, ld_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})}, st_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, st_rep),
|
||||
({"__name__": "root", "uop": UOps.STORE, "vin": {}}, ptr_ar),
|
||||
({"__name__": "root", "uop": UOps.LOAD, "vin": {}}, ptr_ar),
|
||||
])
|
||||
@@ -155,21 +164,11 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
operands = [r[x] for x in vin]
|
||||
lab = ssa(u, "alu")
|
||||
if needs_upcast := dtype == dtypes.half and args not in lang.supports_half:
|
||||
dtype = dtypes.float32
|
||||
out_lab, lab = lab, ssa(None, "alu_cast", lang.types[dtype])
|
||||
for i, op in enumerate(operands):
|
||||
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
|
||||
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype]))
|
||||
if needs_upcast:
|
||||
kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype))
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
@@ -275,8 +274,6 @@ class PTXLanguage(AssemblyLanguage):
|
||||
else: return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"]
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
||||
if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype),
|
||||
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val}_cast;"]
|
||||
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}+{offset}], {val};"]
|
||||
|
||||
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user