mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
GPURunner class will replace CL cache eventually
This commit is contained in:
@@ -27,6 +27,12 @@ def split_float4(x):
|
||||
assert all(y.typ == Types.FLOAT4 for y in x)
|
||||
return sum([[Token(acc.tok+f".s{s}", Types.FLOAT) for s in range(4)] for acc in x], [])
|
||||
|
||||
class GPURunner:
|
||||
def __init__(self, clprg:CLProgram, bufs_to_delete:Set[int], global_work_size:List[int], local_work_size:Optional[List[int]]):
|
||||
self.clprg, self.global_work_size, self.local_work_size, self.bufs_to_delete = clprg, global_work_size, local_work_size, bufs_to_delete
|
||||
def __call__(self, *bufs):
|
||||
return self.clprg(self.global_work_size, self.local_work_size, *[x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete])
|
||||
|
||||
class CLASTKernel(ASTKernel):
|
||||
code_for_op : Final[Dict[Op, str]] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
|
||||
@@ -310,10 +316,7 @@ class CLASTKernel(ASTKernel):
|
||||
# compile kernel
|
||||
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
|
||||
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
|
||||
def runner(*bufs):
|
||||
clbufs = [x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete]
|
||||
return self.fxn(self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, *clbufs)
|
||||
return runner
|
||||
return GPURunner(self.fxn, self.bufs_to_delete, self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None)
|
||||
|
||||
def print(self):
|
||||
super().print()
|
||||
|
||||
Reference in New Issue
Block a user