mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 15:58:10 -05:00
View.__add__ to merge_view (#3686)
verified the cases that used real_strides are redundant
This commit is contained in:
@@ -1,87 +1,10 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
||||
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
strides = strides_for_shape(shape)
|
||||
result = []
|
||||
for stride in strides:
|
||||
here = offs // stride if stride else 0
|
||||
result.append(here)
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.contiguous: return vm1
|
||||
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
||||
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
||||
if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()):
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask)
|
||||
if vm1.mask:
|
||||
for b,e in vm1.mask:
|
||||
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = un1d(vm2.shape, vm1.offset)
|
||||
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
||||
strides: List[sint] = [0] * len(vm1.shape)
|
||||
for d1, st in enumerate(vm1.strides):
|
||||
if st == 0: continue
|
||||
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
||||
if (s1 := s1 - o) == 0: continue
|
||||
terms[d2].append((d1, s1))
|
||||
strides[d1] += s1 * vm2.strides[d2]
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
extents: List[Tuple[sint, Node]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_size *= s
|
||||
if not (merged_term >= merged_size) and not (merged_term < 0):
|
||||
extents.append((merged_size, merged_term))
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
if merged_term: return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1)
|
||||
|
||||
if vm2.mask:
|
||||
# Try to project vm2's mask on to vm1.
|
||||
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
||||
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
||||
if not (t.min < b or t.max >= e): continue
|
||||
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
||||
bad = True
|
||||
continue
|
||||
term = terms[d2]
|
||||
if len(term) != 1:
|
||||
if not term and newe: newe[0] = 0
|
||||
else: bad = True
|
||||
continue
|
||||
d1, s1 = term[0]
|
||||
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
||||
bad = True
|
||||
continue
|
||||
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
||||
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
||||
|
||||
# If any of vm1 was masked off, try again with that mask in place.
|
||||
for b, e, s in zip(newb, newe, vm1.shape):
|
||||
if b != 0 or e != s:
|
||||
return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))))
|
||||
# Otherwise if vm2's mask was violated, then cannot merge.
|
||||
if bad: return None
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
@@ -170,7 +93,7 @@ class ShapeTracker:
|
||||
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := merge_views(self.views[-2], self.views[-1])) is not None:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
||||
return self
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator, itertools
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
@@ -67,6 +67,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
||||
|
||||
return tuple(reversed(new_mask)), False
|
||||
|
||||
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
strides = strides_for_shape(shape)
|
||||
result = []
|
||||
for stride in strides:
|
||||
here = offs // stride if stride else 0
|
||||
result.append(here)
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
@@ -115,6 +124,72 @@ class View:
|
||||
b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, vm1:View) -> Optional[View]:
|
||||
vm2 = self
|
||||
if vm2.contiguous: return vm1
|
||||
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
||||
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
||||
if vm1.mask:
|
||||
for b,e in vm1.mask:
|
||||
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = un1d(vm2.shape, vm1.offset)
|
||||
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
||||
strides: List[sint] = [0] * len(vm1.shape)
|
||||
for d1, st in enumerate(vm1.strides):
|
||||
if st == 0: continue
|
||||
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
||||
if (s1 := s1 - o) == 0: continue
|
||||
terms[d2].append((d1, s1))
|
||||
strides[d1] += s1 * vm2.strides[d2]
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
extents: List[Tuple[sint, Node]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_size *= s
|
||||
if not (merged_term >= merged_size) and not (merged_term < 0):
|
||||
extents.append((merged_size, merged_term))
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
if merged_term: return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
|
||||
|
||||
if vm2.mask:
|
||||
# Try to project vm2's mask on to vm1.
|
||||
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
||||
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
||||
if not (t.min < b or t.max >= e): continue
|
||||
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
||||
bad = True
|
||||
continue
|
||||
term = terms[d2]
|
||||
if len(term) != 1:
|
||||
if not term and newe: newe[0] = 0
|
||||
else: bad = True
|
||||
continue
|
||||
d1, s1 = term[0]
|
||||
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
||||
bad = True
|
||||
continue
|
||||
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
||||
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
||||
|
||||
# If any of vm1 was masked off, try again with that mask in place.
|
||||
for b, e, s in zip(newb, newe, vm1.shape):
|
||||
if b != 0 or e != s:
|
||||
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
||||
# Otherwise if vm2's mask was violated, then cannot merge.
|
||||
if bad: return None
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
|
||||
ret = View.create(self.shape)
|
||||
|
||||
Reference in New Issue
Block a user