fix shapetracker test

This commit is contained in:
George Hotz
2023-03-12 22:33:25 -07:00
parent 153cce0f7e
commit ce1564b05e
3 changed files with 11 additions and 12 deletions

View File

@@ -23,31 +23,31 @@ class CheckingShapeTracker:
def simplify(self): self.st.simplify()
def reshape(self, new_shape):
self.st._reshape(new_shape)
self.st.reshape(new_shape)
self.t = self.t.reshape(new_shape)
def permute(self, axis):
self.st._permute(axis)
self.st.permute(axis)
self.t = np.transpose(self.t, axis)
def expand(self, new_shape):
self.st._expand(new_shape)
self.st.expand(new_shape)
self.t = np.broadcast_to(self.t, new_shape)
def flip(self, axis):
self.st._stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
def shrink(self, arg):
self.st._shrink(arg)
self.st.shrink(arg)
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
def pad(self, arg):
self.st._pad(arg)
self.st.pad(arg)
self.t = np.pad(self.t, arg, constant_values=-1)
def stride(self, arg):
self.st._stride(arg)
self.st.stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
def __getitem__(self, val):