GPURunner class will replace CL cache eventually

This commit is contained in:
George Hotz
2023-02-08 17:31:36 -06:00
parent a5a55ac19e
commit c656513591

View File

@@ -27,6 +27,12 @@ def split_float4(x):
assert all(y.typ == Types.FLOAT4 for y in x)
return sum([[Token(acc.tok+f".s{s}", Types.FLOAT) for s in range(4)] for acc in x], [])
class GPURunner:
def __init__(self, clprg:CLProgram, bufs_to_delete:Set[int], global_work_size:List[int], local_work_size:Optional[List[int]]):
self.clprg, self.global_work_size, self.local_work_size, self.bufs_to_delete = clprg, global_work_size, local_work_size, bufs_to_delete
def __call__(self, *bufs):
return self.clprg(self.global_work_size, self.local_work_size, *[x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete])
class CLASTKernel(ASTKernel):
code_for_op : Final[Dict[Op, str]] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
@@ -310,10 +316,7 @@ class CLASTKernel(ASTKernel):
# compile kernel
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
def runner(*bufs):
clbufs = [x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete]
return self.fxn(self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, *clbufs)
return runner
return GPURunner(self.fxn, self.bufs_to_delete, self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None)
def print(self):
super().print()