simpler dispatch logic since the mlops are universal now

This commit is contained in:
George Hotz
2022-06-09 10:01:10 -07:00
parent 273af8d732
commit e58d344759

View File

@@ -85,7 +85,6 @@ class Device:
class Tensor:
did_float_warning = False
training = False
ops = {}
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
self.device, self.data = device, self._move_data(data, device)
@@ -382,8 +381,9 @@ class Function:
if self.requires_grad:
self.saved_tensors.extend(x)
def apply(self, *x, **kwargs):
def apply(self, device, *x, **kwargs):
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
ctx.device = device
# use default params
params = inspect.signature(self.forward).parameters
for p in params.values():
@@ -402,19 +402,14 @@ class Function:
return ret
def register(name, fxn):
Tensor.ops[name] = fxn
def dispatch(*x, **kwargs):
# get first tensor in args to determine device
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
tt = [arg for arg in x if isinstance(arg, Tensor)]
assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device"
# create tensors from number arguments
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = Tensor.ops[name] # get the function by device and name
f.device = tt.device
return f.apply(f, *x, **kwargs)
if getattr(Tensor, name, None) is not None:
setattr(Tensor, "_"+name, dispatch)
else:
setattr(Tensor, name, dispatch)
x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
return fxn.apply(fxn, tt[0].device, *x, **kwargs)
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))