Use ShapeTracker for tracking shapes in kernels (#485)

* local is a normal buffer

* remove extra shapes and strides

* fix opt

* fix llvm
This commit is contained in:
George Hotz
2023-01-28 11:56:32 -08:00
committed by GitHub
parent 259c48f235
commit b3e4e678e8
4 changed files with 73 additions and 87 deletions

View File

@@ -192,13 +192,10 @@ class LLVMBuffer(ExplicitExecAST):
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
return k.ret
# cache miss, we have to process the kernel
k.process()
if DEBUG >= 2:
print(k.ast)
print("old:", k.shapes)
print("old:", k.strides)
print("old:", [x.shape for x in k.sts])
print("old:", [x.views[-1].strides for x in k.sts])
# this stuff can't be hand coded
kernel_output_axis : List[int] = []
@@ -242,12 +239,12 @@ class LLVMBuffer(ExplicitExecAST):
"""
# the 4x4 need to go all the way at the end, even after reduce
output_shape = k.shapes[0]
full_shape = [x for x in k.shapes if x != output_shape]
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
output_shape = k.sts[0].shape
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
kernel_output_dim = prod([k.shapes[0][a] for a in kernel_output_axis])
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
def get_idxs(builder, idx, buf_index):
@@ -279,13 +276,13 @@ class LLVMBuffer(ExplicitExecAST):
loop_exit = loop_exit[::-1]
# add the buffer indexing
idx_level = [[int_const(o)] for o in k.offsets]
idx_level = [[int_const(st.offset)] for st in k.sts]
for i in range(len(full_shape)):
for j in range(len(k.bufs)):
# stride
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
si.add_incoming(si_ps, loop_exit[i+1]._block)
idx_level[j].append(si)

View File

@@ -3,7 +3,7 @@ import itertools
from typing import List, Tuple
from tinygrad.helpers import prod, dedup, all_same
from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops
from tinygrad.shape import ShapeTracker
from tinygrad.shape import ShapeTracker, View
def get_first_reduce(shapes):
for i in range(len(shapes[0])):
@@ -50,6 +50,7 @@ class ASTKernel:
if hasattr(self.ret, "cl"): self.ret.cl # does the allocation of unbacked buffer, pylint: disable=W0104
self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs
self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))]
self.group_for_reduce : List[int] = []
# check valid AST kernel
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
@@ -57,9 +58,9 @@ class ASTKernel:
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
# process
# TODO: fetch from quick cache before processing
self.process()
self.group_for_reduce : List[int] = []
self.sts : List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
self.simplify_ones()
self.simplify_merge_adjacent()
def print(self):
buf_count = -1
@@ -84,30 +85,21 @@ class ASTKernel:
return cache[x]
print_ast(self.input_ast, "ast")
def process(self):
# get shape, strides, and offset
# if it's a multiview buffer we take the final view
self.shapes = [x.shape for x in self.bufs]
self.strides = [x.st.views[-1].strides for x in self.bufs]
self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is)
self.simplify_ones()
self.simplify_merge_adjacent()
@property
def shape_len(self): return len(self.sts[0].shape)
def simplify_ones(self):
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))]
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
# keep at least 1 one
if all(all_ones):
all_ones[-1] = False
self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes]
self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides]
if all(all_ones): all_ones[-1] = False
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
# find first mismatch, don't reduce this
self.first_reduce = get_first_reduce(self.shapes)
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
def simplify_merge_adjacent(self):
shapes, strides = self.shapes, self.strides
shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts]
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
@@ -125,45 +117,35 @@ class ASTKernel:
rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else:
rets[j].append((shapes[j][i], strides[j][i]))
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
self.first_reduce = get_first_reduce(self.shapes)
@property
def shape_len(self): return len(self.shapes[0])
for i,x in enumerate(rets): self.sts[i].reshape(*[y[0] for y in x])
self.first_reduce = get_first_reduce([x.shape for x in self.sts])
# this should be aware of the three parts to the shape
# * the input/output dimensions
# * the reduce dimensions
# * the size outputted by each kernel
def reshape_and_permute(self, new_shape_fxn, axis):
new_shapes, new_strides = [], []
for shape, stride in zip(self.shapes, self.strides):
st = ShapeTracker(tuple(shape))
st.strided(*zip(shape, stride))
# TODO: handle reduced shape here
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(shape))
for st in self.sts:
if new_shape_fxn is not None: st.reshape(*new_shape_fxn(st.shape))
if axis is not None: st.permute(*axis)
assert len(st.views) == 1
new_shapes.append(st.shape)
new_strides.append(st.strides)
self.shapes, self.strides = new_shapes, new_strides
# drops the final dimension
def upcast(self):
upcasted = [x[-1] for x in self.shapes if x[-1] != 1]
upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1]
assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}"
for i in range(len(self.bufs)):
if self.shapes[i][-1] == upcasted[0]:
st = self.sts[i]
if st.shape[-1] == upcasted[0]:
# multiview shapetrackers can slice through a float4, so don't allow them
can_merge = (not self.bufs[i].st.needs_valid() and len(self.bufs[i].st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack
if self.shapes[i][-1] == 4 and self.buftokens[i].typ == Types.FLOAT and self.strides[i][-1] == 1 and can_merge:
can_merge = (not st.needs_valid() and len(st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack
if st.shape[-1] == 4 and self.buftokens[i].typ == Types.FLOAT and st.views[-1].strides[-1] == 1 and can_merge:
# this is an upcast to FLOAT4
self.buftokens[i].typ = Types.FLOAT4
assert all(x%upcasted[0] == 0 for x in self.strides[i][0:-1])
assert self.offsets[i]%upcasted[0] == 0
assert all(st.views[-1].strides[i]%upcasted[0] == 0 or st.views[-1].shape[i] == 1 for i in range(len(st.shape)-1))
assert self.sts[i].offset % upcasted[0] == 0
else:
self.buftokens[i].array(upcasted[0], self.strides[i][-1])
self.buftokens[i].array(upcasted[0], st.views[-1].strides[-1])
# remove the last dimension
self.shapes = [x[:-1] for x in self.shapes]
self.strides = [x[:-1] for x in self.strides]
for st in self.sts: st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset)

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import prod
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
from tinygrad.ast import ASTKernel, Token, Types
from tinygrad.lazy import IMAGE
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.shape import ShapeTracker, ZeroView
from tinygrad.shape.symbolic import Variable, ModNode
CUDA = int(os.getenv("CUDA", "0"))
@@ -39,9 +39,8 @@ class CLASTKernel(ASTKernel):
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
# TODO: move to shapetracker
def compute_buf_index_symbolic(self, st, buf_index, offset=0):
view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset)
idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)])
def compute_buf_index_symbolic(self, st, offset=0):
idx = st.views[-1].expr_idxs([f"idx{i}" for i in range(self.shape_len)], offset)
valid = Variable.num(1)
for v in st.views[0:-1][::-1]:
if isinstance(v, ZeroView): valid = v.expr_node(valid, idx)
@@ -62,7 +61,7 @@ class CLASTKernel(ASTKernel):
if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value)
assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}"
for v, o in zip(value, self.buftokens[buf_index].offsets()):
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
assert str(valid) == "1", "store must always be valid"
assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}"
if isinstance(self.bufs[buf_index]._buf, CLImage):
@@ -80,7 +79,7 @@ class CLASTKernel(ASTKernel):
const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ)
if self.bufs[buf_index].st.needs_valid():
for o in self.buftokens[buf_index].offsets():
_, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
_, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const)
return tokens
else:
@@ -89,7 +88,7 @@ class CLASTKernel(ASTKernel):
# not constant folded
for o in self.buftokens[buf_index].offsets():
if (buf_index, o) not in self.loaded_keys:
idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o)
idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o)
if isinstance(self.bufs[buf_index]._buf, CLImage):
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
else:
@@ -125,10 +124,10 @@ class CLASTKernel(ASTKernel):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
# shove the axis to the end and remove
if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs):
eb_valids = [True] * len(self.shapes[0])
eb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs:
valids = [self.shapes[i][j]%4 == 0 and self.strides[i][j] == 1 for j in range(len(self.shapes[i]))]
valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)]
eb_valids = [x and y for x,y in zip(eb_valids, valids)]
assert any(eb_valids), f"invalid op with images {eb_valids}"
eb_valid = eb_valids.index(True)
@@ -146,9 +145,9 @@ class CLASTKernel(ASTKernel):
self.simplify_ones()
# are we grouping?
if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.shapes[0][:self.first_reduce]) <= 2048:
for sz in ([256, 16] if prod(self.shapes[0][:self.first_reduce]) <= 32 else [16]):
if all([x[self.first_reduce] % sz == 0 or x[self.first_reduce] == 1 for x in self.shapes]):
if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
self.group_for_reduce.append(sz)
break
@@ -161,9 +160,9 @@ class CLASTKernel(ASTKernel):
# if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape
self.upcast_in_mid_reduce = False
if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and self.buftokens[0].typ != Types.FLOAT4:
lb_valids = [True] * len(self.shapes[0])
lb_valids = [True] * self.shape_len
for i in range(len(self.bufs)):
valids = [self.shapes[i][j]%4 == 0 and (self.strides[i][j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(len(self.shapes[i]))]
valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)]
lb_valids = [x and y for x,y in zip(lb_valids, valids)]
assert any(lb_valids), f"invalid op with images {lb_valids}"
lb_valid = lb_valids.index(True)
@@ -186,11 +185,11 @@ class CLASTKernel(ASTKernel):
self.simplify_ones()
# split to 4 float4s
if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.shapes[0][:self.first_reduce]) >= 2048 and not self.group_for_reduce:
if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce:
xb_choices = []
for i in range(self.first_reduce):
if all(x[i]%4 == 0 for x in self.shapes):
xb_choices.append((sum(x[i]>0 for x in self.strides), sum(x[i] for x in self.strides), i))
if all(st.shape[i]%4 == 0 for st in self.sts):
xb_choices.append((sum(st.views[-1].strides[i]>0 for st in self.sts), sum(st.views[-1].strides[i] for st in self.sts), i))
if len(xb_choices):
xb_choice = sorted(xb_choices)[0][2]
@@ -210,7 +209,7 @@ class CLASTKernel(ASTKernel):
# use more opencl indexing
if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage):
base_shape = self.bufs[0]._base_shape
if all([(base_shape[0]*base_shape[1])%x[0] == 0 and x[0]//base_shape[0] != 0 for x in self.shapes]):
if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]):
if DEBUG >= 3: print("split opencl", base_shape, self.shapes[0])
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
self.simplify_ones()
@@ -228,17 +227,24 @@ class CLASTKernel(ASTKernel):
# group_for_reduce will have to be better first
def codegen(self):
if DEBUG >= 3:
print("old:", self.shapes)
print("old:", self.strides)
print("old:", [x.shape for x in self.sts])
print("old:", [x.views[-1].strides for x in self.sts])
if not CUDA: self.hand_coded_optimizations()
self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce))
self.bufs.append(local_buffer)
self.sts.append(local_buffer.st.copy())
self.buftokens.append(Token("temp", Types.FLOAT, ptr=True))
self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
print("output shape", self.output_shape)
for i in range(len(self.bufs)):
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i])
print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.sts[i])
self.bufs_to_delete : Set[int] = set()
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
@@ -261,8 +267,8 @@ class CLASTKernel(ASTKernel):
# early ast
accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())]
if self.reduceop:
full_shape = [x for x in self.shapes if x != self.shapes[0]]
full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0]
full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape]
full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0]
self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators]
self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
@@ -270,17 +276,17 @@ class CLASTKernel(ASTKernel):
# middle
if self.group_for_reduce:
self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n")
lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st)
assert str(lvalid) == "1", "local buffer must be valid"
self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n")
self.kernel.append(f"int mid_idx = {lidx.cl}; {self.buftokens[-1].tok}[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
if self.upcast_in_mid_reduce:
assert len(self.group_for_reduce) == 2
# it should be the last dimension
self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1])
self.upcast()
else:
assert len(self.group_for_reduce) == 1
self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n")
self.kernel.append("if (mid_idx == 0) {\n")
accumulators = [Token("output", self.buftokens[0].typ)]
@@ -304,7 +310,7 @@ class CLASTKernel(ASTKernel):
# compile kernel
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops)
mem_estimate = sum(prod(x) for x in self.shapes)
mem_estimate = sum(prod(x.shape) for x in self.sts)
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
def runner(*bufs):

View File

@@ -45,8 +45,8 @@ class View:
return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1)))
# generate an expression if you have a variable or expression for each index
def expr_idxs(self, idxs):
return Variable.sum([Variable.num(self.offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
def expr_idxs(self, idxs, offset=0):
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0])
class ZeroView:
def __init__(self, old_shape:Tuple[int, ...], arg):
@@ -95,6 +95,7 @@ class ShapeTracker:
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
def copy(self): return ShapeTracker(self.shape, self.views[:])
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous