Files
tinygrad/tinygrad/llops/ops_gpu.py
2023-02-17 11:54:50 -08:00

384 lines
22 KiB
Python

from __future__ import annotations
import numpy as np
from typing import List, Tuple, Optional, Dict, Union, Set, Final, Callable
from tinygrad.helpers import prod
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.ast import ASTKernel, Token, Types
from tinygrad.lazy import IMAGE
from tinygrad.shape import ShapeTracker
from tinygrad.shape.symbolic import ModNode, DivNode, render_python # this will go away when VALIDHACKS does
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
from tinygrad.helpers import getenv
CUDA,METAL = getenv("CUDA", 0), getenv("METAL", 0)
if not CUDA and not METAL: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram # NOTE: using CL will not work for the CUDA runtime # noqa: F401
elif CUDA: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
elif METAL: from tinygrad.runtime.metal import CLBuffer, CLImage, CLProgram # type: ignore
VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
KOPT = getenv("KOPT", 0)
PRINT_AST = getenv("PRINT_AST", "0")
TEST_AST = getenv("TEST_AST", 0)
def group_float4(x):
assert all(y.typ == Types.FLOAT for y in x) and len(x)%4 == 0
return [Token(f"(float4)({','.join([x[i+j].tok for j in range(4)])})", Types.FLOAT4) for i in range(0, len(x), 4)]
def split_float4(x):
assert all(y.typ == Types.FLOAT4 for y in x)
return sum([[Token(acc.tok+f".s{s}", Types.FLOAT) for s in range(4)] for acc in x], [])
class GPURunner:
def __init__(self, clprg:CLProgram, bufs_to_delete:Set[int], global_work_size:List[int], local_work_size:Optional[List[int]]):
self.clprg, self.global_work_size, self.local_work_size, self.bufs_to_delete = clprg, global_work_size, local_work_size, bufs_to_delete
def __call__(self, *bufs):
return self.clprg(self.global_work_size, self.local_work_size, *[x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete])
class CLASTKernel(ASTKernel):
code_for_op : Final[Dict[Op, str]] = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
UnaryOps.GT0: "(A > 0.)" if CUDA else "((float)1.-step((float)0.,(-A)))",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)"
}
start_for_op : Final[Dict[Op, str]] = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
def image_idx(self, buf_index, idxy, validhacks=False):
assert self.buftokens[buf_index].typ == Types.FLOAT4, f"image must be FLOAT4 {self.buftokens[buf_index]} {self.bufs[buf_index].st}"
idx = (idxy//4)%self.bufs[buf_index]._base_shape[1]
idy = (idxy//(4*self.bufs[buf_index]._base_shape[1]))%self.bufs[buf_index]._base_shape[0]
if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)]
return f"(int2)({idx.render(render_cl)}, {idy.render(render_cl)})"
def store(self, buf_index, value:List[Token]):
if len(value) == self.buftokens[buf_index].size()*4: value = group_float4(value)
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
for v, o in zip(value, self.buftokens[buf_index].offsets()):
idxy, valid = self.sts[buf_index].expr_idxs(o)
assert valid.min == 1, "store must always be valid"
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
if isinstance(self.bufs[buf_index]._buf, CLImage):
self.kernel.append(f"write_imagef(data{buf_index}, {self.image_idx(buf_index, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
else:
self.kernel.append(f"data{buf_index}[{(idxy//(4 if v.typ == Types.FLOAT4 else 1)).render(render_cl)}] = {v.tok};\n")
def load(self, buf_index:int) -> List[Token]:
# constant folding
const = None
if self.bufs[buf_index]._base_shape == (1,) and self.bufs[buf_index]._backing is not None:
assert self.buftokens[buf_index].typ == Types.FLOAT
if buf_index != 0: self.bufs_to_delete.add(buf_index)
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
tokens = []
for o in self.buftokens[buf_index].offsets():
key = f"val{buf_index}_{o}" if o >= 0 else f"val{buf_index}_m{-o}"
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.sts[buf_index].expr_idxs(o)
if const is not None:
ldr = const
elif isinstance(self.bufs[buf_index]._buf, CLImage):
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
else:
ldr = Token(f"{self.buftokens[buf_index].tok}[{(idxy//(4 if self.buftokens[buf_index].typ == Types.FLOAT4 else 1)).render(render_cl)}]", self.buftokens[buf_index].typ)
ldr = ldr if valid.min == 1 or (VALIDHACKS and isinstance(self.bufs[buf_index]._buf, CLImage)) else (Token(f"({valid.render(render_cl)} ? {ldr.tok} : 0.0f)", ldr.typ) if valid.max == 1 else Token("0.0f", ldr.typ))
if const is not None:
self.loaded_keys[(buf_index,o)] = ldr
else:
self.kernel.append(f"{ldr.decltype()} {key} = {ldr.tok};\n")
self.loaded_keys[(buf_index,o)] = Token(key, ldr.typ)
tokens.append(self.loaded_keys[(buf_index,o)])
return tokens
def ast_parse(self, x:Union[GPUBuffer, LazyOp], acc:List[Token], do_reduce=False) -> List[Token]:
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x))
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
values = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
if len(values) == 2:
# TODO: sometimes this is split, sometimes it's multiply
if isinstance(x.op, ReduceOps) and values[0][0].typ == Types.FLOAT4 and len(values[0])*4 == len(values[1]): values[0] = split_float4(values[0])
if values[0][0].typ != values[1][0].typ:
if isinstance(x.op, ReduceOps):
if x.op == ReduceOps.SUM: self.prekernel.add("float clreduce(float4 x) { return x.x + x.y + x.z + x.w; }\n")
elif x.op == ReduceOps.MAX: self.prekernel.add("float clreduce(float4 x) { return max(max(x.x, x.y), max(x.z, x.w)); }\n")
values[1] = [Token(f"clreduce({x.tok})", Types.FLOAT) for x in values[1]]
elif values[0][0].typ == Types.FLOAT: values[0] = group_float4(values[0])
elif values[1][0].typ == Types.FLOAT: values[1] = group_float4(values[1])
assert len(values[0]) == len(values[1]), f"values mismatch {values}"
return [Token(code.replace("A", a.tok).replace("B", b.tok), a.typ) for a,b in zip(values[0], values[1])]
else:
return [Token(code.replace("A", a.tok), a.typ) for a in values[0]]
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
# shove the axis to the end and remove
if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs):
eb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs:
valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)]
eb_valids = [x and y for x,y in zip(eb_valids, valids)]
assert any(eb_valids), f"invalid op with images {eb_valids}"
eb_valid = eb_valids.index(True)
if DEBUG >= 3: print(f"early merging axis {eb_valid} from {eb_valids}")
# no change, we added a dimension
self.reshape_and_permute(
lambda x: list(x[0:eb_valid]) + ([x[eb_valid]//4, 4] if x[eb_valid] > 1 else [1,1]) + list(x[eb_valid+1:]),
[i for i in range(self.shape_len+1) if i != eb_valid+1] + [eb_valid+1])
# drop the last dimension
self.upcast()
# simplify (sets first_reduce)
self.simplify_ones()
# are we grouping?
if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
self.group_for_reduce.append(sz)
break
# TODO: this makes re_S3_32_3_3 at least 10x faster
#if self.first_reduce == 4 and self.shape_len == 7:
#self.group_for_reduce.append(112//2)
# TODO: this shouldn't have to be permuted
#self.reshape_and_permute(None, [0,1,2,3,6,4,5])
# if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape
if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and self.buftokens[0].typ != Types.FLOAT4:
lb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
lb_valids = [x and y for x,y in zip(lb_valids, valids)]
assert any(lb_valids), f"invalid op with images {lb_valids}"
lb_valid = lb_valids.index(True)
assert lb_valid < self.first_reduce, f"can't be in the reduce {lb_valid}"
if DEBUG >= 3: print(f"late merging axis {lb_valid} from {lb_valids}")
# no change, we added a dimension
self.reshape_and_permute(
lambda x: list(x[0:lb_valid]) + [x[lb_valid]//4, 4] + list(x[lb_valid+1:]),
[i for i in range(self.shape_len+1) if i != lb_valid+1] + [lb_valid+1])
if self.group_for_reduce and self.first_reduce <= 2:
self.upcast_in_mid_reduce = True
self.group_for_reduce.append(4)
else:
# drop the last dimension
self.upcast()
# simplify (sets first_reduce)
self.simplify_ones()
# split to 4 float4s
if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
xb_choices = []
for i in range(self.first_reduce):
if all(st.shape[i]%4 == 0 for st in self.sts):
xb_choices.append((sum(st.views[-1].strides[i]>0 for st in self.sts), sum(st.views[-1].strides[i] for st in self.sts), i))
if len(xb_choices):
xb_choice = sorted(xb_choices)[0][2]
if DEBUG >= 3: print(f"float4 merging axis {xb_choice} : {xb_choices}")
# this leaves the last axis in place
self.reshape_and_permute(
lambda x: list(x[0:xb_choice]) + [x[xb_choice]//4, 4] + list(x[xb_choice+1:]),
[i for i in range(self.shape_len+1) if i != xb_choice+1] + [xb_choice+1])
# drop the last dimension
self.upcast()
# re-simplify
self.simplify_ones()
# use more opencl indexing
if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage):
base_shape = self.bufs[0]._base_shape
if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]):
if DEBUG >= 3: print("split opencl", base_shape, self.sts[0].shape)
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
self.simplify_ones()
# group for reduce
if len(self.group_for_reduce):
# with permute for memory coalesing
if len(self.group_for_reduce) == 2:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len))
else:
permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1))
self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis)
# if last dim <= 3 and it's a reduce dim, upcast (loop unrolling)
end_dimension = max([st.shape[-1] for st in self.sts])
if self.first_reduce < self.shape_len and end_dimension > 1 and end_dimension <= 3 and max([x.size() for i,x in enumerate(self.buftokens) if self.bufs[i] in self.earlybufs]) <= 4:
self.upcast()
def printbufs(self, prefix=""):
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
for i in range(len(self.sts)):
print(prefix, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'} image:{'T' if i < len(self.bufs) and isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides)
# STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD
# group_for_reduce will have to be better first
def codegen(self) -> Callable:
self.process()
self.upcast_in_mid_reduce = False
if not KOPT: self.hand_coded_optimizations()
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))))
self.buftokens.append(Token("temp", Types.FLOAT, ptr=True))
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
print("output shape", self.output_shape)
self.printbufs("new:")
self.bufs_to_delete : Set[int] = set()
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
self.prekernel : Set[str] = set()
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(isinstance(buf._buf, CLImage) for buf in self.bufs) else []
# output_shape[-1] is get_global_id(0)
MAX_OUTPUT_SHAPE = 3
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {CLProgram.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
# sometimes, there's more dimensions. compact all the dimensions into the first one
# TODO: these compactions should be searchable
final_dimension = len(self.output_shape)-MAX_OUTPUT_SHAPE
for i in range(final_dimension-1, -1, -1):
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}")
# early ast
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
if self.reduceop is not None:
full_shape_candidates = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
full_shape : Tuple[int, ...] = self.sts[0].shape if len(full_shape_candidates) == 0 else full_shape_candidates[0]
acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets()
assert self.reduceopop is not None
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceopop]};\n" for accumulator in accumulators]
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
expanded_accumulators = split_float4(accumulators) if accumulators[0].typ == Types.FLOAT4 and len(accumulators)*4 == len(acc_offsets) else accumulators
self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [expanded_accumulators[off] for off in acc_offsets], do_reduce=True)] + ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce)))
# middle
if self.group_for_reduce:
lidx, lvalid = self.sts[-1].expr_idxs()
assert lvalid.min == 1, "local buffer must always be valid"
self.kernel.append(f"int mid_idx = {lidx.render(render_cl)};\n")
for i,acc in enumerate(accumulators):
self.kernel.append(CLProgram.smem_prefix + f"{acc.decltype()} {self.buftokens[-1].tok}{i}[{prod(self.group_for_reduce)}];")
self.kernel.append(f"{self.buftokens[-1].tok}{i}[mid_idx] = {acc.tok};\n")
self.kernel.append(CLProgram.barrier+"\n")
if self.upcast_in_mid_reduce:
assert len(self.group_for_reduce) == 2
# it should be the last dimension
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1])
self.upcast()
assert self.reduceopop is not None
self.kernel.append("if (mid_idx == 0) {\n")
new_accumulators = [Token(f"output{i}", self.buftokens[0].typ) for i in range(len(accumulators))]
for i,acc in enumerate(new_accumulators):
self.kernel.append(f"{acc.decltype()} {acc.tok} = 0.0;")
if self.upcast_in_mid_reduce:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)//4}; mid++) {{ {CLASTKernel.code_for_op[self.reduceopop].replace('A', acc.tok).replace('B', f'vload4(0, &temp{i}[mid*4])')}; }}\n")
else:
self.kernel.append(f"for (int mid = 0; mid < {prod(self.group_for_reduce)}; mid++) {{ {CLASTKernel.code_for_op[self.reduceopop].replace('A', acc.tok).replace('B', f'temp{i}[mid]')}; }}\n")
accumulators = new_accumulators
# late ast
self.store(0, self.ast_parse(self.ast, accumulators))
if self.group_for_reduce: self.kernel.append("}")
self.kernel.append("}")
# kernel function definition
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else (CLProgram.buffer_prefix+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{CLProgram.kernel_prefix} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + (['uint3 gid [[thread_position_in_grid]]'] if METAL else []))] + \
[") {\n"] + self.kernel
# compile kernel
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs))
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
return GPURunner(self.fxn, self.bufs_to_delete, self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None)
def print(self):
super().print()
for i in range(len(self.bufs)):
print(self.buftokens[i], self.bufs[i] in self.earlybufs, self.sts[i])
print(self.fxn.prg)
class GPUBuffer(ExplicitExecAST):
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None, force_create=False):
super().__init__(shape, hostbuf)
self._buf : Optional[Union[CLImage, CLBuffer]] = hostbuf._buf if hostbuf is not None else None
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
# early copy in for large buffers
if (self._backing is not None and self._backing.shape != (1,)) or force_create:
self.cl
# TODO: refactor this to return self._buf and not import pyopencl
@property
def cl(self) -> Union[CLBuffer, CLImage]:
if self._buf is None:
self._buf = CLImage(self._base_shape) if (len(self._base_shape) == 3 and self._base_shape[2] == 4 and IMAGE >= 2) else CLBuffer(4*prod(self._base_shape))
assert self._buf is not None
if self._backing is not None:
assert GlobalCounters.cache is None, f"can't copy in {self._backing.shape} while caching"
self._buf.copyin(self._backing)
self._backing = None
return self._buf._cl
# TODO: we don't always need a hostbuf
def __repr__(self): return f"GPUBuffer(shape={self.st}, hostbuf=GPUBuffer(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.float32)))" if self._backing else ", force_create=True))")
@staticmethod
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
def toCPU(self) -> np.ndarray:
cl_buf = self.contiguous()
cl_buf.cl # force buffer creation, happens if it's a backed buffer that hasn't been created yet
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, tuple(list(self.shape)+[1])).contiguous()
assert prod(cl_buf._base_shape) == prod(self.shape), f"shape product mismatch {cl_buf._base_shape} vs {self.shape}"
data = np.empty(self.shape, dtype=np.float32)
assert GlobalCounters.cache is None, f"can't copy out {self} while caching"
cl_buf._buf.copyout(data)
return data
@classmethod
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None):
k = CLASTKernel(ast, output_buffer)
if KOPT:
from extra.kernel_search import apply_optimization
apply_optimization(k, ast, max_interventions=KOPT)
prg = k.codegen()
if GlobalCounters.cache is not None: GlobalCounters.cache.append((prg, k.bufs))
prg(*k.bufs)
if PRINT_AST == "1" or (hasattr(k, "fxn") and PRINT_AST == k.fxn.name):
print(k.fxn.name)
k.print()
if TEST_AST:
from extra.lib_test_ast import test_ast # type: ignore
test_ast(k)
return k.ret