mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
reshape makes a copy
This commit is contained in:
@@ -40,3 +40,4 @@ from enum import Enum
|
||||
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])
|
||||
|
||||
@@ -48,9 +48,10 @@ def reduce_op(op, inp, ret):
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def reshape(x, shape):
|
||||
assert np.prod(x.shape) == np.prod(shape)
|
||||
return x.reshape(shape)
|
||||
def reshape(x, ret):
|
||||
assert np.prod(x.shape) == np.prod(ret.shape)
|
||||
ret[:] = x.reshape(ret.shape)
|
||||
return ret
|
||||
|
||||
def perm_axis(x, order, ret):
|
||||
ret[:] = x.permute(order)
|
||||
|
||||
@@ -126,11 +126,11 @@ class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
|
||||
return ctx.op.reshape(x, shape) # NOTE: this is not a copy
|
||||
return ctx.op.reshape(x, ctx.buffer(shape))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.op.reshape(grad_output, in_shape)
|
||||
return ctx.op.reshape(grad_output, ctx.buffer(in_shape))
|
||||
|
||||
class Permute(Function):
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
|
||||
Reference in New Issue
Block a user