From 8b7ecd63bb0ff7e8a42a156a55113ab5d5f1f21b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 17 Apr 2023 08:21:46 -0700 Subject: [PATCH] Remove Zeroview (#748) * no zeroview start * closer * stride mask * st tests pass, delete ZeroView * byebye zv * close to working * not contiguous with mask * subtract, don't add * mask on view * ugh, that shouldn't have been in there * shape merge * bugfixes * fuzzer + 4 fuzzer failures * fuzzer for symbolic * more fuzzing and nothing * that fuzzer doesn't hit either * fixes padding...ugh * no more offsets * working * rewrite load and store * all checks * fix idxs * progress * bugfix * float4_axis * works * cleanups * complex valids_okay --- test/external/fuzz_shapetracker.py | 60 ++++++++++ test/external/fuzz_symbolic.py | 50 +++++++++ test/unit/test_shapetracker.py | 115 +++++++++++++------ test/unit/test_symbolic.py | 3 + tinygrad/codegen/cstyle.py | 2 + tinygrad/codegen/linearizer.py | 158 ++++++++++++++------------ tinygrad/lazy.py | 2 +- tinygrad/shape/shapetracker.py | 172 +++++++++++++++++------------ tinygrad/shape/symbolic.py | 10 +- 9 files changed, 391 insertions(+), 181 deletions(-) create mode 100644 test/external/fuzz_shapetracker.py create mode 100644 test/external/fuzz_symbolic.py diff --git a/test/external/fuzz_shapetracker.py b/test/external/fuzz_shapetracker.py new file mode 100644 index 0000000000..1e82dd5bca --- /dev/null +++ b/test/external/fuzz_shapetracker.py @@ -0,0 +1,60 @@ +import random +from test.unit.test_shapetracker import CheckingShapeTracker + +def do_permute(st): + perm = list(range(0, len(st.shape))) + random.shuffle(perm) + perm = tuple(perm) + print("st.permute(", perm, ")") + st.permute(perm) + +def do_pad(st): + c = random.randint(0, len(st.shape)-1) + pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape))) + print("st.pad(", pad, ")") + st.pad(pad) + +def do_reshape_split_one(st): + c = random.randint(0, len(st.shape)-1) + poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0] + spl = random.choice(poss) + shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:] + print("st.reshape(", shp, ")") + st.reshape(shp) + +def do_reshape_combine_two(st): + if len(st.shape) < 2: return + c = random.randint(0, len(st.shape)-2) + shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:] + print("st.reshape(", shp, ")") + st.reshape(shp) + +def do_shrink(st): + c = random.randint(0, len(st.shape)-1) + while 1: + shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape)) + if all(x= 5: self.printbufs("early") - # NOTE: this stride is only on the last view, and may not be real + def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] + def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x] == 4] + + # TODO: this stride is only on the last view, and may not be real def upcasted_axis(self, i): return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:], - self.sts[i].strides[self.shape_len-self.upcasted:], + self.sts[i].views[-1].strides[self.shape_len-self.upcasted:], # WRONG [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])])) - def offsets(self, i): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.upcasted_axis(i)[::-1]])] if self.upcasted > 0 else [0] - def can_float4(self, i): return any(a[0:2] == (4,1) for a in self.upcasted_axis(i)) + # TODO: is there a better way to write this? def acc_offsets(self, i): if self.upcasted == 0: return [0] acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))] return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])] - def can_merge_float4(self, i:int, idxs:List[Variable], offset:int) -> bool: - if offset%4 != 0: return False - float4_index = Variable("FLOAT4_INDEX", 0, 3) - idxy_test, valid_test = self.sts[i].expr_idxs(float4_index+offset, idxs) - # float4_index must not be in after divide or in valid. NOTE: this forces it to always be aligned too, maybe not required? - ret = check_no_mul(idxy_test, float4_index) and "FLOAT4_INDEX" not in (idxy_test//4).render() and "FLOAT4_INDEX" not in (valid_test//4).render() - if DEBUG >= 5: print(f"fuse buf {i} {ret} :", check_no_mul(idxy_test, float4_index), idxy_test, idxy_test//4, valid_test//4) - return ret + def _group_float4(self, i, store_offset): + store_offset_float4 = {} + float4_axis = (self.upcasted-1) - self.float4_axis(i)[0] + for uidxs, var in store_offset.items(): + if uidxs[float4_axis] == 0: + store_offset_float4[uidxs] = [var] + else: + uidxs2 = list(uidxs) + uidxs2[float4_axis] = 0 + store_offset_float4[tuple(uidxs2)].append(var) + return store_offset_float4 def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]: - should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 - cache: Dict[int, Token] = {} - def op(offset): - if offset in cache: return cache[offset] - will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) - if const is not None: - reg = self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], const) + load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(LocalTypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)} + + # float4 grouping (optional) + should_upcast = self.supports_float4 and self.bufs[i].dtype != dtypes.float16 and len(self.float4_axis(i)) == 1 + if should_upcast: + load_offset_new = {} + for k,out_tokens in self._group_float4(i, load_offset).items(): + idxs = [x[2]-out_tokens[0][2] for x in out_tokens] + valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3]) + if any(idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay: + # idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float + for x in out_tokens: load_offset_new[x[1]] = x + else: + load_offset_new[k] = (LocalTypes.float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3]) + load_offset = load_offset_new + + # do loads + cache: Dict[str, Token] = {} + loaded = {} + for uidxs, (localtype, uidx_list, idx, valid) in load_offset.items(): + key = f"{localtype}{idx.render()}{valid.render()}" + if key not in cache: + cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const) + if localtype == LocalTypes.float4: + for j,uidx in enumerate(uidx_list): + loaded[uidx] = Token(cache[key].name, LocalTypes.float4, j) else: - assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" - reg = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{mnum(offset)}", LocalTypes.float4 if will_merge else LocalTypes.float), [], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - if will_merge: - for j in range(0, 4): cache[offset+j] = Token(reg.name, LocalTypes.float4, j) - else: - cache[offset] = reg - return cache[offset] - return [op(o) for o in self.offsets(i)] + loaded[uidxs] = cache[key] + return [loaded[uidxs] for uidxs in self.shape_offsets(i)] def global_store(self, i, idxs:List[Variable], store=List[Token]) -> None: - should_upcast = self.supports_float4 and self.can_float4(i) and self.bufs[i].dtype != dtypes.float16 - store_offset: Dict[int, int] = {y:x for x,y in enumerate(self.offsets(i))} # NOTE: for stores, these should be unique - def op(offset): - if offset not in store_offset: return - will_merge = should_upcast and self.can_merge_float4(i, idxs, offset) - assert will_merge or not isinstance(self.bufs[i].dtype, ImageDType), "image must merge float4" - if will_merge: - out_tokens = [store[store_offset[offset+j]] for j in range(4)] + store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store)) + + # float4 grouping (optional) + should_upcast = self.supports_float4 and self.bufs[i].dtype != dtypes.float16 and len(self.float4_axis(i)) == 1 + if should_upcast: + store_offset_new = {} + for k,out_tokens in self._group_float4(i, store_offset).items(): if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens): - var = Token(store[store_offset[offset]].name, LocalTypes.float4) + store_offset_new[k] = Token(out_tokens[0].name, LocalTypes.float4) else: - var = self.uop(UOps.CAST, Token(store[store_offset[offset]].name+"_f4", LocalTypes.float4), out_tokens) - else: - var = store[store_offset[offset]] - for j in range(0, 4 if will_merge else 1): del store_offset[offset+j] - self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(offset, idxs))) - for o in self.offsets(i): op(o) + store_offset_new[k] = self.uop(UOps.CAST, Token(out_tokens[0].name+"_f4", LocalTypes.float4), out_tokens) + store_offset = store_offset_new + + # do stores + for uidxs, var in store_offset.items(): + self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]))) def linearize(self): # uops @@ -204,12 +215,17 @@ class Linearizer: gl_idxs = global_idxs # reduce op + fake_reduce_idxs = [] + removed = len(global_idxs) if self.reduceop is not None: + # define indexes + reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] + fake_reduce_idxs = [x*0 for x in reduce_idxs] + # define accumulator - acc = self.global_load(0, gl_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(0, gl_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # reduce loop - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce")) # load earlybufs @@ -223,7 +239,7 @@ class Linearizer: # end the local loop, do the local reduce if self.group_for_reduce: - self.global_store(-1, local_idxs, acc) # store accumulators + self.global_store(-1, local_idxs+fake_reduce_idxs, acc) # store accumulators self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs # if any group_for_reduce items aren't reduces, upcast them here @@ -231,18 +247,19 @@ class Linearizer: self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) self.upcast() self.group_for_reduce.pop() + removed -= 1 # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(-1, local_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) + acc = self.global_load(-1, local_idxs[:removed]+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)]) # late reduce loop end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce")) # load localbufs - loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs) + loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) @@ -251,13 +268,13 @@ class Linearizer: self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) + loaded_buffers.update({b:self.global_load(i, global_idxs[:removed]+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) # run late AST val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) # store - self.global_store(0, global_idxs, val) + self.global_store(0, global_idxs[:removed]+fake_reduce_idxs, val) # end the global loop self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global")) @@ -376,14 +393,12 @@ class Linearizer: def required_optimizations(self, early_only=False): for buf_index,buf in enumerate(self.bufs): - upcast_strides = [self.sts[buf_index].strides[i] for i in self.upcast_in_mid_reduce_axes] - if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType) and not (self.can_float4(buf_index) or (buf not in self.earlybufs and (1 in upcast_strides))): - axes = [i for i,x in enumerate(self.sts[buf_index].strides) if x == 1] - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes} on buf_index {buf_index}, {self.sts[buf_index]}" - assert self.sts[buf_index].shape[axes[0]]%4 == 0, f"axis:{axes[0]} in buffer {buf_index} is not a multiple of 4, {self.sts[buf_index].shape}" - self.shift_to(axes[0], 4) - self.upcast() - assert self.can_float4(buf_index) + unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0] + if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType): + assert len(unit_stride_axes_mul_4) >= 1, "needs a unit stride axis" + if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: + self.shift_to(unit_stride_axes_mul_4[0], 4) + self.upcast() def hand_coded_optimizations(self): # if there's images in the earlybufs, we have to make an axis the 4 loading one @@ -393,7 +408,7 @@ class Linearizer: self.simplify_ones() # are we grouping? (requires local shape support) - if not self.can_float4(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: + if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]): @@ -402,8 +417,8 @@ class Linearizer: break # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.can_float4(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: - axes = [i for i,x in enumerate(self.sts[0].strides) if x == 1] + if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: + axes = self.sts[0].unit_stride_axes() assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" if self.sts[0].shape[axes[0]]%4 == 0: self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis @@ -433,8 +448,9 @@ class Linearizer: xb_choices = [] for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce # if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken - if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))): - xb_choices.append((sum(st.strides[axis]>0 for st in self.sts), sum(st.strides[axis] for st in self.sts), axis, upcast_amount)) + # NOTE: this is using views[-1] + if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))): + xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) if len(xb_choices): xb_choices = sorted(xb_choices) if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") @@ -445,5 +461,5 @@ class Linearizer: break # if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < (self.shape_len-self.upcasted) and self.full_unupcasted_shape[-1] <= 5 and (len(self.offsets(self.full_buf_index)) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): + if self.first_reduce < (self.shape_len-self.upcasted) and self.full_unupcasted_shape[-1] <= 5 and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): self.upcast() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 34fdb23181..0efd005ff1 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -131,7 +131,7 @@ class LazyBuffer: for x in get_buffers(self.op): x.realize() # HACK: image shape can be wrong, hot cast it back to a normal float - if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or self.shape[self.st.strides.index(1)]%4 != 0): + if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): if self.op.op == MovementOps.RESHAPE: # put CAST before the final RESHAPE self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 446a371e3b..c5505b8fa3 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -2,13 +2,19 @@ from __future__ import annotations import functools from enum import Enum, auto -from typing import Tuple, Union, List, Optional, cast, Dict, Callable +from typing import Tuple, Union, List, Optional, Dict, Callable from tinygrad.helpers import prod, DEBUG -from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node +from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode # these ops live here class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 +def check_no_mul(test, var): + if test == var: return True + if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay + if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay + return False + @functools.lru_cache(maxsize=None) def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]: assert len(shape) == len(strides) @@ -24,16 +30,37 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))) class View: - def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0): + def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None): self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset + self.mask = mask self.shape_strides = to_shape_strides(self.shape, self.strides) - self.contiguous: bool = self.offset == 0 and is_contiguous(self.shape, self.strides) + self.contiguous: bool = self.offset == 0 and is_contiguous(self.shape, self.strides) and mask is None - def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})" + def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})" - def expr_node(self, idx=None, offset:Union[Node, int]=0): + def expr_node_mask(self, idx, valid=None) -> Node: + expr = [valid] if valid is not None else [] + if self.mask is not None: + acc = 1 + for ns,(x,y) in list(zip(self.shape, self.mask))[::-1]: + base = ((idx//acc) % ns) + expr += [base >= x, base < y] + acc *= ns + return Variable.ands(expr) + + def idxs_to_idx(self, idxs): + assert len(idxs) == len(self.shape), "need an idx for all dimensions" + acc = 1 + ret = [] + for tidx,d in list(zip(idxs, self.shape))[::-1]: + ret.append(tidx * acc) + acc *= d + return Variable.sum(ret) + + # generate an expression if you have a single idx variable + def expr_node(self, idx=None) -> Node: if idx is None: idx = Variable('idx', 0, prod(self.shape)) - ret = [Variable.num(self.offset)+offset] + ret = [Variable.num(self.offset)] acc = 1 for d,s in self.shape_strides[::-1]: ret.append(((idx//acc)%d)*s) @@ -41,30 +68,9 @@ class View: return Variable.sum(ret) # generate an expression if you have a variable or expression for each index - def expr_idxs(self, idxs, offset:Union[Node, int]=0): - return Variable.sum([Variable.num(self.offset)+offset] + [(idx if isinstance(idx, Variable) else Variable(idx, 0, sh-1))*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) - -class ZeroView: - def __init__(self, old_shape:Tuple[int, ...], arg): - self.old_shape, self.arg = old_shape, arg - self.shape: Tuple[int, ...] = tuple([y-x for x,y in self.arg]) - # fake properties - self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0 - - def __repr__(self): return f"ZeroView({self.old_shape}, {self.arg})" - - def expr_node(self, idx=None, valid=None): - if idx is None: idx = Variable('idx', 0, prod([y-x for x,y in self.arg])) - expr, acc = [valid] if valid is not None else [], 1 - for s,ns,(x,y) in list(zip(self.old_shape, self.shape, self.arg))[::-1]: - base = ((idx//acc) % ns) + x - expr += ([base >= 0] if x < 0 else []) + ([base < s] if y > s else []) - acc *= ns - return Variable.ands(expr) - - def expr_idxs(self, idxs, offset=0): raise NotImplementedError("ZeroView doesn't support expr_idxs") - -ViewTypes = Union[View, ZeroView] + def expr_idxs(self, idxs): + assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}" + return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) @functools.lru_cache(maxsize=None) def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: @@ -79,6 +85,7 @@ def view_from_shape(shape:Tuple[int, ...]) -> View: @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: + if vm2.mask: return None # this isn't supported yet new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset)) assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?" for s,st in zip(vm1.shape, vm1.strides): @@ -94,11 +101,11 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: else: if DEBUG >= 4: print("can't simplify", s, this_dim.render()) break - return View(vm1.shape, tuple(new_strides), new_offset.b) if len(new_strides) == len(vm1.strides) else None + return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None class ShapeTracker: - def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None): - self.views: List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]) + def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): + self.views: List[View] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]) def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})" def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:]) @@ -108,63 +115,72 @@ class ShapeTracker: @property def shape(self) -> Tuple[int, ...]: return self.views[-1].shape - @property - def strides(self) -> Tuple[int, ...]: return self.views[-1].strides + # this is the real size (ish) + def size(self): return prod([s for s,st in zip(self.shape, self.views[-1].strides) if st != 0]) - @property - def offset(self) -> int: return self.views[-1].offset + def unit_stride_axes(self) -> List[int]: + ret, acc = [], 1 + for j,s in list(enumerate(self.shape))[::-1]: + if s == 1: continue + var = Variable('idx', 0, s-1) + this_dim = self.expr_node(var*acc) + acc *= s + if check_no_mul(this_dim[0], var): ret.append(j) + return ret - # this is the real size - def size(self): return prod([s for s,st in zip(self.shape, self.strides) if st != 0]) - - def _expr_idx(self, idx): - valid = Variable.num(1) + def _expr_idx(self, idx, valid): for v in self.views[0:-1][::-1]: - if isinstance(v, ZeroView): valid = v.expr_node(idx, valid) - else: idx = v.expr_node(idx) + valid = v.expr_node_mask(idx, valid) + idx = v.expr_node(idx) return idx, valid def simplify(self): - if len(self.views) >= 2 and isinstance(self.views[-2], View) and isinstance(self.views[-1], View): + if len(self.views) >= 2: new_view = merge_views(self.views[-2], self.views[-1]) if new_view: if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}") self.views = self.views[:-2] + [new_view] self.simplify() - # TODO: arg order is reversed here - def expr_idxs(self, offset=0, idxs=None): - if idxs is None: idxs = [f"idx{i}" for i in range(len(self.shape))] - return self._expr_idx(self.views[-1].expr_idxs(idxs, offset)) + def expr_idxs(self, idxs=None): + if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] + idx = self.views[-1].expr_idxs(idxs) + valid = self.views[-1].expr_node_mask(self.views[-1].idxs_to_idx(idxs)) + return self._expr_idx(idx, valid) - def expr_node(self, idx='idx', offset=0): - return self._expr_idx(self.views[-1].expr_node(Variable(idx, 0, prod(self.shape)-1), offset)) + def expr_node(self, idx='idx'): + if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1) + return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx)) def needs_valid(self) -> bool: - return any(isinstance(v, ZeroView) for v in self.views) + return any(v.mask is not None for v in self.views) # *** under this line are the movement ops *** - def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...]): - offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)]) - self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset) + def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None): + offset = sum([self.views[-1].strides[i]*x for i,(x,_) in enumerate(arg)]) + if self.views[-1].mask: + # move the old mask + nmask = tuple((max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)) + # merge the masks if we have two + mask = tuple((max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)) if mask is not None else nmask + self.views[-1] = View(tuple(y-x for x,y in arg), self.views[-1].strides, self.views[-1].offset+offset, mask) def pad(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape) - if all(b==0 and e==0 for b,e in arg): return self # ZeroView is expensive if we don't need it + if all(b==0 and e==0 for b,e in arg): return self zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)) - zeroview = ZeroView(self.shape, zvarg) - self.__unsafe_resize(zvarg) - # if we add a ZeroView, we add another (stock) view also for modding - self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))] + self.__unsafe_resize(zvarg, mask=tuple((b,s+b) for s,(b,_) in zip(self.shape, arg))) def shrink(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape) self.__unsafe_resize(arg) def expand(self, new_shape: Tuple[int, ...]): - assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}" - self.views[-1] = View(new_shape, self.strides, self.offset) + assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}" + # NOTE: can the mask ever be (0,0)? + mask = tuple((((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)) if self.views[-1].mask else None + self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask) def reshape(self, new_shape: Tuple[int, ...]): if self.shape == new_shape: return self @@ -172,32 +188,42 @@ class ShapeTracker: assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # check if this is adding or removing 1s (only) - # NOTE: this is optional, but removes most calls to (expensive!) merge_views + # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1): - old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1] + old_strides = [y for x,y in zip(self.shape, self.views[-1].strides) if x != 1] new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape) - self.views[-1] = View(new_shape, new_strides_tuple, self.offset) + new_mask_tuple = None + if self.views[-1].mask: + if any(y!=(0,1) for x,y in zip(self.shape, self.views[-1].mask) if x == 1): + # mask it all out! + new_mask_tuple = tuple((0,0) for _ in new_shape) + else: + old_mask = [y for x,y in zip(self.shape, self.views[-1].mask) if x != 1] + new_mask_tuple = tuple((0,1) if x == 1 else old_mask.pop(0) for x in new_shape) + self.views[-1] = View(new_shape, new_strides_tuple, self.views[-1].offset, new_mask_tuple) return self view = View(new_shape, strides_for_shape(new_shape)) if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset else: - # NOTE: the last view in self.views is never a ZeroView - if (merged_view := merge_views(cast(View, self.views[-1]), view)) is not None: self.views[-1] = merged_view - else: self.views.append(view) + if (merged_view := merge_views(self.views[-1], view)) is not None: self.views[-1] = merged_view + else: + if DEBUG >= 4: print(f"WARNING: creating new view with reshape {self} -> {new_shape}") + self.views.append(view) def permute(self, axis: Tuple[int, ...]): assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" - self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset) + self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.views[-1].strides[a] for a in axis), self.views[-1].offset, tuple(self.views[-1].mask[a] for a in axis) if self.views[-1].mask is not None else None) # except for the negative case, you can build this from the others. invertible in the negative case def stride(self, mul: Tuple[int, ...]): assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" - strides = tuple(z*m for z,m in zip(self.strides, mul)) + strides = tuple(z*m for z,m in zip(self.views[-1].strides, mul)) new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)) - offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0]) - self.views[-1] = View(new_shape, strides, self.offset + offset) + offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.views[-1].strides, mul) if m < 0]) + mask = tuple((((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.shape, mul)) if self.views[-1].mask is not None else None + self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask) # *** entry point for external *** diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 4a0a14e817..0509b9ec9d 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -28,6 +28,12 @@ class Node: def __mul__(self, b:int): if b == 0: return NumNode(0) elif b == 1: return self + + # this is a hack to make div work with boolean nodes. TODO: make generic + if isinstance(self, GeNode): return (self.a*b) >= (self.b*b) + if isinstance(self, LtNode): return (self.a*b) < (self.b*b) + if isinstance(self, AndNode): return Variable.ands([x*b for x in self.nodes]) + if isinstance(self, MulNode): return self.a*(self.b*b) # two muls is one mul if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum return create_opnode(MulNode, self, b) @@ -47,7 +53,7 @@ class Node: if isinstance(self, ModNode) and self.b % b == 0: return (self.a//b) % (self.b//b) # put the div inside mod if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b) - if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b) + if isinstance(self, MulNode) and b % self.b == 0 and self.b > 0: return self.a//(b//self.b) # NOTE: mod negative isn't handled right if isinstance(self, SumNode) and factoring_allowed: factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0) nofactor = [] @@ -94,7 +100,7 @@ class Node: else: a = self if a.min >= 0 and a.max < b: return a - if a.min < 0: return (a + ((a.min//b)*b)) % b + if a.min < 0: return (a - ((a.min//b)*b)) % b return create_opnode(ModNode, a, b) @staticmethod