From a77d792affd0ffa73a0b61f9c0af9a2bc4edfe23 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 4 Mar 2023 15:31:51 -0800 Subject: [PATCH] Codegen gpu cleanups (#640) * cleanups * fixups * handle pre upcasted global buffers * early is just required * delete junk from hand coded opt * implicit upcast_in_mid_reduce * speedup * fix exec w validhacks * reorder opt * only need to check the output for that * return total runtime from kernels if debugging --- extra/thneed.py | 1 + tinygrad/codegen/ast.py | 28 +++++-- tinygrad/codegen/gpu.py | 164 +++++++++++++++++----------------------- 3 files changed, 89 insertions(+), 104 deletions(-) diff --git a/extra/thneed.py b/extra/thneed.py index 8a3f745baf..a5e621d3b0 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -288,6 +288,7 @@ class Thneed: print(prg.prg) total_runtime += runtime print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms") + return total_runtime/1e9 return et def optimize_local_workgroup(self): diff --git a/tinygrad/codegen/ast.py b/tinygrad/codegen/ast.py index 0d23e00785..b34b717f16 100644 --- a/tinygrad/codegen/ast.py +++ b/tinygrad/codegen/ast.py @@ -122,6 +122,12 @@ class ASTKernel: @property def shape_len(self) -> int: return len(self.sts[0].shape) + @property + def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape + + @property + def upcast_in_mid_reduce_axes(self): return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] + def simplify_ones(self): # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride @@ -164,18 +170,24 @@ class ASTKernel: if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) if axis is not None: st.permute(tuple(axis)) - def shift_to_last(self, axis, amount): + # axis : the axis to pull from + # amount : the amount to take + # top : if you want to pull that amount from the top + # insert_before : place to insert the new stuff + def shift_to(self, axis, amount, top=False, insert_before=None): + if insert_before is None: insert_before = self.shape_len + move_axis = axis if top else axis+1 + if move_axis < insert_before: insert_before += 1 self.reshape_and_permute( - lambda x: list(x[0:axis]) + ([x[axis]//amount, amount] if x[axis] > 1 else [1,1]) + list(x[axis+1:]), - [i for i in range(self.shape_len+1) if i != axis+1] + [axis+1]) + lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), + [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) # drops the final dimension def upcast(self): upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" for st,buftoken in zip(self.sts, self.buftokens): - if st.shape[-1] == upcasted[0]: - buftoken.array(upcasted[0], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) - - # remove the last dimension (unless it's the only dimension, then make it a 1) - for st in self.sts: st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset) + # add last axis to the buftoken (if it's not a 1) + if st.shape[-1] == upcasted[0]: buftoken.array(st.shape[-1], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) + # remove the last axis (unless it's the only dimension, then make it a 1) + st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset) diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index 241323fdd1..20c8de3467 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -4,8 +4,8 @@ from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner from tinygrad.codegen.ast import ASTKernel, Token, Types from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, Variable, render_python -from tinygrad.shape import ShapeTracker -from tinygrad.helpers import getenv, DEBUG, prod, partition, colored, mnum +from tinygrad.shape import ShapeTracker, View +from tinygrad.helpers import getenv, DEBUG, prod, partition, colored, mnum, all_same # div is different in cl than python render_cl = render_python.copy() @@ -36,8 +36,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F idx = Variable.sum(idx_nodes) unfactored = (Variable.sum(unfactored) // base_shape[1]) idy += unfactored - # ugh really... - if idx.min >= base_shape[1]//2: + # ugh really...handtuned garbage + if idx.min >= (base_shape[1]*3)//4: idx -= base_shape[1] idy += 1 else: @@ -97,6 +97,7 @@ class GPUCodegen(ASTKernel): const = Token(f"({val}f)", Types.FLOAT) should_upcast = self.lang.float4 and const is None and self.buftokens[buf_index].can_float4() tokens = [] + test_idy = [] for o in self.buftokens[buf_index].offsets(): key = f"val{mnum(buf_index)}_{mnum(o)}" if (buf_index, o) not in self.loaded_keys: @@ -114,6 +115,7 @@ class GPUCodegen(ASTKernel): assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}" idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS) ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) + test_idy.append(idy.render(render_cl)) elif should_upcast and can_merge: ldr = Token(f"(({self.lang.buffer_prefix if self.bufs[buf_index] is not None else self.lang.smem_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4) else: @@ -129,6 +131,7 @@ class GPUCodegen(ASTKernel): else: self.loaded_keys[(buf_index,o)] = Token(key, Types.FLOAT) tokens.append(self.loaded_keys[(buf_index,o)]) + assert not VALIDHACKS or all_same(test_idy), f"idy changed! {test_idy}" return tokens def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]: @@ -142,61 +145,56 @@ class GPUCodegen(ASTKernel): else: return [Token(code.replace("A", a.tok), a.typ) for a in values[0]] + def required_optimizations(self, early_only=False): + for buf_index,buf in enumerate(self.bufs): + upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes] + if (not early_only or buf in self.earlybufs) and hasattr(buf._buf, "IMAGE") and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))): + axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + self.shift_to(axes[0], 4) + self.upcast() + assert self.buftokens[buf_index].can_float4() + def hand_coded_optimizations(self): # if there's images in the earlybufs, we have to make an axis the 4 loading one - # shove the axis to the end and remove - if any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs): - eb_valids = [True] * self.shape_len - for i in range(len(self.bufs)): - if hasattr(self.bufs[i]._buf, "IMAGE") and self.bufs[i] in self.earlybufs: - valids = [self.sts[i].shape[j]%4 == 0 and self.sts[i].views[-1].strides[j] == 1 for j in range(self.shape_len)] - eb_valids = [x and y for x,y in zip(eb_valids, valids)] - assert any(eb_valids), f"invalid op with images {eb_valids}" - eb_valid = eb_valids.index(True) - if DEBUG >= 4: print(f"early merging axis {eb_valid} from {eb_valids}") - - # no change, we added a dimension - self.shift_to_last(eb_valid, 4) - - # drop the last dimension - self.upcast() + self.required_optimizations(early_only=True) # simplify (sets first_reduce) self.simplify_ones() - # are we grouping? + # are we grouping? what does this have to do with float4? if self.lang.float4 and not self.buftokens[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): + self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce) self.group_for_reduce.append(sz) break - # if there's images in the latebufs, we have to make an axis the 4 storing one. this affects the kernel shape - if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf not in self.earlybufs) and not self.buftokens[0].can_float4(): - lb_valids = [True] * self.shape_len - for i in range(len(self.bufs)): - valids = [self.sts[i].shape[j]%4 == 0 and (self.sts[i].views[-1].strides[j] == 1 or not hasattr(self.bufs[i]._buf, "IMAGE") or self.bufs[i] in self.earlybufs) for j in range(self.shape_len)] - lb_valids = [x and y for x,y in zip(lb_valids, valids)] - assert any(lb_valids), f"invalid op with images {lb_valids}" - lb_valid = lb_valids.index(True) - assert lb_valid < self.first_reduce, f"can't be in the reduce {lb_valid}" - if DEBUG >= 4: print(f"late merging axis {lb_valid} from {lb_valids}") + # are we upcasting in mid reduce? + if hasattr(self.bufs[0]._buf, "IMAGE") and not self.buftokens[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2: + axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1] + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis + self.group_for_reduce.append(4) - # no change, we added a dimension - self.shift_to_last(lb_valid, 4) - - if self.group_for_reduce and self.first_reduce <= 2: - self.upcast_in_mid_reduce = True - self.group_for_reduce.append(4) - else: - # drop the last dimension - self.upcast() + # now do everything required + self.required_optimizations() # simplify (sets first_reduce) self.simplify_ones() - # split to 4 float4s + # use more opencl indexing if the output buffer is an image and we have room + if hasattr(self.bufs[0]._buf, "IMAGE") and self.first_reduce+len(self.group_for_reduce) < 3: + base_shape = self.bufs[0]._base_shape + if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0: + if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape) + self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) + self.simplify_ones() + + # **** below this line need to be optional and benchmarked **** + + # split to 4 float4s based on a heuristic if self.buftokens[0].can_float4() and any(hasattr(buf._buf, "IMAGE") for buf in self.earlybufs) and prod(self.sts[0].shape[:self.first_reduce]) >= 2048 and not self.group_for_reduce: xb_choices = [] for i in range(self.first_reduce): @@ -208,7 +206,7 @@ class GPUCodegen(ASTKernel): if DEBUG >= 4: print(f"float4 merging axis {xb_choice} : {xb_choices}") # this leaves the last axis in place - self.shift_to_last(xb_choice, 4) + self.shift_to(xb_choice, 4) # drop the last dimension self.upcast() @@ -216,86 +214,61 @@ class GPUCodegen(ASTKernel): # re-simplify self.simplify_ones() - # use more opencl indexing if the output buffer is an image - if self.first_reduce == 2 and hasattr(self.bufs[0]._buf, "IMAGE"): - base_shape = self.bufs[0]._base_shape - if all([(base_shape[0]*base_shape[1])%st.shape[0] == 0 and st.shape[0]//base_shape[0] != 0 for st in self.sts]): - if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape) - self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None) - self.simplify_ones() - - # group for reduce - if len(self.group_for_reduce): - # with permute for memory coalesing - if len(self.group_for_reduce) == 2: - permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.shape_len, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len)) - else: - permute_axis = list(range(0, self.first_reduce)) + [self.first_reduce+1, self.first_reduce] + list(range(self.first_reduce+2, self.shape_len+1)) - self.reshape_and_permute(lambda x: list(x[0:self.first_reduce]) + [max(1, x[self.first_reduce]//self.group_for_reduce[0]), min(x[self.first_reduce], self.group_for_reduce[0])] + list(x[self.first_reduce+1:]), permute_axis) - - # if last dim <= 3 and it's a reduce dim, upcast (loop unrolling) - end_dimension = max([st.shape[-1] for st in self.sts]) - if self.first_reduce < self.shape_len and end_dimension > 1 and end_dimension <= 3 and max([x.size() for i,x in enumerate(self.buftokens) if self.bufs[i] in self.earlybufs]) <= 4: + # if last dim <= 3 and it's a reduce dim, upcast (loop unrolling). no simplify needed since it's just an upcast + # NOTE: careful, this can break VALIDHACKS + if not self.group_for_reduce and self.first_reduce < self.shape_len and self.full_shape[-1] > 1 and self.full_shape[-1] <= 3 and (max([x.size() for i,x in enumerate(self.buftokens) if self.bufs[i] in self.earlybufs]) <= 4 or not any(r for _,_,r in self.buftokens[self.full_buf_index].axis)): self.upcast() - def required_optimizations(self): - for buf_index,buf in enumerate(self.bufs): - if hasattr(buf._buf, "IMAGE") and not (self.buftokens[buf_index].can_float4() or (buf not in self.earlybufs and self.upcast_in_mid_reduce)): - axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - self.shift_to_last(axes[0], 4) - self.upcast() - assert self.buftokens[buf_index].can_float4() - # STOP WASTING TIME WITH DOING THE RESHAPES AND PERMUTES BY HAND. KERNEL SEARCH IS THE ONLY WAY IT WILL EVER BE GOOD # group_for_reduce will have to be better first def codegen(self) -> ASTRunner: self.process() - self.upcast_in_mid_reduce = False + if DEBUG >= 4: self.printbufs("old:", DEBUG>=5) + self.hand_coded_optimizations() - # this shouldn't do anything if you ran the hand coded optimizations - self.required_optimizations() - - # there's sometimes ones here - self.simplify_ones() - # fancy colored shape printer if DEBUG >= 3: - axis = [(f"{rs:4d}", ("green" if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.sts[self.full_buf_index].shape)] + axis = [(f"{rs:4d}", (("green" if i in self.upcast_in_mid_reduce_axes else "cyan") if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.full_shape)] axis += [(f"{s:4d}", 'magenta' if reduce else 'yellow') for s, _, reduce in self.buftokens[self.full_buf_index].axis[::-1]] print(' '.join([colored(*x) for x in axis])+(" "*(50-len(' '.join([x[0] for x in axis])))), end="") - self.prekernel : Set[str] = set() - self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs) else [] - # add a local buffer for multistage reduce if len(self.group_for_reduce): - self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce)))) - self.buftokens.append(Token("temp", Types.FLOAT, ptr=True)) self.bufs.append(None) - self.kernel.append(self.lang.smem_prefix + f"float {self.buftokens[-1].tok}[{self.sts[-1].size()}];\n") + # TODO: the strides of this can be controlled + st = ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.buftokens[0].axis])) + buftoken = Token("temp", Types.FLOAT, ptr=True) + # manual upcast of the local + for _,_,r in self.buftokens[0].axis[::-1]: + buftoken.array(st.shape[-1], st.views[-1].strides[-1], r) + st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) + self.sts.append(st) + self.buftokens.append(buftoken) - self.output_shape = list(self.sts[0].shape[:self.first_reduce]) + self.group_for_reduce + self.output_shape : Tuple[int, ...] = self.sts[0].shape[:self.first_reduce] + tuple(self.group_for_reduce) + assert self.full_shape[:len(self.output_shape)] == self.output_shape, f"output shape mismatch : {self.full_shape[:len(self.output_shape)]} != {self.output_shape}" if DEBUG >= 4: print("output shape", self.output_shape) self.printbufs("new:", DEBUG>=5) self.bufs_to_delete : Set[int] = set() self.loaded_keys : Dict[Tuple[int,int], Token] = {} + self.prekernel : Set[str] = set() + self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else [] - # output_shape[-1] is get_global_id(0) if len(self.lang.gid) == 0: self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))] else: + # output_shape[-1] is get_global_id(0) self.kernel += [f"int idx{len(self.output_shape)-1-i} = {self.lang.gid[i]}; /* {self.output_shape[-1-i]} */\n" for i in range(min(len(self.lang.gid), len(self.output_shape))) if self.output_shape[-1-i] != 1] if len(self.output_shape) > len(self.lang.gid): # sometimes, there's more dimensions. compact all the dimensions into the first one - # TODO: these compactions should be searchable + # TODO: these compactions should be searchable (they sort of are with reshapes and permutes) final_dimension = len(self.output_shape)-len(self.lang.gid) for i in range(final_dimension-1, -1, -1): self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"] - self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:]) + self.output_shape = (prod(self.output_shape[0:final_dimension+1]), ) + self.output_shape[final_dimension+1:] if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}") # early ast @@ -304,12 +277,13 @@ class GPUCodegen(ASTKernel): acc_offsets = self.buftokens[self.bufs.index(self.earlybufs[0])].acc_offsets() assert self.reduceopop is not None self.kernel += [f"{accumulator.decltype()} {accumulator.tok} = {GPUCodegen.start_for_op[self.reduceopop]};\n" for accumulator in accumulators] - self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.sts[self.full_buf_index].shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] + self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] self.kernel += [f"{x.tok};\n" for x in self.ast_parse(self.reduceop, [accumulators[off] for off in acc_offsets], do_reduce=True)] self.kernel += ["}\n"] * (self.shape_len - (self.first_reduce + len(self.group_for_reduce))) # middle if self.group_for_reduce: + self.kernel.append(self.lang.smem_prefix + f"float {self.buftokens[-1].tok}[{self.sts[-1].size()*self.buftokens[-1].size()}];\n") self.store(-1, accumulators) # TODO: this is assuming the local size = global size. should use lidxs self.kernel.append(self.lang.barrier+"\n") @@ -318,12 +292,10 @@ class GPUCodegen(ASTKernel): lidx, lvalid = self.sts[-1].expr_idxs() assert lvalid.min == 1, "local buffer must always be valid" - if self.upcast_in_mid_reduce: - assert len(self.group_for_reduce) == 2 - # it should be the last dimension - self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != self.first_reduce+1] + [self.first_reduce+1]) + # if any group_for_reduce items aren't reduces, upcast them here + for j in self.upcast_in_mid_reduce_axes: + self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) self.upcast() - if DEBUG>=4: self.printbufs("upcast:", DEBUG>=5) assert self.reduceopop is not None self.kernel.append(f"if ({lidx.render(render_cl)} == 0) {{\n") @@ -362,6 +334,6 @@ class GPUCodegen(ASTKernel): GPUCodegen.kernel_name_cache[prg] = function_name return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete, - self.output_shape[::-1] if len(self.output_shape) > 0 else [1], + list(self.output_shape[::-1]) if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, op_estimate=self.info.flops, mem_estimate=sum(prod(x._base_shape) for x in self.bufs if x is not None))