From 06abbbfe7cf1a9a92745dafa21f7a11ba72026f0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 20 Mar 2023 15:45:12 -0700 Subject: [PATCH] remove the stupid register class (#721) * remove the stupid register class * touchups * colorful display name --- tinygrad/codegen/cstyle.py | 10 +- tinygrad/codegen/linearizer.py | 241 ++++++++++++++++----------------- tinygrad/codegen/llvmir.py | 2 +- tinygrad/ops.py | 6 +- 4 files changed, 126 insertions(+), 133 deletions(-) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 86634ad6e6..99a3cbecbd 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -1,4 +1,4 @@ -from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set +from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps @@ -56,7 +56,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})" } -def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: +def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: def group_float4(grp:List[str]) -> str: if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0] else: return f"{lang.float4}({','.join(g for g in grp)})" @@ -67,6 +67,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[LazyBuffer], bufnames:List[str], la local_size = [] pend_close = None + bufnames = ["temp" if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)] + depth = 0 def kk(s): kernel.append(" "*depth+s) @@ -200,7 +202,7 @@ class CStyleCodegen(Linearizer): self.hand_coded_optimizations() self.linearize() - prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, [x.name for x in self.registers], self.lang) + prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, self.lang) # if we have local_sizes, we have to correct the global_size for i,s in enumerate(local_size): global_size[i] *= s @@ -215,4 +217,4 @@ class CStyleCodegen(Linearizer): return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), global_size[::-1] if len(global_size) else [1], local_size[::-1] if len(local_size) else None, - op_estimate=self.info.flops, mem_estimate=self.mem_estimate) + op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 65c75cfcab..a098f20805 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -3,11 +3,11 @@ import itertools, math from collections import defaultdict from enum import Enum, auto -from tinygrad.helpers import dedup, colored, all_same, ImageDType, DEBUG, prod, dtypes, mnum, DType +from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps -from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape +from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape from tinygrad.shape.symbolic import Variable, SumNode, ModNode class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); LOAD4 = auto(); STORE4 = auto() # noqa: E702 @@ -28,32 +28,12 @@ class UOp(NamedTuple): arg: Any def __repr__(self): return f"{str(self.uop):20s}: {self.out if self.out is not None else '':10s} {str(self.vin):32s} {self.arg}" -def get_first_reduce(shapes): - for i in range(len(shapes[0])): - if not all_same([x[i] for x in shapes]): return i - return len(shapes[0]) # off the end - def check_no_mul(test, var): if test == var: return True if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay return False -class Register: - def __init__(self, name:str): - self.name = name - self.axis: List[Tuple[int, int, bool]] = [] - def array(self, length, stride, reduce): self.axis.append((length, stride, reduce)) - def size(self): return prod([x[0] for x in self.axis]) - def offsets(self): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.axis[::-1]])] if len(self.axis) else [0] - def can_float4(self): return any(a[0:2] == (4,1) for a in self.axis) - # TODO: this is sort of a hack, it gets the accumulator indices - def acc_offsets(self): - if len(self.axis) == 0: return [0] - acc_strides = [x*(1-self.axis[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.axis[::-1])))] - return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.axis[::-1])])] - def __repr__(self): return f"<{self.name}{f'{self.axis}'}>" - class Linearizer: supports_float4: bool = False @@ -97,21 +77,57 @@ class Linearizer: permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) self.reshape_and_permute(None, permute) + # parameters + self.group_for_reduce: List[int] = [] + self.upcasted: int = 0 + # group simplifies self.simplify_ones() self.simplify_merge_adjacent() - # is this generic? - self.registers = [Register(f"data{i}") for i in range(len(self.bufs))] - self.group_for_reduce: List[int] = [] + # NOTE: this stride is only on the last view, and may not be real + def upcasted_axis(self, i): + return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:], + self.sts[i].strides[self.shape_len-self.upcasted:], + [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])])) + def offsets(self, i): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.upcasted_axis(i)[::-1]])] if self.upcasted > 0 else [0] + def can_float4(self, i): return any(a[0:2] == (4,1) for a in self.upcasted_axis(i)) + def acc_offsets(self, i): + if self.upcasted == 0: return [0] + acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))] + return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])] def can_merge_float4(self, i:int, idxs:List[Variable], offset:int) -> bool: if offset%4 != 0: return False float4_index = Variable("FLOAT4_INDEX", 0, 3) idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs) - if DEBUG >= 4: print(f"attempting to fuse buf {i} :", check_no_mul(idxy_test, float4_index), idxy_test//4, valid_test//4) # float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required? - return check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render() + ret = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render() + if DEBUG >= 4: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test//4, valid_test//4) + return ret + + def global_buf(self, i, idxs:List[Variable], store=None): + should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 + cache: Dict[int, str] = {} + store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique + def op(offset): + if offset in cache: return cache[offset] + will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) + if store is not None: + if offset in store_offset: + offsets = [] + for j in range(0, 4 if will_merge else 1): + offsets.append(store[store_offset[offset+j]]) + del store_offset[offset+j] + self.uop(UOps.STORE4 if will_merge else UOps.STORE, None, offsets, MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) + else: + reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, f"val{mnum(i)}_{mnum(offset)}", [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) + if will_merge: + for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j] + else: + cache[offset] = reg + return cache[offset] + return [op(o) for o in self.offsets(i)] def linearize(self): # uops @@ -121,41 +137,15 @@ class Linearizer: if len(self.group_for_reduce): self.bufs.append(LocalBuffer()) # TODO: the strides of this can be controlled - st = ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.registers[0].axis])) - buftoken = Register("temp") - # manual upcast of the local - for _,_,is_reduce in self.registers[0].axis[::-1]: - buftoken.array(st.shape[-1], st.views[-1].strides[-1], is_reduce) - st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) - self.sts.append(st) - self.registers.append(buftoken) - self.uop(UOps.DEFINE_LOCAL, None, [], (self.registers[-1].name, self.sts[-1].size()*self.registers[-1].size())) + self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) + self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size())) # print if DEBUG >= 3: self.printbufs() - def global_buf(i, idxs:List[Variable], store=None): - should_upcast = self.supports_float4 and self.registers[i].can_float4() and self.bufs[i].dtype != dtypes.float16 - cache: Dict[int, str] = {} - store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.registers[i].offsets())} # NOTE: for stores, these should be unique - def op(offset): - if offset in cache: return cache[offset] - will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) - if store is not None: - if offset in store_offset: - offsets = [] - for j in range(0, 4 if will_merge else 1): - offsets.append(store[store_offset[offset+j]]) - del store_offset[offset+j] - self.uop(UOps.STORE4 if will_merge else UOps.STORE, None, offsets, MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - else: - reg = self.uop(UOps.LOAD4 if will_merge else UOps.LOAD, self.registers[i].name+"_"+mnum(offset), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - if will_merge: - for j in range(0, 4): cache[offset+j] = reg+"."+"xyzw"[j] - else: - cache[offset] = reg - return cache[offset] - return [op(o) for o in self.registers[i].offsets()] + # kernel name (before late upcast) + self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) + self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'black', bright=True).join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + " " * (23-len(self.function_name)) # parse AST loaded_buffers = {} @@ -184,24 +174,24 @@ class Linearizer: # reduce op if self.reduceop is not None: # define accumulator - acc = [self.uop(UOps.CONST, ssa('acc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.registers[0].offsets()] + acc = [self.uop(UOps.CONST, ssa('acc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(0)] # reduce loop - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)] + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce")) # load earlybufs - loaded_buffers.update({b:global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) + loaded_buffers.update({b:self.global_buf(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0}) # run early AST (with reduce) - self.ast_parse(self.reduceop, [acc[off] for off in self.registers[self.full_buf_index].acc_offsets()], loaded_buffers, ssa, do_reduce=True) + self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True) # end the reduce loop self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce")) # end the local loop, do the local reduce if self.group_for_reduce: - global_buf(-1, local_idxs, acc) # store accumulators + self.global_buf(-1, local_idxs, acc) # store accumulators self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs # if any group_for_reduce items aren't reduces, upcast them here @@ -213,40 +203,36 @@ class Linearizer: # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = [self.uop(UOps.CONST, ssa('lacc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.registers[-1].offsets()] + acc = [self.uop(UOps.CONST, ssa('lacc'), [], {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) for _ in self.offsets(-1)] # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce")) # load localbufs - loaded_buffers["LOCAL_BUFFER"] = global_buf(-1, end_local_idxs) + loaded_buffers["LOCAL_BUFFER"] = self.global_buf(-1, end_local_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.registers[-1].acc_offsets()], loaded_buffers, ssa, do_reduce=True) + self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # end the late reduce loop self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) # load latebufs - loaded_buffers.update({b:global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) + loaded_buffers.update({b:self.global_buf(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) # run late AST val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) # store - global_buf(0, global_idxs, val) + self.global_buf(0, global_idxs, val) # end the global loop self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global")) - # kernel function definition - self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape]) - - def uop(self, uop:UOps, out:Optional[str], vin:List[str], arg:Any): self.uops.append(UOp(uop, out, vin, arg)) - if DEBUG >= 3: print(self.uops[-1]) + if DEBUG >= 4: print(self.uops[-1]) return out def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[str]: @@ -263,29 +249,72 @@ class Linearizer: return [self.uop(UOps.ALU, ssa('alu'), list(val), x.op) for val in zip(*values)] @property - def first_reduce(self) -> int: return get_first_reduce([x.shape for i,x in enumerate(self.sts) if not isinstance(self.bufs[i], LocalBuffer)]) + def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) @property def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape + @property + def full_unupcasted_shape(self) -> Tuple[int, ...]: return self.full_shape[:self.shape_len-self.upcasted] + @property def shape_len(self) -> int: return len(self.sts[0].shape) @property def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] - def colorshape(self, pad=50) -> str: - axis = [(f"{rs:4d}", (("green" if i in self.upcast_in_mid_reduce_axes else "cyan") if i < self.first_reduce + len(self.group_for_reduce) else "red") if i >= self.first_reduce else "blue") for i, rs in enumerate(self.full_shape)] - axis += [(f"{s:4d}", 'magenta' if reduce else 'yellow') for s, _, reduce in self.registers[self.full_buf_index].axis[::-1]] - return ' '.join([colored(*x) for x in axis])+(" "*(pad-len(' '.join([x[0] for x in axis])))) + def colors(self) -> List[str]: + # up to first_reduce, they are all global (blue) + colors = ["blue"] * self.first_reduce + # between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green) + colors += ["green" if i in self.upcast_in_mid_reduce_axes else "cyan" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] + # between first_reduce + group_for_reduce and upcasted, they are reduce (red) + colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce))) + # upcasted dimensions are reduce (magenta) or normal (yellow) + colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)] + assert len(colors) == self.shape_len, "colors size mismatch" + return colors def printbufs(self, prefix=""): for i in range(len(self.sts)): - print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s} {str(self.registers[i]):38s}", self.sts[i].views) - print(self.colorshape()) + print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views) + print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))) # ******************** base simplifiers ******************** + # apply reshape and permute to all shapetrackers + def reshape_and_permute(self, new_shape_fxn, axis): + for st in self.sts: + if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) + if axis is not None: st.permute(tuple(axis)) + + # drops the final dimension + def upcast(self): + assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1" + self.upcasted += 1 + + # axis : the axis to pull from + # amount : the amount to take + # top : if you want to pull that amount from the top + # insert_before : place to insert the new stuff + def shift_to(self, axis, amount, top=False, insert_before=None): + if insert_before is None: insert_before = self.shape_len + move_axis = axis if top else axis+1 + if move_axis < insert_before: insert_before += 1 + self.reshape_and_permute( + lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), + [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) + + # ******************** complex simplifiers ******************** + + def simplify_ones(self): + # remove places where the shape is all ones + # TODO: this should be factored in to multi shape stride + all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] + # keep at least 1 one + if all(all_ones): all_ones[-1] = False + self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) + def simplify_merge_adjacent(self): if self.shape_len == 0: return shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts] @@ -308,56 +337,18 @@ class Linearizer: # do the reshapes for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x)) - def simplify_ones(self): - # remove places where the shape is all ones - # TODO: this should be factored in to multi shape stride - all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] - # keep at least 1 one - if all(all_ones): all_ones[-1] = False - self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) - - # apply reshape and permute to all shapetrackers - def reshape_and_permute(self, new_shape_fxn, axis): - for st in self.sts: - if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st.permute(tuple(axis)) - - # ******************** complex simplifiers ******************** - - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_before : place to insert the new stuff - def shift_to(self, axis, amount, top=False, insert_before=None): - if insert_before is None: insert_before = self.shape_len - move_axis = axis if top else axis+1 - if move_axis < insert_before: insert_before += 1 - self.reshape_and_permute( - lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), - [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) - - # drops the final dimension - def upcast(self): - upcasted = [x.shape[-1] for x in self.sts if x.shape[-1] != 1] - assert len(upcasted) >= 1 and all_same(upcasted), f"can't upcast mismatch {upcasted}" - for st,buftoken in zip(self.sts, self.registers): - # add last axis to the buftoken (if it's not a 1) - if st.shape[-1] == upcasted[0]: buftoken.array(st.shape[-1], st.views[-1].strides[-1], len(upcasted) != len(self.sts)) - # remove the last axis (unless it's the only dimension, then make it a 1) - st.views[-1] = View(st.shape[0:-1], st.views[-1].strides[0:-1], st.views[-1].offset) if len(st.shape) > 1 else View((1,), (0,), st.views[-1].offset) - # ******************** GPU simplifiers ******************** def required_optimizations(self, early_only=False): for buf_index,buf in enumerate(self.bufs): upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes] - if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType) and not (self.registers[buf_index].can_float4() or (buf not in self.earlybufs and (1 in upcast_strides))): + if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType) and not (self.can_float4(buf_index) or (buf not in self.earlybufs and (1 in upcast_strides))): axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}" assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}" self.shift_to(axes[0], 4) self.upcast() - assert self.registers[buf_index].can_float4() + assert self.can_float4(buf_index) def hand_coded_optimizations(self): # if there's images in the earlybufs, we have to make an axis the 4 loading one @@ -367,16 +358,16 @@ class Linearizer: self.simplify_ones() # are we grouping? (requires local shape support) - if not self.registers[0].can_float4() and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + if not self.can_float4(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): - self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce) + self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) self.group_for_reduce.append(sz) break # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.registers[0].can_float4() and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: + if self.bufs[0].dtype.name.startswith('image') and not self.can_float4(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1] assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" if self.sts[0].shape[axes[0]]%4 == 0: @@ -407,7 +398,7 @@ class Linearizer: xb_choices = [] for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce # if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken - if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.registers[buf_index].axis) for buf_index in range(len(self.sts))): + if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))): xb_choices.append((sum(st.strides[axis]>0 for st in self.sts), sum(st.strides[axis] for st in self.sts), axis, upcast_amount)) if len(xb_choices): xb_choices = sorted(xb_choices) @@ -419,5 +410,5 @@ class Linearizer: break # if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < self.shape_len and self.full_shape[-1] <= 5 and (max([x.size() for i,x in enumerate(self.registers) if self.bufs[i] in self.earlybufs]) <= 4 or not any(r for _,_,r in self.registers[self.full_buf_index].axis)): + if self.first_reduce < (self.shape_len-self.upcasted) and self.full_unupcasted_shape[-1] <= 5 and (len(self.offsets(self.full_buf_index)) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): self.upcast() diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index 17dc1aa321..7f43bfb2f4 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -106,4 +106,4 @@ class LLVMIRCodegen(Linearizer): self.process() # no optimize, this doesn't support local self.linearize() - return ASTRunner('exec', uops_to_llvm_ir(self.uops, self.bufs), op_estimate=self.info.flops, mem_estimate=self.mem_estimate) + return ASTRunner('exec', uops_to_llvm_ir(self.uops, self.bufs), op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 61934afdd2..0e21e461c4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -76,9 +76,9 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex # **************** for Compiled Buffers **************** class ASTRunner: - def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0): + def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None): if DEBUG >= 4: print(prg) - self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate = name, prg, global_size, local_size, op_estimate, mem_estimate + self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name def build(self, runtime): self.clprg = runtime(self.name, self.prg) @@ -93,7 +93,7 @@ class ASTRunner: if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs) if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"*** {GlobalCounters.kernel_count:4d} {self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"*** {GlobalCounters.kernel_count:4d} {self.display_name if self.display_name is not None else self.name:20s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):6.2f} GB/s)")) GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += self.op_estimate