This commit is contained in:
George Hotz
2020-12-08 10:05:21 -08:00
parent c4540f1b8c
commit 4e1a0de392
2 changed files with 9 additions and 4 deletions

View File

@@ -75,6 +75,11 @@ class TestOps(unittest.TestCase):
def test_scalar_rmul(self):
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu)
def test_scalar_sub(self):
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu)
def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu)
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:

View File

@@ -244,9 +244,9 @@ def register(name, fxn, gpu=False):
else:
Tensor.ops_cpu[name] = fxn
def dispatch(*x, **kwargs):
assert isinstance(x[0], Tensor)
x = [Tensor(np.array([arg], dtype=x[0].dtype), gpu=x[0].gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = (Tensor.ops_gpu if x[0].gpu else Tensor.ops_cpu)[name]
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = (Tensor.ops_gpu if tt.gpu else Tensor.ops_cpu)[name]
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
@@ -254,7 +254,7 @@ def register(name, fxn, gpu=False):
if name in ['add', 'sub', 'mul', 'pow']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(self,x))
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
# this registers all the operations
import tinygrad.ops_cpu