Somewhat more generic broadcasting (#105)

* Somewhat more generic broadcasting

* Add TODO

* Set Torch to deterministic in test

Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
adamritter
2020-11-12 04:33:00 +00:00
committed by GitHub
parent 8827a536e0
commit f1d21afe88
2 changed files with 20 additions and 17 deletions

View File

@@ -5,6 +5,8 @@ import timeit
import functools
from tinygrad.tensor import Tensor, GPU
torch._set_deterministic(True)
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False):
ts = [torch.rand(x, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy()) for x in ts]

View File

@@ -83,23 +83,24 @@ def supersample_op(ctx, input, out_shape, kernel_size, result_op, decls='', inpu
return ret
def binary_op(ctx, code, x, y):
if len(x.shape) != len(y.shape) and y.shape != (1,):
raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape))
xdiv = 1
ydiv = 1
if x.shape != y.shape:
# special case broadcasting
# TODO: make general
if len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and y.shape[2] == 1 and y.shape[3] == 1:
ydiv = x.shape[2] * x.shape[3]
elif len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and x.shape[2] == 1 and x.shape[3] == 1:
xdiv = y.shape[2] * y.shape[3]
elif len(x.shape) == 2 and x.shape[0] == y.shape[0] and y.shape[1] == 1:
ydiv = x.shape[1]
elif np.prod(y.shape) == 1:
ydiv = np.prod(x.shape)
else:
raise Exception("binary op shape mismatch: %r != %r" % (x.shape, y.shape))
# TODO: Make broadcasting work when it's not at the most inner part of the arrays.
def get_xdiv(xs, ys):
if len(xs) != len(ys):
return None
r = 1
for i in range(len(xs)):
if (xs[i] != 1) and (r > 1 or (xs[i] != ys[i])):
return None
r *= ys[i] / xs[i]
return r
if y.shape == (1,):
xdiv, ydiv=1, np.prod(x.shape)
else:
xdiv, ydiv=1, get_xdiv(y.shape, x.shape)
if ydiv is None:
xdiv, ydiv=get_xdiv(y.shape, x.shape), 1
if xdiv is None:
raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape))
ret = buffer_like(ctx, x if np.prod(x.shape) >= np.prod(y.shape) else y)
prg = clbuild(ctx.cl_ctx, """
__kernel void binop(__global const float *a_g, __global const float *b_g, __global float *res_g,