diff --git a/README.md b/README.md index 806d3b466a..218a7cc4ab 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ mlops are mid level ops, there's 13 of them. They understand memory allocation a Relu, Log, Exp # unary ops Sum, Max # reduce ops (with axis argument) Add, Sub, Mul, Pow # binary ops (with broadcasting) -Reshape, Transpose, Slice # movement ops +Reshape, Permute, Slice # movement ops Conv2D(NCHW) # processing op (Matmul is also Conv2D) ``` diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index b692e134d6..f96cbc48d2 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -132,7 +132,7 @@ class Reshape(Function): in_shape, = ctx.saved_tensors return ctx.op.reshape(grad_output, in_shape) -class Transpose(Function): +class Permute(Function): def forward(ctx, x, order=(1,0)): ctx.save_for_backward(order) ret = ctx.buffer([x.shape[i] for i in order]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4d8cd82aa3..97b39f49fd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -257,6 +257,9 @@ class Tensor: dot = matmul + def transpose(self, order=(1,0)): + return self.permute(order=order) + def _canonicalize_reduce_axis(self, axis): if axis is None: axis = range(len(self.shape)) if isinstance(axis, int): axis = [axis]