stupid, but the tests should pass

This commit is contained in:
George Hotz
2022-06-08 23:43:10 -07:00
parent 214fb8c974
commit d841fc4392
2 changed files with 11 additions and 2 deletions

View File

@@ -1,2 +1,2 @@
from ..mlops import *
#Buffer = select_llops("opencl")
Buffer = select_llops("opencl")

View File

@@ -370,6 +370,16 @@ class Function:
self.requires_grad = any(t.requires_grad for t in tensors)
self.saved_tensors = []
@property
def op(f):
# uhhh, obviously stupid
if f.device == 0:
return importlib.import_module(f".cpu", f"tinygrad.llops")
elif f.device == 1:
return importlib.import_module(f".opencl", f"tinygrad.llops")
elif f.device == 2:
return importlib.import_module(f".torch", f"tinygrad.llops")
def save_for_backward(self, *x):
if self.requires_grad:
self.saved_tensors.extend(x)
@@ -400,7 +410,6 @@ def register(name, fxn, device=Device.CPU):
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = Tensor.ops[tt.device][name]
f.device = tt.device
f.op = importlib.import_module(f".cpu", f"tinygrad.llops")
return f.apply(f, *x, **kwargs)
if getattr(Tensor, name, None) is not None:
setattr(Tensor, "_"+name, dispatch)