mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
cleanup triton (#2092)
* Revert "disable flaky triton test" This reverts commit1e15fdaee7. * Update test.yml * check if has shared for matvec * disable ocelot cache for triton * disable ocelot cache * disable ocelot cache * pass shared to triton uops tests * temporary debugs for CI crash * Revert "temporary debugs for CI crash" This reverts commitfee3ea96c8. * Revert "triton isn't tested, and allows this refactor (#2007)" This reverts commitdea8bb0938. * add runtime_args to every renderer, move triton local size override to runtime args * Add binary to args, correct type returned * update to new loops * Update test.yml * cleanup triton
This commit is contained in:
@@ -178,12 +178,6 @@ class ASTRunner:
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
@staticmethod
|
||||
def from_linearizer(k, src:str):
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name, runtime_args={"binary": False})
|
||||
|
||||
def optimize_local_size(self, global_size, rawbufs) -> List[int]:
|
||||
assert self.global_size is not None, "needs a global size to optimize local size"
|
||||
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
|
||||
|
||||
@@ -118,8 +118,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
codeObject = compile(prg, fn, "exec")
|
||||
exec(codeObject, globals()) # pylint: disable=W0122\
|
||||
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
|
||||
prg = compiled.asm["ptx"]
|
||||
if getenv("CUDACPU"): prg = remove_single_scalar_curly_braces(prg.split(".file")[0].split(".visible .func")[0])
|
||||
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
|
||||
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
|
||||
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
|
||||
return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
|
||||
|
||||
Reference in New Issue
Block a user