From b3e4e678e8faeb69a5d339af2035304ca86411f8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 28 Jan 2023 11:56:32 -0800 Subject: [PATCH] Use ShapeTracker for tracking shapes in kernels (#485) * local is a normal buffer * remove extra shapes and strides * fix opt * fix llvm --- accel/llvm/ops_llvm.py | 19 +++++------ tinygrad/ast.py | 68 ++++++++++++++------------------------ tinygrad/llops/ops_gpu.py | 68 +++++++++++++++++++++----------------- tinygrad/shape/__init__.py | 5 +-- 4 files changed, 73 insertions(+), 87 deletions(-) diff --git a/accel/llvm/ops_llvm.py b/accel/llvm/ops_llvm.py index 62f72f1c77..53daf69605 100644 --- a/accel/llvm/ops_llvm.py +++ b/accel/llvm/ops_llvm.py @@ -192,13 +192,10 @@ class LLVMBuffer(ExplicitExecAST): LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs]) return k.ret - # cache miss, we have to process the kernel - k.process() - if DEBUG >= 2: print(k.ast) - print("old:", k.shapes) - print("old:", k.strides) + print("old:", [x.shape for x in k.sts]) + print("old:", [x.views[-1].strides for x in k.sts]) # this stuff can't be hand coded kernel_output_axis : List[int] = [] @@ -242,12 +239,12 @@ class LLVMBuffer(ExplicitExecAST): """ # the 4x4 need to go all the way at the end, even after reduce - output_shape = k.shapes[0] - full_shape = [x for x in k.shapes if x != output_shape] - full_shape = output_shape if len(full_shape) == 0 else full_shape[0] + output_shape = k.sts[0].shape + full_shape_options = [x.shape for x in k.sts if x.shape != output_shape] + full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0] full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)] - kernel_output_dim = prod([k.shapes[0][a] for a in kernel_output_axis]) + kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis]) kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim) def get_idxs(builder, idx, buf_index): @@ -279,13 +276,13 @@ class LLVMBuffer(ExplicitExecAST): loop_exit = loop_exit[::-1] # add the buffer indexing - idx_level = [[int_const(o)] for o in k.offsets] + idx_level = [[int_const(st.offset)] for st in k.sts] for i in range(len(full_shape)): for j in range(len(k.bufs)): # stride si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}") si.add_incoming(idx_level[j][-1], loop_entry[i]._block) - si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i])) + si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i])) si.add_incoming(si_ps, loop_exit[i+1]._block) idx_level[j].append(si) diff --git a/tinygrad/ast.py b/tinygrad/ast.py index 01351d2668..3ea8b49cc1 100644 --- a/tinygrad/ast.py +++ b/tinygrad/ast.py @@ -3,7 +3,7 @@ import itertools from typing import List, Tuple from tinygrad.helpers import prod, dedup, all_same from tinygrad.ops import LazyOp, MovementOps, get_lazyop_info, get_buffers, ReduceOps, get_lazyops -from tinygrad.shape import ShapeTracker +from tinygrad.shape import ShapeTracker, View def get_first_reduce(shapes): for i in range(len(shapes[0])): @@ -50,6 +50,7 @@ class ASTKernel: if hasattr(self.ret, "cl"): self.ret.cl # does the allocation of unbacked buffer, pylint: disable=W0104 self.bufs = [type(self.ret)(self.info.shape, hostbuf=self.ret)] + self.bufs self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))] + self.group_for_reduce : List[int] = [] # check valid AST kernel assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape" @@ -57,9 +58,9 @@ class ASTKernel: assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size" # process - # TODO: fetch from quick cache before processing - self.process() - self.group_for_reduce : List[int] = [] + self.sts : List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel + self.simplify_ones() + self.simplify_merge_adjacent() def print(self): buf_count = -1 @@ -84,30 +85,21 @@ class ASTKernel: return cache[x] print_ast(self.input_ast, "ast") - - def process(self): - # get shape, strides, and offset - # if it's a multiview buffer we take the final view - self.shapes = [x.shape for x in self.bufs] - self.strides = [x.st.views[-1].strides for x in self.bufs] - self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is) - self.simplify_ones() - self.simplify_merge_adjacent() + @property + def shape_len(self): return len(self.sts[0].shape) def simplify_ones(self): # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride - all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))] + all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] # keep at least 1 one - if all(all_ones): - all_ones[-1] = False - self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes] - self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides] + if all(all_ones): all_ones[-1] = False + self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) # find first mismatch, don't reduce this - self.first_reduce = get_first_reduce(self.shapes) + self.first_reduce = get_first_reduce([x.shape for x in self.sts]) def simplify_merge_adjacent(self): - shapes, strides = self.shapes, self.strides + shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO @@ -125,45 +117,35 @@ class ASTKernel: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) else: rets[j].append((shapes[j][i], strides[j][i])) - self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] - self.first_reduce = get_first_reduce(self.shapes) - @property - def shape_len(self): return len(self.shapes[0]) + for i,x in enumerate(rets): self.sts[i].reshape(*[y[0] for y in x]) + self.first_reduce = get_first_reduce([x.shape for x in self.sts]) # this should be aware of the three parts to the shape # * the input/output dimensions # * the reduce dimensions # * the size outputted by each kernel def reshape_and_permute(self, new_shape_fxn, axis): - new_shapes, new_strides = [], [] - for shape, stride in zip(self.shapes, self.strides): - st = ShapeTracker(tuple(shape)) - st.strided(*zip(shape, stride)) - # TODO: handle reduced shape here - if new_shape_fxn is not None: st.reshape(*new_shape_fxn(shape)) + for st in self.sts: + if new_shape_fxn is not None: st.reshape(*new_shape_fxn(st.shape)) if axis is not None: st.permute(*axis) - assert len(st.views) == 1 - new_shapes.append(st.shape) - new_strides.append(st.strides) - self.shapes, self.strides = new_shapes, new_strides # drops the final dimension def upcast(self): - upcasted = [x[-1] for x in self.shapes if x[-1] != 1] + upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" for i in range(len(self.bufs)): - if self.shapes[i][-1] == upcasted[0]: + st = self.sts[i] + if st.shape[-1] == upcasted[0]: # multiview shapetrackers can slice through a float4, so don't allow them - can_merge = (not self.bufs[i].st.needs_valid() and len(self.bufs[i].st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack - if self.shapes[i][-1] == 4 and self.buftokens[i].typ == Types.FLOAT and self.strides[i][-1] == 1 and can_merge: + can_merge = (not st.needs_valid() and len(st.views) == 1) or "Image" in str(type(self.bufs[i]._buf)) # TODO: terrible hack + if st.shape[-1] == 4 and self.buftokens[i].typ == Types.FLOAT and st.views[-1].strides[-1] == 1 and can_merge: # this is an upcast to FLOAT4 self.buftokens[i].typ = Types.FLOAT4 - assert all(x%upcasted[0] == 0 for x in self.strides[i][0:-1]) - assert self.offsets[i]%upcasted[0] == 0 + assert all(st.views[-1].strides[i]%upcasted[0] == 0 or st.views[-1].shape[i] == 1 for i in range(len(st.shape)-1)) + assert self.sts[i].offset % upcasted[0] == 0 else: - self.buftokens[i].array(upcasted[0], self.strides[i][-1]) + self.buftokens[i].array(upcasted[0], st.views[-1].strides[-1]) # remove the last dimension - self.shapes = [x[:-1] for x in self.shapes] - self.strides = [x[:-1] for x in self.strides] + for st in self.sts: st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 73c071dc8d..362b1240a0 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -6,7 +6,7 @@ from tinygrad.helpers import prod from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters from tinygrad.ast import ASTKernel, Token, Types from tinygrad.lazy import IMAGE -from tinygrad.shape import ShapeTracker, View, ZeroView +from tinygrad.shape import ShapeTracker, ZeroView from tinygrad.shape.symbolic import Variable, ModNode CUDA = int(os.getenv("CUDA", "0")) @@ -39,9 +39,8 @@ class CLASTKernel(ASTKernel): start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"} # TODO: move to shapetracker - def compute_buf_index_symbolic(self, st, buf_index, offset=0): - view = View(self.shapes[buf_index], self.strides[buf_index], self.offsets[buf_index] + offset) - idx = view.expr_idxs([f"idx{i}" for i in range(self.shape_len)]) + def compute_buf_index_symbolic(self, st, offset=0): + idx = st.views[-1].expr_idxs([f"idx{i}" for i in range(self.shape_len)], offset) valid = Variable.num(1) for v in st.views[0:-1][::-1]: if isinstance(v, ZeroView): valid = v.expr_node(valid, idx) @@ -62,7 +61,7 @@ class CLASTKernel(ASTKernel): if len(value)*4 == self.buftokens[buf_index].size(): value = split_float4(value) assert len(value) == self.buftokens[buf_index].size(), f"size mismatch {len(value)} != {self.buftokens[buf_index].size()}" for v, o in zip(value, self.buftokens[buf_index].offsets()): - idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o) + idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o) assert str(valid) == "1", "store must always be valid" assert self.buftokens[buf_index].typ == v.typ, f"buf must be {v.typ}" if isinstance(self.bufs[buf_index]._buf, CLImage): @@ -80,7 +79,7 @@ class CLASTKernel(ASTKernel): const = Token(f"({self.bufs[buf_index]._backing[0]}f)", self.buftokens[buf_index].typ) if self.bufs[buf_index].st.needs_valid(): for o in self.buftokens[buf_index].offsets(): - _, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o) + _, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o) tokens.append(Token(f"({valid.cl} ? {const.tok} : 0.0f)", const.typ) if str(valid) != "1" else const) return tokens else: @@ -89,7 +88,7 @@ class CLASTKernel(ASTKernel): # not constant folded for o in self.buftokens[buf_index].offsets(): if (buf_index, o) not in self.loaded_keys: - idxy, valid = self.compute_buf_index_symbolic(self.bufs[buf_index].st, buf_index, o) + idxy, valid = self.compute_buf_index_symbolic(self.sts[buf_index], o) if isinstance(self.bufs[buf_index]._buf, CLImage): ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {self.image_idx(buf_index, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) else: @@ -125,10 +124,10 @@ class CLASTKernel(ASTKernel): # if there's images in the earlybufs, we have to make an axis the 4 loading one # shove the axis to the end and remove if any(isinstance(buf._buf, CLImage) for buf in self.earlybufs): - eb_valids = [True] * len(self.shapes[0]) + eb_valids = [True] * self.shape_len for i in range(len(self.bufs)): if isinstance(self.bufs[i]._buf, CLImage) and self.bufs[i] in self.earlybufs: - valids = [self.shapes[i][j]%4 == 0 and self.strides[i][j] == 1 for j in range(len(self.shapes[i]))] + valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)] eb_valids = [x and y for x,y in zip(eb_valids, valids)] assert any(eb_valids), f"invalid op with images {eb_valids}" eb_valid = eb_valids.index(True) @@ -146,9 +145,9 @@ class CLASTKernel(ASTKernel): self.simplify_ones() # are we grouping? - if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.shapes[0][:self.first_reduce]) <= 2048: - for sz in ([256, 16] if prod(self.shapes[0][:self.first_reduce]) <= 32 else [16]): - if all([x[self.first_reduce] % sz == 0 or x[self.first_reduce] == 1 for x in self.shapes]): + if self.buftokens[0].typ != Types.FLOAT4 and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): + if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): self.group_for_reduce.append(sz) break @@ -161,9 +160,9 @@ class CLASTKernel(ASTKernel): # if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape self.upcast_in_mid_reduce = False if any(isinstance(buf._buf, CLImage) for buf in self.bufs if buf not in self.earlybufs) and self.buftokens[0].typ != Types.FLOAT4: - lb_valids = [True] * len(self.shapes[0]) + lb_valids = [True] * self.shape_len for i in range(len(self.bufs)): - valids = [self.shapes[i][j]%4 == 0 and (self.strides[i][j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(len(self.shapes[i]))] + valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not isinstance(self.bufs[i]._buf, CLImage) or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)] lb_valids = [x and y for x,y in zip(lb_valids, valids)] assert any(lb_valids), f"invalid op with images {lb_valids}" lb_valid = lb_valids.index(True) @@ -186,11 +185,11 @@ class CLASTKernel(ASTKernel): self.simplify_ones() # split to 4 float4s - if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.shapes[0][:self.first_reduce]) >= 2048 and not self.group_for_reduce: + if self.buftokens[0].typ == Types.FLOAT4 and any(isinstance(buf._buf, CLImage) for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce: xb_choices = [] for i in range(self.first_reduce): - if all(x[i]%4 == 0 for x in self.shapes): - xb_choices.append((sum(x[i]>0 for x in self.strides), sum(x[i] for x in self.strides), i)) + if all(st.shape[i]%4 == 0 for st in self.sts): + xb_choices.append((sum(st.views[-1].strides[i]>0 for st in self.sts), sum(st.views[-1].strides[i] for st in self.sts), i)) if len(xb_choices): xb_choice = sorted(xb_choices)[0][2] @@ -210,7 +209,7 @@ class CLASTKernel(ASTKernel): # use more opencl indexing if self.first_reduce == 2 and isinstance(self.bufs[0]._buf, CLImage): base_shape = self.bufs[0]._base_shape - if all([(base_shape[0]*base_shape[1])%x[0] == 0 and x[0]//base_shape[0] != 0 for x in self.shapes]): + if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]): if DEBUG >= 3: print("split opencl", base_shape, self.shapes[0]) self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) self.simplify_ones() @@ -228,17 +227,24 @@ class CLASTKernel(ASTKernel): # group_for_reduce will have to be better first def codegen(self): if DEBUG >= 3: - print("old:", self.shapes) - print("old:", self.strides) - + print("old:", [x.shape for x in self.sts]) + print("old:", [x.views[-1].strides for x in self.sts]) + if not CUDA: self.hand_coded_optimizations() - self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce + # add a local buffer for multistage reduce + if len(self.group_for_reduce): + local_buffer = GPUBuffer([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce)) + self.bufs.append(local_buffer) + self.sts.append(local_buffer.st.copy()) + self.buftokens.append(Token("temp", Types.FLOAT, ptr=True)) + + self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce if DEBUG >= 3: print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}") print("output shape", self.output_shape) for i in range(len(self.bufs)): - print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.shapes[i], self.strides[i]) + print(self.buftokens[i], f"early:{'T' if self.bufs[i] in self.earlybufs else 'F'} image:{'T' if isinstance(self.bufs[i]._buf, CLImage) else 'F'}", self.sts[i]) self.bufs_to_delete : Set[int] = set() self.loaded_keys : Dict[Tuple[int,int], Token] = {} @@ -261,8 +267,8 @@ class CLASTKernel(ASTKernel): # early ast accumulators : List[Token] = [Token("acc%d" % i, self.buftokens[0].typ) for i in range(self.buftokens[0].size())] if self.reduceop: - full_shape = [x for x in self.shapes if x != self.shapes[0]] - full_shape = self.shapes[0] if len(full_shape) == 0 else full_shape[0] + full_shape = [x.shape for x in self.sts if x.shape != self.sts[0].shape] + full_shape = self.sts[0].shape if len(full_shape) == 0 else full_shape[0] self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {CLASTKernel.start_for_op[self.reduceop.op]};\n" for accumulator in accumulators] self.kernel += [f"for (int idx{i} = 0; idx{i} < {full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] @@ -270,17 +276,17 @@ class CLASTKernel(ASTKernel): # middle if self.group_for_reduce: - self.kernel.append(f"__local {accumulators[0].decltype()} temp[{prod(self.group_for_reduce)}]; // second stage\n") + lidx, lvalid = self.compute_buf_index_symbolic(local_buffer.st) + assert str(lvalid) == "1", "local buffer must be valid" + + self.kernel.append(f"__local {accumulators[0].decltype()} {self.buftokens[-1].tok}[{prod(self.group_for_reduce)}]; // second stage\n") + self.kernel.append(f"int mid_idx = {lidx.cl}; {self.buftokens[-1].tok}[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") if self.upcast_in_mid_reduce: assert len(self.group_for_reduce) == 2 # it should be the last dimension - self.kernel.append(f"int mid_idx = idx{self.first_reduce}*{self.group_for_reduce[1]} + idx{self.first_reduce+1}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1]) self.upcast() - else: - assert len(self.group_for_reduce) == 1 - self.kernel.append(f"int mid_idx = idx{self.first_reduce}; temp[mid_idx] = {accumulators[0].tok}; barrier(CLK_LOCAL_MEM_FENCE);\n") self.kernel.append("if (mid_idx == 0) {\n") accumulators = [Token("output", self.buftokens[0].typ)] @@ -304,7 +310,7 @@ class CLASTKernel(ASTKernel): # compile kernel self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops) - mem_estimate = sum(prod(x) for x in self.shapes) + mem_estimate = sum(prod(x.shape) for x in self.sts) if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}") def runner(*bufs): diff --git a/tinygrad/shape/__init__.py b/tinygrad/shape/__init__.py index 0e4beb37d3..1e1b989b92 100644 --- a/tinygrad/shape/__init__.py +++ b/tinygrad/shape/__init__.py @@ -45,8 +45,8 @@ class View: return 'idx=' + str(self.expr_node(Variable('idx', 0, prod([x[0] for x in self.shape_strides])-1))) # generate an expression if you have a variable or expression for each index - def expr_idxs(self, idxs): - return Variable.sum([Variable.num(self.offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0]) + def expr_idxs(self, idxs, offset=0): + return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idxs[i], 0, sh-1)*st for i,(sh,st) in enumerate(zip(self.shape, self.strides)) if sh != 1 and st != 0]) class ZeroView: def __init__(self, old_shape:Tuple[int, ...], arg): @@ -95,6 +95,7 @@ class ShapeTracker: def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None): self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]) def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})" + def copy(self): return ShapeTracker(self.shape, self.views[:]) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous