mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
fix rsub
This commit is contained in:
@@ -75,6 +75,11 @@ class TestOps(unittest.TestCase):
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu)
|
||||
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu)
|
||||
|
||||
def test_broadcast_full(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
||||
|
||||
@@ -244,9 +244,9 @@ def register(name, fxn, gpu=False):
|
||||
else:
|
||||
Tensor.ops_cpu[name] = fxn
|
||||
def dispatch(*x, **kwargs):
|
||||
assert isinstance(x[0], Tensor)
|
||||
x = [Tensor(np.array([arg], dtype=x[0].dtype), gpu=x[0].gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = (Tensor.ops_gpu if x[0].gpu else Tensor.ops_cpu)[name]
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = (Tensor.ops_gpu if tt.gpu else Tensor.ops_cpu)[name]
|
||||
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
@@ -254,7 +254,7 @@ def register(name, fxn, gpu=False):
|
||||
if name in ['add', 'sub', 'mul', 'pow']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(self,x))
|
||||
setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self))
|
||||
|
||||
# this registers all the operations
|
||||
import tinygrad.ops_cpu
|
||||
|
||||
Reference in New Issue
Block a user