llops torch test passes

This commit is contained in:
George Hotz
2022-06-08 23:30:56 -07:00
parent 1e3db466cc
commit 5a533fc073
5 changed files with 17 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ def binary_broadcast(x_shape, y_shape, extra=False):
shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32)
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}")
shape_ret = np.maximum(shape_x, shape_y)
shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)])
if extra:
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims

View File

@@ -37,10 +37,14 @@ def binary_op(op, x, y, ret):
return ret
def reduce_op(op, inp, ret):
if ret.shape == (1,): axis=None
if inp.shape == ret.shape: # this is just a copy
ret[:] = inp
return ret
if ret.shape == (1,): axis=tuple(range(len(inp.shape)))
else: axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
if op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
else: raise Exception(f"{op} isn't supported")
return ret
def reshape(x, shape):

View File

@@ -5,6 +5,11 @@ from ..tensor import Function
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Buffer(torch.Tensor):
def __new__(cls, shape):
if isinstance(shape, torch.Tensor):
return super().__new__(cls, shape)
else:
return Buffer(torch.zeros(shape))
custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
@staticmethod
def fromCPU(data):
@@ -16,7 +21,7 @@ class Buffer(torch.Tensor):
# ************* unary+binary+reduce+movement ops *************
from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul
from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, reshape, perm_axis, inner_slice, matmul
# ************* processing ops *************

View File

@@ -79,7 +79,8 @@ def unbroadcast(out, in_sh):
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return ll.binary_op(BinaryOps.ADD, x, y, ll.Buffer(binary_broadcast(x.shape, y.shape)))
buf = ll.Buffer(binary_broadcast(x.shape, y.shape))
return ll.binary_op(BinaryOps.ADD, x, y, buf) #ll.Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
@@ -145,7 +146,7 @@ class Transpose(Function):
return ll.perm_axis(x, order, ret)
def backward(ctx, grad_output):
norder = np.argsort(ctx.order)
norder = np.argsort(ctx.order).tolist()
ret = ll.Buffer([grad_output.shape[i] for i in norder])
return ll.perm_axis(grad_output, norder, ret)

View File

@@ -1,2 +1,2 @@
from ..mlops import *
Buffer = select_llops("opencl")
#Buffer = select_llops("opencl")