diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cc1bd34854..6837da254c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -178,12 +178,6 @@ class ASTRunner: if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} - @staticmethod - def from_linearizer(k, src:str): - return ASTRunner(k.function_name, src, k.global_size, k.local_size, - op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name, runtime_args={"binary": False}) - def optimize_local_size(self, global_size, rawbufs) -> List[int]: assert self.global_size is not None, "needs a global size to optimize local size" MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024 diff --git a/tinygrad/renderer/triton.py b/tinygrad/renderer/triton.py index 0e0dcdfa4c..9f5b864c58 100644 --- a/tinygrad/renderer/triton.py +++ b/tinygrad/renderer/triton.py @@ -118,8 +118,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): codeObject = compile(prg, fn, "exec") exec(codeObject, globals()) # pylint: disable=W0122\ compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None)) - prg = compiled.asm["ptx"] - if getenv("CUDACPU"): prg = remove_single_scalar_curly_braces(prg.split(".file")[0].split(".visible .func")[0]) + prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0]) max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}