Optimize ptx loops (#4263)

* Optimize PTX loops

* Update assembly.py
This commit is contained in:
Szymon Ożóg
2024-04-23 10:20:14 +02:00
committed by GitHub
parent 967638f0d5
commit 6c25f1abf7
2 changed files with 5 additions and 2 deletions

View File

@@ -83,7 +83,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)),
])
# here we do a pretransform on UOps to fix some shortcomings of PTX
@@ -92,6 +91,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE}))
uops.optimize_loops()
def kk(*s: str): kernel.append("\n".join(s))