shapetracker touchups

This commit is contained in:
George Hotz
2023-03-11 12:07:52 -08:00
parent d41ac5f5f1
commit d30005b645

View File

@@ -163,10 +163,8 @@ class ShapeTracker:
self.__unsafe_resize(arg)
def _expand(self, new_shape: Tuple[int, ...]):
assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}"
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
strides: Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
self.views[-1] = View(new_shape, strides, self.offset)
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
self.views[-1] = View(new_shape, self.strides, self.offset)
def _reshape(self, new_shape: Tuple[int, ...]):
if self.shape == new_shape: return self
@@ -195,7 +193,7 @@ class ShapeTracker:
# except for the negative case, you can build this from the others. invertible in the negative case
def _stride(self, mul: Tuple[int, ...]):
assert all(isinstance(x, int) for x in mul)
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
strides = tuple(z*m for z,m in zip(self.strides, mul))
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
@@ -209,7 +207,7 @@ class ShapeTracker:
return self
dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker._reshape, MovementOps.EXPAND: ShapeTracker._expand, MovementOps.PAD: ShapeTracker._pad,
MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride}
MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride}
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):