diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index f9c8e67417..640715186f 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,87 +1,10 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations -import functools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, Iterable, cast from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint -from tinygrad.shape.view import View, strides_for_shape - -def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: - strides = strides_for_shape(shape) - result = [] - for stride in strides: - here = offs // stride if stride else 0 - result.append(here) - offs -= here * stride - return result - -@functools.lru_cache(maxsize=None) -def merge_views(vm2:View, vm1:View) -> Optional[View]: - if vm2.contiguous: return vm1 - if vm1.contiguous and vm1.shape == vm2.shape: return vm2 - if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret - if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()): - return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask) - if vm1.mask: - for b,e in vm1.mask: - if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape)) - return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) - - # Project vm1's offset and strides on to vm2. - origin = un1d(vm2.shape, vm1.offset) - terms: List[List[Tuple[int, sint]]] = [[] for _ in origin] - strides: List[sint] = [0] * len(vm1.shape) - for d1, st in enumerate(vm1.strides): - if st == 0: continue - for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))): - if (s1 := s1 - o) == 0: continue - terms[d2].append((d1, s1)) - strides[d1] += s1 * vm2.strides[d2] - - # Merge dimensions in vm2 if required. - # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. - idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] - merged_size, merged_term = 1, NumNode(0) - extents: List[Tuple[sint, Node]] = [] - for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): - merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size - merged_size *= s - if not (merged_term >= merged_size) and not (merged_term < 0): - extents.append((merged_size, merged_term)) - merged_size, merged_term = 1, NumNode(0) - if merged_term: return None - if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: - return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1) - - if vm2.mask: - # Try to project vm2's mask on to vm1. - newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False - for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))): - if not (t.min < b or t.max >= e): continue - if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int): - bad = True - continue - term = terms[d2] - if len(term) != 1: - if not term and newe: newe[0] = 0 - else: bad = True - continue - d1, s1 = term[0] - if not isinstance(s1, int) or not isinstance(newe[d1], int): - bad = True - continue - newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1)) - newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1) - - # If any of vm1 was masked off, try again with that mask in place. - for b, e, s in zip(newb, newe, vm1.shape): - if b != 0 or e != s: - return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))) - # Otherwise if vm2's mask was violated, then cannot merge. - if bad: return None - - return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset) +from tinygrad.shape.view import View def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]: assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" @@ -170,7 +93,7 @@ class ShapeTracker: return f'idx{axis}' in [v.expr for v in valid.vars()] def simplify(self) -> ShapeTracker: - if len(self.views) >= 2 and (new_view := merge_views(self.views[-2], self.views[-1])) is not None: + if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: return ShapeTracker(self.views[:-2] + (new_view,)).simplify() return self diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 17d85fb843..60c4e641a0 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools, operator, itertools +import functools, operator, itertools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast from tinygrad.helpers import prod, all_int, argsort @@ -67,6 +67,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl return tuple(reversed(new_mask)), False +def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: + strides = strides_for_shape(shape) + result = [] + for stride in strides: + here = offs // stride if stride else 0 + result.append(here) + offs -= here * stride + return result + @dataclass(frozen=True) class View: shape:Tuple[sint, ...] @@ -115,6 +124,72 @@ class View: b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def __add__(self, vm1:View) -> Optional[View]: + vm2 = self + if vm2.contiguous: return vm1 + if vm1.contiguous and vm1.shape == vm2.shape: return vm2 + if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret + if vm1.mask: + for b,e in vm1.mask: + if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape)) + return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) + + # Project vm1's offset and strides on to vm2. + origin = un1d(vm2.shape, vm1.offset) + terms: List[List[Tuple[int, sint]]] = [[] for _ in origin] + strides: List[sint] = [0] * len(vm1.shape) + for d1, st in enumerate(vm1.strides): + if st == 0: continue + for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))): + if (s1 := s1 - o) == 0: continue + terms[d2].append((d1, s1)) + strides[d1] += s1 * vm2.strides[d2] + + # Merge dimensions in vm2 if required. + # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. + idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] + merged_size, merged_term = 1, NumNode(0) + extents: List[Tuple[sint, Node]] = [] + for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): + merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size + merged_size *= s + if not (merged_term >= merged_size) and not (merged_term < 0): + extents.append((merged_size, merged_term)) + merged_size, merged_term = 1, NumNode(0) + if merged_term: return None + if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: + return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1 + + if vm2.mask: + # Try to project vm2's mask on to vm1. + newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False + for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))): + if not (t.min < b or t.max >= e): continue + if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int): + bad = True + continue + term = terms[d2] + if len(term) != 1: + if not term and newe: newe[0] = 0 + else: bad = True + continue + d1, s1 = term[0] + if not isinstance(s1, int) or not isinstance(newe[d1], int): + bad = True + continue + newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1)) + newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1) + + # If any of vm1 was masked off, try again with that mask in place. + for b, e, s in zip(newb, newe, vm1.shape): + if b != 0 or e != s: + return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))) + # Otherwise if vm2's mask was violated, then cannot merge. + if bad: return None + + return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]: ret = View.create(self.shape)