diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 196fd05cba..79029bfe01 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -85,7 +85,6 @@ class Device: class Tensor: did_float_warning = False training = False - ops = {} def __init__(self, data, device=Device.DEFAULT, requires_grad=True): self.device, self.data = device, self._move_data(data, device) @@ -382,8 +381,9 @@ class Function: if self.requires_grad: self.saved_tensors.extend(x) - def apply(self, *x, **kwargs): + def apply(self, device, *x, **kwargs): ctx = self(*x) # self - operation i.e 'add', 'sub', etc. + ctx.device = device # use default params params = inspect.signature(self.forward).parameters for p in params.values(): @@ -402,19 +402,14 @@ class Function: return ret def register(name, fxn): - Tensor.ops[name] = fxn def dispatch(*x, **kwargs): - # get first tensor in args to determine device - tt = [arg for arg in x if isinstance(arg, Tensor)][0] + tt = [arg for arg in x if isinstance(arg, Tensor)] + assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device" # create tensors from number arguments - x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] - f = Tensor.ops[name] # get the function by device and name - f.device = tt.device - return f.apply(f, *x, **kwargs) - if getattr(Tensor, name, None) is not None: - setattr(Tensor, "_"+name, dispatch) - else: - setattr(Tensor, name, dispatch) + x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] + return fxn.apply(fxn, tt[0].device, *x, **kwargs) + + setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch) if name in ['add', 'sub', 'mul', 'pow', 'matmul']: setattr(Tensor, f"__{name}__", dispatch) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))