dispatcher: now this is nice

This commit is contained in:
George Hotz
2022-06-09 10:28:46 -07:00
parent 70bb3a7976
commit 8c084b8c12

View File

@@ -369,9 +369,11 @@ class Function:
cls.backward = staticmethod(cls.backward)
return super().__new__(cls)
def __init__(self, *tensors):
def __init__(self, device, *tensors):
self.device = device
self.parents = tensors
self.requires_grad = any(t.requires_grad for t in tensors)
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = any(self.needs_input_grad)
self.saved_tensors = []
buffer = property(lambda self: Device.buffers[self.device])
@@ -381,26 +383,25 @@ class Function:
if self.requires_grad:
self.saved_tensors.extend(x)
def apply(self, device, *x, **kwargs):
ctx = self(*x) # self - operation i.e 'add', 'sub', etc.
ctx.device = device
ctx.needs_input_grad = [t.requires_grad for t in x]
@classmethod
def apply(cls, *x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor
# create tensors from number arguments
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
assert all([tt.device == t.device for t in x]), "All tensors are not on the same device"
ctx = cls(tt.device, *x)
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=any(ctx.needs_input_grad))
ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=ctx.requires_grad)
po.output = [ret]
if ret.requires_grad:
ret._ctx = ctx # used by autograd engine
return ret
def register(name, fxn):
def dispatch(*x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)]
assert all([tt[0].device == t.device for t in tt]), "All tensors are not on the same device"
# create tensors from number arguments
x = [Tensor(np.array([arg], dtype=tt[0].dtype), device=tt[0].device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
return fxn.apply(fxn, tt[0].device, *x, **kwargs)
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)