mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
simpler dispatch logic since the mlops are universal now
This commit is contained in:
@@ -85,7 +85,6 @@ class Device:
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
training = False
|
||||
ops = {}
|
||||
|
||||
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
|
||||
self.device, self.data = device, self._move_data(data, device)
|
||||
@@ -382,8 +381,9 @@ class Function:
|
||||
if self.requires_grad:
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
def apply(self, *x, **kwargs):
|
||||
def apply(self, device, *x, **kwargs):
|
||||
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
|
||||
ctx.device = device
|
||||
# use default params
|
||||
params = inspect.signature(self.forward).parameters
|
||||
for p in params.values():
|
||||
@@ -402,19 +402,14 @@ class Function:
|
||||
return ret
|
||||
|
||||
def register(name, fxn):
|
||||
Tensor.ops[name] = fxn
|
||||
def dispatch(*x, **kwargs):
|
||||
# get first tensor in args to determine device
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)]
|
||||
assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device"
|
||||
# create tensors from number arguments
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = Tensor.ops[name] # get the function by device and name
|
||||
f.device = tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
if getattr(Tensor, name, None) is not None:
|
||||
setattr(Tensor, "_"+name, dispatch)
|
||||
else:
|
||||
setattr(Tensor, name, dispatch)
|
||||
x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
return fxn.apply(fxn, tt[0].device, *x, **kwargs)
|
||||
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
|
||||
Reference in New Issue
Block a user