From b57a16aa8963708620808e97325868cf93cdec4a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:25:16 -0800 Subject: [PATCH] take merge views from corsix branch (#3273) * take merge views from corsix branch * better DEBUG * max views * remove view.py change * Revert "remove view.py change" This reverts commit f3025f4f393b4b9a9a1ac89ea488d82de448b78c. * only allow filter on non symbolic * oops, correct fix * comment to explain --- test/test_winograd.py | 8 +++- tinygrad/shape/shapetracker.py | 77 +++++++++++++++++++++++++++++++--- tinygrad/shape/view.py | 6 +++ 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/test/test_winograd.py b/test/test_winograd.py index 8f2e1a631c..cb69ce1609 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, GlobalCounters -from tinygrad.helpers import Timing, CI, Profiling, WINO +from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG from tinygrad.ops import LoadOps from tinygrad.codegen.linearizer import Linearizer @@ -28,6 +28,12 @@ class TestWinograd(unittest.TestCase): l = Linearizer(s.ast) l.hand_coded_optimizations() l.linearize() + if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views") + for st in l.sts: + assert len(st.views) <= 2, "too many views in winograd" + if DEBUG >= 3: + print(f"{len(st.views):3d} views") + for v in st.views: print(v) def test_profile(self): x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize() diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d62e55b5ac..4775f5c8b5 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,19 +1,86 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations -import functools +import functools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast, Iterable, Union from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint -from tinygrad.shape.view import View +from tinygrad.shape.view import View, strides_for_shape + +def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: + strides = strides_for_shape(shape) + result = [] + for stride in strides: + here = offs // stride if stride else 0 + result.append(here) + offs -= here * stride + return result @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: if vm1.contiguous and vm1.shape == vm2.shape: return vm2 if vm2.contiguous: return vm1 - if vm2.mask or vm1.offset != 0: return None # this isn't supported yet - if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None - return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask) + if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()): + return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask) + if vm1.mask: + for b,e in vm1.mask: + if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape)) + return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) + + # Project vm1's offset and strides on to vm2. + origin = un1d(vm2.shape, vm1.offset) + terms: List[List[Tuple[int, sint]]] = [[] for _ in origin] + strides: List[sint] = [0] * len(vm1.shape) + for d1, st in enumerate(vm1.strides): + if st == 0: continue + for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))): + if (s1 := s1 - o) == 0: continue + terms[d2].append((d1, s1)) + strides[d1] += s1 * vm2.strides[d2] + + # Merge dimensions in vm2 if required. + # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. + idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] + merged_size, merged_term = 1, NumNode(0) + extents: List[Tuple[sint, Node]] = [] + for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): + merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size + merged_size *= s + if not (merged_term >= merged_size) and not (merged_term < 0): + extents.append((merged_size, merged_term)) + merged_size, merged_term = 1, NumNode(0) + if merged_term: return None + if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: + return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1) + + if vm2.mask: + # Try to project vm2's mask on to vm1. + newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False + for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))): + if not (t.min < b or t.max >= e): continue + if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int): + bad = True + continue + term = terms[d2] + if len(term) != 1: + if not term and newe: newe[0] = 0 + else: bad = True + continue + d1, s1 = term[0] + if not isinstance(s1, int) or not isinstance(newe[d1], int): + bad = True + continue + newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1)) + newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1) + + # If any of vm1 was masked off, try again with that mask in place. + for b, e, s in zip(newb, newe, vm1.shape): + if b != 0 or e != s: + return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))) + # Otherwise if vm2's mask was violated, then cannot merge. + if bad: return None + + return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset) def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]: assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index e5f7a4eebe..df5d716e77 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -83,6 +83,12 @@ class View: def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None): strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape) contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape) + # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked + # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset + if mask and any(elim := [isinstance(b, int) and isinstance(e, int) and b+1 >= e for b,e in mask]): + if any(b >= e for b,e in mask): strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape) + offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim)) + strides = tuple(0 if e else st for st,e in zip(strides, elim)) return View(shape, strides, offset, mask, contiguous) @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none