rename transpose to permute

This commit is contained in:
George Hotz
2022-06-10 19:41:50 -07:00
parent 462f1ce0da
commit c8bacd0d8e
3 changed files with 5 additions and 2 deletions

View File

@@ -122,7 +122,7 @@ mlops are mid level ops, there's 13 of them. They understand memory allocation a
Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Reshape, Transpose, Slice # movement ops
Reshape, Permute, Slice # movement ops
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
```

View File

@@ -132,7 +132,7 @@ class Reshape(Function):
in_shape, = ctx.saved_tensors
return ctx.op.reshape(grad_output, in_shape)
class Transpose(Function):
class Permute(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
ret = ctx.buffer([x.shape[i] for i in order])

View File

@@ -257,6 +257,9 @@ class Tensor:
dot = matmul
def transpose(self, order=(1,0)):
return self.permute(order=order)
def _canonicalize_reduce_axis(self, axis):
if axis is None: axis = range(len(self.shape))
if isinstance(axis, int): axis = [axis]