mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
llops torch test passes
This commit is contained in:
@@ -7,7 +7,7 @@ def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32)
|
||||
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
|
||||
raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}")
|
||||
shape_ret = np.maximum(shape_x, shape_y)
|
||||
shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)])
|
||||
|
||||
if extra:
|
||||
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
|
||||
|
||||
@@ -37,10 +37,14 @@ def binary_op(op, x, y, ret):
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if ret.shape == (1,): axis=None
|
||||
if inp.shape == ret.shape: # this is just a copy
|
||||
ret[:] = inp
|
||||
return ret
|
||||
if ret.shape == (1,): axis=tuple(range(len(inp.shape)))
|
||||
else: axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, ret.shape)) if a != b])
|
||||
if op == ReduceOps.SUM: ret[:] = inp.sum(axis, keepdims=True)
|
||||
if op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: ret[:] = inp.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
return ret
|
||||
|
||||
def reshape(x, shape):
|
||||
|
||||
@@ -5,6 +5,11 @@ from ..tensor import Function
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class Buffer(torch.Tensor):
|
||||
def __new__(cls, shape):
|
||||
if isinstance(shape, torch.Tensor):
|
||||
return super().__new__(cls, shape)
|
||||
else:
|
||||
return Buffer(torch.zeros(shape))
|
||||
custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
@staticmethod
|
||||
def fromCPU(data):
|
||||
@@ -16,7 +21,7 @@ class Buffer(torch.Tensor):
|
||||
|
||||
# ************* unary+binary+reduce+movement ops *************
|
||||
|
||||
from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul
|
||||
from tinygrad.llops.cpu import unary_op, binary_op, reduce_op, reshape, perm_axis, inner_slice, matmul
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
|
||||
@@ -79,7 +79,8 @@ def unbroadcast(out, in_sh):
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return ll.binary_op(BinaryOps.ADD, x, y, ll.Buffer(binary_broadcast(x.shape, y.shape)))
|
||||
buf = ll.Buffer(binary_broadcast(x.shape, y.shape))
|
||||
return ll.binary_op(BinaryOps.ADD, x, y, buf) #ll.Buffer(binary_broadcast(x.shape, y.shape)))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
@@ -145,7 +146,7 @@ class Transpose(Function):
|
||||
return ll.perm_axis(x, order, ret)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
norder = np.argsort(ctx.order)
|
||||
norder = np.argsort(ctx.order).tolist()
|
||||
ret = ll.Buffer([grad_output.shape[i] for i in norder])
|
||||
return ll.perm_axis(grad_output, norder, ret)
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from ..mlops import *
|
||||
Buffer = select_llops("opencl")
|
||||
#Buffer = select_llops("opencl")
|
||||
|
||||
Reference in New Issue
Block a user