redundant shape simplify in __unsafe_resize [pr] (#8301)

also done in View.create.
This commit is contained in:
chenyu
2024-12-17 19:00:45 -05:00
committed by GitHub
parent a9f46ebf70
commit 4e2d98638d

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, operator, itertools, math
import functools, operator, itertools
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Sequence
from tinygrad.dtype import dtypes
@@ -215,13 +215,12 @@ class View:
if not all_int([s1, newe[d1]]):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
# If any of vm1 was masked off, try again with that mask in place.
for b, e, s in zip(newb, newe, vm1.shape):
if (b, e) != (0, s):
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None
@@ -246,8 +245,7 @@ class View:
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
# merge the masks if we have two
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
shape = [y-x for x,y in arg]
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
@@ -273,7 +271,6 @@ class View:
# NOTE: does not check multiple of symbolic shape
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
if 0 in self.shape: return View.create(new_shape)
# NOTE: can the mask ever be (0,0)?
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None