diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 3b698c1653..198501d380 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -7,7 +7,7 @@ def binary_broadcast(x_shape, y_shape, extra=False): shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32) if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)): raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}") - shape_ret = np.maximum(shape_x, shape_y) + shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)]) if extra: dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims diff --git a/tinygrad/llops/cpu.py b/tinygrad/llops/cpu.py index 7cacefcc68..28a713b16f 100644 --- a/tinygrad/llops/cpu.py +++ b/tinygrad/llops/cpu.py @@ -37,10 +37,14 @@ def binary_op(op, x, y, ret): return ret def reduce_op(op, inp, ret): - if ret.shape == (1,): axis=None + if inp.shape == ret.shape: # this is just a copy + ret[:] = inp + return ret + if ret.shape == (1,): axis=tuple(range(len(inp.shape))) else: axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b]) if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True) - if op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True) + elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True) + else: raise Exception(f"{op} isn't supported") return ret def reshape(x, shape): diff --git a/tinygrad/llops/torch.py b/tinygrad/llops/torch.py index d0dcebee85..fc7b6e1fde 100644 --- a/tinygrad/llops/torch.py +++ b/tinygrad/llops/torch.py @@ -5,6 +5,11 @@ from ..tensor import Function device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Buffer(torch.Tensor): + def __new__(cls, shape): + if isinstance(shape, torch.Tensor): + return super().__new__(cls, shape) + else: + return Buffer(torch.zeros(shape)) custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) @staticmethod def fromCPU(data): @@ -16,7 +21,7 @@ class Buffer(torch.Tensor): # ************* unary+binary+reduce+movement ops ************* -from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul +from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, reshape, perm_axis, inner_slice, matmul # ************* processing ops ************* diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 96913322fc..987fb9d54a 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -79,7 +79,8 @@ def unbroadcast(out, in_sh): class Add(Function): def forward(ctx, x, y): ctx.save_for_backward(x.shape, y.shape) - return ll.binary_op(BinaryOps.ADD, x, y, ll.Buffer(binary_broadcast(x.shape, y.shape))) + buf = ll.Buffer(binary_broadcast(x.shape, y.shape)) + return ll.binary_op(BinaryOps.ADD, x, y, buf) #ll.Buffer(binary_broadcast(x.shape, y.shape))) def backward(ctx, grad_output): shape_x, shape_y = ctx.saved_tensors @@ -145,7 +146,7 @@ class Transpose(Function): return ll.perm_axis(x, order, ret) def backward(ctx, grad_output): - norder = np.argsort(ctx.order) + norder = np.argsort(ctx.order).tolist() ret = ll.Buffer([grad_output.shape[i] for i in norder]) return ll.perm_axis(grad_output, norder, ret) diff --git a/tinygrad/ops/ops_gpu.py b/tinygrad/ops/ops_gpu.py index c3cf5ed06b..9fb52b2f83 100644 --- a/tinygrad/ops/ops_gpu.py +++ b/tinygrad/ops/ops_gpu.py @@ -1,2 +1,2 @@ from ..mlops import * -Buffer = select_llops("opencl") +#Buffer = select_llops("opencl")