From d30005b64596295d7044fd3a6a0d9df5fb26a2fc Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Mar 2023 12:07:52 -0800 Subject: [PATCH] shapetracker touchups --- tinygrad/shape/shapetracker.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index b94054e893..064c29f179 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -163,10 +163,8 @@ class ShapeTracker: self.__unsafe_resize(arg) def _expand(self, new_shape: Tuple[int, ...]): - assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}" - assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}" - strides: Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape))) - self.views[-1] = View(new_shape, strides, self.offset) + assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}" + self.views[-1] = View(new_shape, self.strides, self.offset) def _reshape(self, new_shape: Tuple[int, ...]): if self.shape == new_shape: return self @@ -195,7 +193,7 @@ class ShapeTracker: # except for the negative case, you can build this from the others. invertible in the negative case def _stride(self, mul: Tuple[int, ...]): - assert all(isinstance(x, int) for x in mul) + assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" strides = tuple(z*m for z,m in zip(self.strides, mul)) new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)) offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0]) @@ -209,7 +207,7 @@ class ShapeTracker: return self dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker._reshape, MovementOps.EXPAND: ShapeTracker._expand, MovementOps.PAD: ShapeTracker._pad, - MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride} + MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride} # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):