diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index ef862c28e5..122e0ec602 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -40,3 +40,4 @@ from enum import Enum UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) +MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"]) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 3a891dad60..d4e71f7841 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -48,9 +48,10 @@ def reduce_op(op, inp, ret): else: raise Exception(f"{op} isn't supported") return ret -def reshape(x, shape): - assert np.prod(x.shape) == np.prod(shape) - return x.reshape(shape) +def reshape(x, ret): + assert np.prod(x.shape) == np.prod(ret.shape) + ret[:] = x.reshape(ret.shape) + return ret def perm_axis(x, order, ret): ret[:] = x.permute(order) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index f96cbc48d2..a289fd303d 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -126,11 +126,11 @@ class Reshape(Function): def forward(ctx, x, shape): ctx.save_for_backward(x.shape) shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape) - return ctx.op.reshape(x, shape) # NOTE: this is not a copy + return ctx.op.reshape(x, ctx.buffer(shape)) def backward(ctx, grad_output): in_shape, = ctx.saved_tensors - return ctx.op.reshape(grad_output, in_shape) + return ctx.op.reshape(grad_output, ctx.buffer(in_shape)) class Permute(Function): def forward(ctx, x, order=(1,0)):