mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Somewhat more generic broadcasting (#105)
* Somewhat more generic broadcasting * Add TODO * Set Torch to deterministic in test Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,8 @@ import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
|
||||
torch._set_deterministic(True)
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False):
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
|
||||
@@ -83,23 +83,24 @@ def supersample_op(ctx, input, out_shape, kernel_size, result_op, decls='', inpu
|
||||
return ret
|
||||
|
||||
def binary_op(ctx, code, x, y):
|
||||
if len(x.shape) != len(y.shape) and y.shape != (1,):
|
||||
raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape))
|
||||
xdiv = 1
|
||||
ydiv = 1
|
||||
if x.shape != y.shape:
|
||||
# special case broadcasting
|
||||
# TODO: make general
|
||||
if len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and y.shape[2] == 1 and y.shape[3] == 1:
|
||||
ydiv = x.shape[2] * x.shape[3]
|
||||
elif len(y.shape) == 4 and x.shape[0:2] == y.shape[0:2] and x.shape[2] == 1 and x.shape[3] == 1:
|
||||
xdiv = y.shape[2] * y.shape[3]
|
||||
elif len(x.shape) == 2 and x.shape[0] == y.shape[0] and y.shape[1] == 1:
|
||||
ydiv = x.shape[1]
|
||||
elif np.prod(y.shape) == 1:
|
||||
ydiv = np.prod(x.shape)
|
||||
else:
|
||||
raise Exception("binary op shape mismatch: %r != %r" % (x.shape, y.shape))
|
||||
# TODO: Make broadcasting work when it's not at the most inner part of the arrays.
|
||||
def get_xdiv(xs, ys):
|
||||
if len(xs) != len(ys):
|
||||
return None
|
||||
r = 1
|
||||
for i in range(len(xs)):
|
||||
if (xs[i] != 1) and (r > 1 or (xs[i] != ys[i])):
|
||||
return None
|
||||
r *= ys[i] / xs[i]
|
||||
return r
|
||||
if y.shape == (1,):
|
||||
xdiv, ydiv=1, np.prod(x.shape)
|
||||
else:
|
||||
xdiv, ydiv=1, get_xdiv(y.shape, x.shape)
|
||||
if ydiv is None:
|
||||
xdiv, ydiv=get_xdiv(y.shape, x.shape), 1
|
||||
if xdiv is None:
|
||||
raise Exception("shape mismatch in binop %s: %r %r" % (code, x.shape, y.shape))
|
||||
ret = buffer_like(ctx, x if np.prod(x.shape) >= np.prod(y.shape) else y)
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void binop(__global const float *a_g, __global const float *b_g, __global float *res_g,
|
||||
|
||||
Reference in New Issue
Block a user