View.__add__ to merge_view (#3686)

verified the cases that used real_strides are redundant
This commit is contained in:
chenyu
2024-03-11 15:52:34 -04:00
committed by GitHub
parent 76ade20b89
commit b68fbd7d81
2 changed files with 78 additions and 80 deletions

View File

@@ -1,87 +1,10 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
from tinygrad.shape.view import View, strides_for_shape
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
strides = strides_for_shape(shape)
result = []
for stride in strides:
here = offs // stride if stride else 0
result.append(here)
offs -= here * stride
return result
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.contiguous: return vm1
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()):
return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask)
if vm1.mask:
for b,e in vm1.mask:
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, Node]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not (merged_term >= merged_size) and not (merged_term < 0):
extents.append((merged_size, merged_term))
merged_size, merged_term = 1, NumNode(0)
if merged_term: return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1)
if vm2.mask:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.min < b or t.max >= e): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue
term = terms[d2]
if len(term) != 1:
if not term and newe: newe[0] = 0
else: bad = True
continue
d1, s1 = term[0]
if not isinstance(s1, int) or not isinstance(newe[d1], int):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
# If any of vm1 was masked off, try again with that mask in place.
for b, e, s in zip(newb, newe, vm1.shape):
if b != 0 or e != s:
return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))))
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
from tinygrad.shape.view import View
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
@@ -170,7 +93,7 @@ class ShapeTracker:
return f'idx{axis}' in [v.expr for v in valid.vars()]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := merge_views(self.views[-2], self.views[-1])) is not None:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, operator, itertools
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.helpers import prod, all_int, argsort
@@ -67,6 +67,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
return tuple(reversed(new_mask)), False
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
strides = strides_for_shape(shape)
result = []
for stride in strides:
here = offs // stride if stride else 0
result.append(here)
offs -= here * stride
return result
@dataclass(frozen=True)
class View:
shape:Tuple[sint, ...]
@@ -115,6 +124,72 @@ class View:
b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __add__(self, vm1:View) -> Optional[View]:
vm2 = self
if vm2.contiguous: return vm1
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if vm1.mask:
for b,e in vm1.mask:
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, Node]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not (merged_term >= merged_size) and not (merged_term < 0):
extents.append((merged_size, merged_term))
merged_size, merged_term = 1, NumNode(0)
if merged_term: return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
if vm2.mask:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.min < b or t.max >= e): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue
term = terms[d2]
if len(term) != 1:
if not term and newe: newe[0] = 0
else: bad = True
continue
d1, s1 = term[0]
if not isinstance(s1, int) or not isinstance(newe[d1], int):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
# If any of vm1 was masked off, try again with that mask in place.
for b, e, s in zip(newb, newe, vm1.shape):
if b != 0 or e != s:
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
ret = View.create(self.shape)