mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
shapetracker touchups
This commit is contained in:
@@ -163,10 +163,8 @@ class ShapeTracker:
|
||||
self.__unsafe_resize(arg)
|
||||
|
||||
def _expand(self, new_shape: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}"
|
||||
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
strides: Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
|
||||
self.views[-1] = View(new_shape, strides, self.offset)
|
||||
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
self.views[-1] = View(new_shape, self.strides, self.offset)
|
||||
|
||||
def _reshape(self, new_shape: Tuple[int, ...]):
|
||||
if self.shape == new_shape: return self
|
||||
@@ -195,7 +193,7 @@ class ShapeTracker:
|
||||
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
def _stride(self, mul: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) for x in mul)
|
||||
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
||||
strides = tuple(z*m for z,m in zip(self.strides, mul))
|
||||
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
@@ -209,7 +207,7 @@ class ShapeTracker:
|
||||
return self
|
||||
|
||||
dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker._reshape, MovementOps.EXPAND: ShapeTracker._expand, MovementOps.PAD: ShapeTracker._pad,
|
||||
MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride}
|
||||
MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride}
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
|
||||
|
||||
Reference in New Issue
Block a user