From 4e2d98638d2216b697e8875fcd4760f5fe8a51f9 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 17 Dec 2024 19:00:45 -0500 Subject: [PATCH] redundant shape simplify in __unsafe_resize [pr] (#8301) also done in View.create. --- tinygrad/shape/view.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 842a861b33..fe3926e360 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools, operator, itertools, math +import functools, operator, itertools from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast, Sequence from tinygrad.dtype import dtypes @@ -215,13 +215,12 @@ class View: if not all_int([s1, newe[d1]]): bad = True continue - newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1)) + newb[d1] = max(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1)) newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1) # If any of vm1 was masked off, try again with that mask in place. - for b, e, s in zip(newb, newe, vm1.shape): - if (b, e) != (0, s): - return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))) + if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)): + return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))) # Otherwise if vm2's mask was violated, then cannot merge. if bad: return None @@ -246,8 +245,7 @@ class View: nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)]) # merge the masks if we have two mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask - shape = [y-x for x,y in arg] - return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask) + return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View: @@ -273,7 +271,6 @@ class View: # NOTE: does not check multiple of symbolic shape assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}" if 0 in self.shape: return View.create(new_shape) - # NOTE: can the mask ever be (0,0)? # TODO: this resolve may not be needed, but it's hard because vars need to be sorted mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \ for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None