From 4e1a0de392eba10d3c21f23f1c8892e9a8f76186 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 8 Dec 2020 10:05:21 -0800 Subject: [PATCH] fix rsub --- test/test_ops.py | 5 +++++ tinygrad/tensor.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 394fba4147..af877883eb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -75,6 +75,11 @@ class TestOps(unittest.TestCase): def test_scalar_rmul(self): helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu) + def test_scalar_sub(self): + helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu) + def test_scalar_rsub(self): + helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu) + def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f3286092d0..291344ec25 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -244,9 +244,9 @@ def register(name, fxn, gpu=False): else: Tensor.ops_cpu[name] = fxn def dispatch(*x, **kwargs): - assert isinstance(x[0], Tensor) - x = [Tensor(np.array([arg], dtype=x[0].dtype), gpu=x[0].gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] - f = (Tensor.ops_gpu if x[0].gpu else Tensor.ops_cpu)[name] + tt = [arg for arg in x if isinstance(arg, Tensor)][0] + x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] + f = (Tensor.ops_gpu if tt.gpu else Tensor.ops_cpu)[name] f.cl_ctx, f.cl_queue = cl_ctx, cl_queue return f.apply(f, *x, **kwargs) setattr(Tensor, name, dispatch) @@ -254,7 +254,7 @@ def register(name, fxn, gpu=False): if name in ['add', 'sub', 'mul', 'pow']: setattr(Tensor, f"__{name}__", dispatch) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x))) - setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(self,x)) + setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self)) # this registers all the operations import tinygrad.ops_cpu