mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
Use ShapeTracker for tracking shapes in kernels (#485)
* local is a normal buffer * remove extra shapes and strides * fix opt * fix llvm
This commit is contained in:
@@ -192,13 +192,10 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
|
||||
return k.ret
|
||||
|
||||
# cache miss, we have to process the kernel
|
||||
k.process()
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(k.ast)
|
||||
print("old:", k.shapes)
|
||||
print("old:", k.strides)
|
||||
print("old:", [x.shape for x in k.sts])
|
||||
print("old:", [x.views[-1].strides for x in k.sts])
|
||||
|
||||
# this stuff can't be hand coded
|
||||
kernel_output_axis : List[int] = []
|
||||
@@ -242,12 +239,12 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
"""
|
||||
|
||||
# the 4x4 need to go all the way at the end, even after reduce
|
||||
output_shape = k.shapes[0]
|
||||
full_shape = [x for x in k.shapes if x != output_shape]
|
||||
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
|
||||
output_shape = k.sts[0].shape
|
||||
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
|
||||
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
|
||||
|
||||
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
|
||||
kernel_output_dim = prod([k.shapes[0][a] for a in kernel_output_axis])
|
||||
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
|
||||
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
|
||||
|
||||
def get_idxs(builder, idx, buf_index):
|
||||
@@ -279,13 +276,13 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
loop_exit = loop_exit[::-1]
|
||||
|
||||
# add the buffer indexing
|
||||
idx_level = [[int_const(o)] for o in k.offsets]
|
||||
idx_level = [[int_const(st.offset)] for st in k.sts]
|
||||
for i in range(len(full_shape)):
|
||||
for j in range(len(k.bufs)):
|
||||
# stride
|
||||
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
|
||||
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
|
||||
si.add_incoming(si_ps, loop_exit[i+1]._block)
|
||||
idx_level[j].append(si)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import itertools
|
||||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod, dedup, all_same
|
||||
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.shape import ShapeTracker, View
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
for i in range(len(shapes[0])):
|
||||
@@ -50,6 +50,7 @@ class ASTKernel:
|
||||
if hasattr(self.ret, "cl"): self.ret.cl # does the allocation of unbacked buffer, pylint: disable=W0104
|
||||
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
|
||||
self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))]
|
||||
self.group_for_reduce : List[int] = []
|
||||
|
||||
# check valid AST kernel
|
||||
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
|
||||
@@ -57,9 +58,9 @@ class ASTKernel:
|
||||
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
|
||||
|
||||
# process
|
||||
# TODO: fetch from quick cache before processing
|
||||
self.process()
|
||||
self.group_for_reduce : List[int] = []
|
||||
self.sts : List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
|
||||
def print(self):
|
||||
buf_count = -1
|
||||
@@ -84,30 +85,21 @@ class ASTKernel:
|
||||
return cache[x]
|
||||
print_ast(self.input_ast, "ast")
|
||||
|
||||
|
||||
def process(self):
|
||||
# get shape, strides, and offset
|
||||
# if it's a multiview buffer we take the final view
|
||||
self.shapes = [x.shape for x in self.bufs]
|
||||
self.strides = [x.st.views[-1].strides for x in self.bufs]
|
||||
self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is)
|
||||
self.simplify_ones()
|
||||
self.simplify_merge_adjacent()
|
||||
@property
|
||||
def shape_len(self): return len(self.sts[0].shape)
|
||||
|
||||
def simplify_ones(self):
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))]
|
||||
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
|
||||
# keep at least 1 one
|
||||
if all(all_ones):
|
||||
all_ones[-1] = False
|
||||
self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes]
|
||||
self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides]
|
||||
if all(all_ones): all_ones[-1] = False
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
# find first mismatch, don't reduce this
|
||||
self.first_reduce = get_first_reduce(self.shapes)
|
||||
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
shapes, strides = self.shapes, self.strides
|
||||
shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts]
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
@@ -125,45 +117,35 @@ class ASTKernel:
|
||||
rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else:
|
||||
rets[j].append((shapes[j][i], strides[j][i]))
|
||||
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
|
||||
self.first_reduce = get_first_reduce(self.shapes)
|
||||
|
||||
@property
|
||||
def shape_len(self): return len(self.shapes[0])
|
||||
for i,x in enumerate(rets): self.sts[i].reshape(*[y[0] for y in x])
|
||||
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
|
||||
|
||||
# this should be aware of the three parts to the shape
|
||||
# * the input/output dimensions
|
||||
# * the reduce dimensions
|
||||
# * the size outputted by each kernel
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
new_shapes, new_strides = [], []
|
||||
for shape, stride in zip(self.shapes, self.strides):
|
||||
st = ShapeTracker(tuple(shape))
|
||||
st.strided(*zip(shape, stride))
|
||||
# TODO: handle reduced shape here
|
||||
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(shape))
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(st.shape))
|
||||
if axis is not None: st.permute(*axis)
|
||||
assert len(st.views) == 1
|
||||
new_shapes.append(st.shape)
|
||||
new_strides.append(st.strides)
|
||||
self.shapes, self.strides = new_shapes, new_strides
|
||||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
upcasted = [x[-1] for x in self.shapes if x[-1] != 1]
|
||||
upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1]
|
||||
assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}"
|
||||
for i in range(len(self.bufs)):
|
||||
if self.shapes[i][-1] == upcasted[0]:
|
||||
st = self.sts[i]
|
||||
if st.shape[-1] == upcasted[0]:
|
||||
# multiview shapetrackers can slice through a float4, so don't allow them
|
||||
can_merge = (not self.bufs[i].st.needs_valid() and len(self.bufs[i].st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack
|
||||
if self.shapes[i][-1] == 4 and self.buftokens[i].typ == Types.FLOAT and self.strides[i][-1] == 1 and can_merge:
|
||||
can_merge = (not st.needs_valid() and len(st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack
|
||||
if st.shape[-1] == 4 and self.buftokens[i].typ == Types.FLOAT and st.views[-1].strides[-1] == 1 and can_merge:
|
||||
# this is an upcast to FLOAT4
|
||||
self.buftokens[i].typ = Types.FLOAT4
|
||||
assert all(x%upcasted[0] == 0 for x in self.strides[i][0:-1])
|
||||
assert self.offsets[i]%upcasted[0] == 0
|
||||
assert all(st.views[-1].strides[i]%upcasted[0] == 0 or st.views[-1].shape[i] == 1 for i in range(len(st.shape)-1))
|
||||
assert self.sts[i].offset % upcasted[0] == 0
|
||||
else:
|
||||
self.buftokens[i].array(upcasted[0], self.strides[i][-1])
|
||||
self.buftokens[i].array(upcasted[0], st.views[-1].strides[-1])
|
||||
|
||||
# remove the last dimension
|
||||
self.shapes = [x[:-1] for x in self.shapes]
|
||||
self.strides = [x[:-1] for x in self.strides]
|
||||
for st in self.sts: st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.helpers import prod
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
|
||||
from tinygrad.ast import ASTKernel, Token, Types
|
||||
from tinygrad.lazy import IMAGE
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView
|
||||
from tinygrad.shape import ShapeTracker, ZeroView
|
||||
from tinygrad.shape.symbolic import Variable, ModNode
|
||||
|
||||
CUDA = int(os.getenv("CUDA", "0"))
|
||||
@@ -39,9 +39,8 @@ class CLASTKernel(ASTKernel):
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
|
||||
|
||||
# TODO: move to shapetracker
|
||||
def compute_buf_index_symbolic(self, st, buf_index, offset=0):
|
||||
view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset)
|
||||
idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)])
|
||||
def compute_buf_index_symbolic(self, st, offset=0):
|
||||
idx = st.views[-1].expr_idxs([f"idx{i}" for i in range(self.shape_len)], offset)
|
||||
valid = Variable.num(1)
|
||||
for v in st.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
|
||||
@@ -62,7 +61,7 @@ class CLASTKernel(ASTKernel):
|
||||
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
|
||||
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
|
||||
for v, o in zip(value, self.buftokens[buf_index].offsets()):
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
|
||||
assert str(valid) == "1", "store must always be valid"
|
||||
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
@@ -80,7 +79,7 @@ class CLASTKernel(ASTKernel):
|
||||
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
|
||||
if self.bufs[buf_index].st.needs_valid():
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
_, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
|
||||
_, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
|
||||
tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const)
|
||||
return tokens
|
||||
else:
|
||||
@@ -89,7 +88,7 @@ class CLASTKernel(ASTKernel):
|
||||
# not constant folded
|
||||
for o in self.buftokens[buf_index].offsets():
|
||||
if (buf_index, o) not in self.loaded_keys:
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
|
||||
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
|
||||
if isinstance(self.bufs[buf_index]._buf, CLImage):
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
else:
|
||||
@@ -125,10 +124,10 @@ class CLASTKernel(ASTKernel):
|
||||
# if there's images in the earlybufs, we have to make an axis the 4 loading one
|
||||
# shove the axis to the end and remove
|
||||
if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs):
|
||||
eb_valids = [True] * len(self.shapes[0])
|
||||
eb_valids = [True] * self.shape_len
|
||||
for i in range(len(self.bufs)):
|
||||
if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs:
|
||||
valids = [self.shapes[i][j]%4 == 0 and self.strides[i][j] == 1 for j in range(len(self.shapes[i]))]
|
||||
valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)]
|
||||
eb_valids = [x and y for x,y in zip(eb_valids, valids)]
|
||||
assert any(eb_valids), f"invalid op with images {eb_valids}"
|
||||
eb_valid = eb_valids.index(True)
|
||||
@@ -146,9 +145,9 @@ class CLASTKernel(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# are we grouping?
|
||||
if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.shapes[0][:self.first_reduce]) <= 2048:
|
||||
for sz in ([256, 16] if prod(self.shapes[0][:self.first_reduce]) <= 32 else [16]):
|
||||
if all([x[self.first_reduce] % sz == 0 or x[self.first_reduce] == 1 for x in self.shapes]):
|
||||
if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
|
||||
self.group_for_reduce.append(sz)
|
||||
break
|
||||
|
||||
@@ -161,9 +160,9 @@ class CLASTKernel(ASTKernel):
|
||||
# if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape
|
||||
self.upcast_in_mid_reduce = False
|
||||
if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and self.buftokens[0].typ != Types.FLOAT4:
|
||||
lb_valids = [True] * len(self.shapes[0])
|
||||
lb_valids = [True] * self.shape_len
|
||||
for i in range(len(self.bufs)):
|
||||
valids = [self.shapes[i][j]%4 == 0 and (self.strides[i][j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(len(self.shapes[i]))]
|
||||
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
|
||||
lb_valids = [x and y for x,y in zip(lb_valids, valids)]
|
||||
assert any(lb_valids), f"invalid op with images {lb_valids}"
|
||||
lb_valid = lb_valids.index(True)
|
||||
@@ -186,11 +185,11 @@ class CLASTKernel(ASTKernel):
|
||||
self.simplify_ones()
|
||||
|
||||
# split to 4 float4s
|
||||
if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.shapes[0][:self.first_reduce]) >= 2048 and not self.group_for_reduce:
|
||||
if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
|
||||
xb_choices = []
|
||||
for i in range(self.first_reduce):
|
||||
if all(x[i]%4 == 0 for x in self.shapes):
|
||||
xb_choices.append((sum(x[i]>0 for x in self.strides), sum(x[i] for x in self.strides), i))
|
||||
if all(st.shape[i]%4 == 0 for st in self.sts):
|
||||
xb_choices.append((sum(st.views[-1].strides[i]>0 for st in self.sts), sum(st.views[-1].strides[i] for st in self.sts), i))
|
||||
|
||||
if len(xb_choices):
|
||||
xb_choice = sorted(xb_choices)[0][2]
|
||||
@@ -210,7 +209,7 @@ class CLASTKernel(ASTKernel):
|
||||
# use more opencl indexing
|
||||
if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage):
|
||||
base_shape = self.bufs[0]._base_shape
|
||||
if all([(base_shape[0]*base_shape[1])%x[0] == 0 and x[0]//base_shape[0] != 0 for x in self.shapes]):
|
||||
if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]):
|
||||
if DEBUG >= 3: print("split opencl", base_shape, self.shapes[0])
|
||||
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
|
||||
self.simplify_ones()
|
||||
@@ -228,17 +227,24 @@ class CLASTKernel(ASTKernel):
|
||||
# group_for_reduce will have to be better first
|
||||
def codegen(self):
|
||||
if DEBUG >= 3:
|
||||
print("old:", self.shapes)
|
||||
print("old:", self.strides)
|
||||
|
||||
print("old:", [x.shape for x in self.sts])
|
||||
print("old:", [x.views[-1].strides for x in self.sts])
|
||||
|
||||
if not CUDA: self.hand_coded_optimizations()
|
||||
|
||||
self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))
|
||||
self.bufs.append(local_buffer)
|
||||
self.sts.append(local_buffer.st.copy())
|
||||
self.buftokens.append(Token("temp", Types.FLOAT, ptr=True))
|
||||
|
||||
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
|
||||
if DEBUG >= 3:
|
||||
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
|
||||
print("output shape", self.output_shape)
|
||||
for i in range(len(self.bufs)):
|
||||
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i])
|
||||
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.sts[i])
|
||||
|
||||
self.bufs_to_delete : Set[int] = set()
|
||||
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
|
||||
@@ -261,8 +267,8 @@ class CLASTKernel(ASTKernel):
|
||||
# early ast
|
||||
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
|
||||
if self.reduceop:
|
||||
full_shape = [x for x in self.shapes if x != self.shapes[0]]
|
||||
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
|
||||
full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
|
||||
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
|
||||
|
||||
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
|
||||
@@ -270,17 +276,17 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# middle
|
||||
if self.group_for_reduce:
|
||||
self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n")
|
||||
lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st)
|
||||
assert str(lvalid) == "1", "local buffer must be valid"
|
||||
|
||||
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")
|
||||
self.kernel.append(f"int mid_idx = {lidx.cl}; {self.buftokens[-1].tok}[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
|
||||
if self.upcast_in_mid_reduce:
|
||||
assert len(self.group_for_reduce) == 2
|
||||
# it should be the last dimension
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1])
|
||||
self.upcast()
|
||||
else:
|
||||
assert len(self.group_for_reduce) == 1
|
||||
self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
|
||||
|
||||
self.kernel.append("if (mid_idx == 0) {\n")
|
||||
accumulators = [Token("output", self.buftokens[0].typ)]
|
||||
@@ -304,7 +310,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# compile kernel
|
||||
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops)
|
||||
mem_estimate = sum(prod(x) for x in self.shapes)
|
||||
mem_estimate = sum(prod(x.shape) for x in self.sts)
|
||||
|
||||
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
|
||||
def runner(*bufs):
|
||||
|
||||
@@ -45,8 +45,8 @@ class View:
|
||||
return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1)))
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(self, idxs):
|
||||
return Variable.sum([Variable.num(self.offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
|
||||
def expr_idxs(self, idxs, offset=0):
|
||||
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
|
||||
|
||||
class ZeroView:
|
||||
def __init__(self, old_shape:Tuple[int, ...], arg):
|
||||
@@ -95,6 +95,7 @@ class ShapeTracker:
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
|
||||
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
|
||||
def copy(self): return ShapeTracker(self.shape, self.views[:])
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous
|
||||
|
||||
Reference in New Issue
Block a user