mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix shapetracker test
This commit is contained in:
@@ -23,31 +23,31 @@ class CheckingShapeTracker:
|
||||
def simplify(self): self.st.simplify()
|
||||
|
||||
def reshape(self, new_shape):
|
||||
self.st._reshape(new_shape)
|
||||
self.st.reshape(new_shape)
|
||||
self.t = self.t.reshape(new_shape)
|
||||
|
||||
def permute(self, axis):
|
||||
self.st._permute(axis)
|
||||
self.st.permute(axis)
|
||||
self.t = np.transpose(self.t, axis)
|
||||
|
||||
def expand(self, new_shape):
|
||||
self.st._expand(new_shape)
|
||||
self.st.expand(new_shape)
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
|
||||
def flip(self, axis):
|
||||
self.st._stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
|
||||
def shrink(self, arg):
|
||||
self.st._shrink(arg)
|
||||
self.st.shrink(arg)
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
|
||||
def pad(self, arg):
|
||||
self.st._pad(arg)
|
||||
self.st.pad(arg)
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
|
||||
def stride(self, arg):
|
||||
self.st._stride(arg)
|
||||
self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
|
||||
def __getitem__(self, val):
|
||||
|
||||
Reference in New Issue
Block a user