diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 7817261c83..ae6ddd6d4b 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -338,7 +338,7 @@ class UOpGraph: self.replace_op(u, new) return True - def uoptimize(self): + def optimize_loops(self): # get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent" acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list) for u in self.uops: @@ -356,6 +356,9 @@ class UOpGraph: while self.uops_optimization(get_recursive_parents): pass self.simplify_phi_loops(get_recursive_parents) + def uoptimize(self): + self.optimize_loops() + # (recursively) remove childless uops # TODO: remove DEFINE_GLOBAL from here self.remove_childless(set(x for x in self.uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.STORE})) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index c772a065d1..74144ee998 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -83,7 +83,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)), ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})}, lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,), None),), root.arg)), - ]) # here we do a pretransform on UOps to fix some shortcomings of PTX @@ -92,6 +91,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: for pointer_op in list(filter(lambda uop: uop.uop in [UOps.LOAD, UOps.STORE], uops.uops)): ptr_ar(pointer_op, uops) uops.remove_childless(set(x for x in uops if x.uop in {UOps.DEFINE_GLOBAL, UOps.PHI, UOps.ENDIF, UOps.ENDLOOP, UOps.STORE})) + uops.optimize_loops() def kk(*s: str): kernel.append("\n".join(s))