mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
redundant shape simplify in __unsafe_resize [pr] (#8301)
also done in View.create.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Sequence
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -215,13 +215,12 @@ class View:
|
||||
if not all_int([s1, newe[d1]]):
|
||||
bad = True
|
||||
continue
|
||||
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
||||
newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
|
||||
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
||||
|
||||
# If any of vm1 was masked off, try again with that mask in place.
|
||||
for b, e, s in zip(newb, newe, vm1.shape):
|
||||
if (b, e) != (0, s):
|
||||
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
||||
if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
|
||||
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
||||
# Otherwise if vm2's mask was violated, then cannot merge.
|
||||
if bad: return None
|
||||
|
||||
@@ -246,8 +245,7 @@ class View:
|
||||
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
shape = [y-x for x,y in arg]
|
||||
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
||||
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
||||
@@ -273,7 +271,6 @@ class View:
|
||||
# NOTE: does not check multiple of symbolic shape
|
||||
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
||||
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
|
||||
Reference in New Issue
Block a user