From f1d21afe88d9c466d46068f364caee2bf2c3bfae Mon Sep 17 00:00:00 2001 From: adamritter <58403584+adamritter@users.noreply.github.com> Date: Thu, 12 Nov 2020 04:33:00 +0000 Subject: [PATCH] Somewhat more generic broadcasting (#105) * Somewhat more generic broadcasting * Add TODO * Set Torch to deterministic in test Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com> --- test/test_ops.py | 2 ++ tinygrad/opsgpu.py | 35 ++++++++++++++++++----------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4755d6af6c..87ed285dda 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,6 +5,8 @@ import timeit import functools from tinygrad.tensor import Tensor, GPU +torch._set_deterministic(True) + def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False): ts = [torch.rand(x, requires_grad=True) for x in shps] tst = [Tensor(x.detach().numpy()) for x in ts] diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index 0fa270890c..8d2ea02a8c 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -83,23 +83,24 @@ def supersample_op(ctx, input, out_shape, kernel_size, result_op, decls='', inpu return ret def binary_op(ctx, code, x, y): - if len(x.shape) != len(y.shape) and y.shape != (1,): - raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape)) - xdiv = 1 - ydiv = 1 - if x.shape != y.shape: - # special case broadcasting - # TODO: make general - if len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and y.shape[2] == 1 and y.shape[3] == 1: - ydiv = x.shape[2] * x.shape[3] - elif len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and x.shape[2] == 1 and x.shape[3] == 1: - xdiv = y.shape[2] * y.shape[3] - elif len(x.shape) == 2 and x.shape[0] == y.shape[0] and y.shape[1] == 1: - ydiv = x.shape[1] - elif np.prod(y.shape) == 1: - ydiv = np.prod(x.shape) - else: - raise Exception("binary op shape mismatch: %r != %r" % (x.shape, y.shape)) + # TODO: Make broadcasting work when it's not at the most inner part of the arrays. + def get_xdiv(xs, ys): + if len(xs) != len(ys): + return None + r = 1 + for i in range(len(xs)): + if (xs[i] != 1) and (r > 1 or (xs[i] != ys[i])): + return None + r *= ys[i] / xs[i] + return r + if y.shape == (1,): + xdiv, ydiv=1, np.prod(x.shape) + else: + xdiv, ydiv=1, get_xdiv(y.shape, x.shape) + if ydiv is None: + xdiv, ydiv=get_xdiv(y.shape, x.shape), 1 + if xdiv is None: + raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape)) ret = buffer_like(ctx, x if np.prod(x.shape) >= np.prod(y.shape) else y) prg = clbuild(ctx.cl_ctx, """ __kernel void binop(__global const float *a_g, __global const float *b_g, __global float *res_g,