reshape makes a copy

This commit is contained in:
George Hotz
2022-06-10 19:49:04 -07:00
parent c8bacd0d8e
commit e5d694490f
3 changed files with 7 additions and 5 deletions

View File

@@ -40,3 +40,4 @@ from enum import Enum
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])

View File

@@ -48,9 +48,10 @@ def reduce_op(op, inp, ret):
else: raise Exception(f"{op} isn't supported")
return ret
def reshape(x, shape):
assert np.prod(x.shape) == np.prod(shape)
return x.reshape(shape)
def reshape(x, ret):
assert np.prod(x.shape) == np.prod(ret.shape)
ret[:] = x.reshape(ret.shape)
return ret
def perm_axis(x, order, ret):
ret[:] = x.permute(order)

View File

@@ -126,11 +126,11 @@ class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
return ctx.op.reshape(x, shape) # NOTE: this is not a copy
return ctx.op.reshape(x, ctx.buffer(shape))
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ctx.op.reshape(grad_output, in_shape)
return ctx.op.reshape(grad_output, ctx.buffer(in_shape))
class Permute(Function):
def forward(ctx, x, order=(1,0)):