mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
stupid, but the tests should pass
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from ..mlops import *
|
||||
#Buffer = select_llops("opencl")
|
||||
Buffer = select_llops("opencl")
|
||||
|
||||
@@ -370,6 +370,16 @@ class Function:
|
||||
self.requires_grad = any(t.requires_grad for t in tensors)
|
||||
self.saved_tensors = []
|
||||
|
||||
@property
|
||||
def op(f):
|
||||
# uhhh, obviously stupid
|
||||
if f.device == 0:
|
||||
return importlib.import_module(f".cpu", f"tinygrad.llops")
|
||||
elif f.device == 1:
|
||||
return importlib.import_module(f".opencl", f"tinygrad.llops")
|
||||
elif f.device == 2:
|
||||
return importlib.import_module(f".torch", f"tinygrad.llops")
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
if self.requires_grad:
|
||||
self.saved_tensors.extend(x)
|
||||
@@ -400,7 +410,6 @@ def register(name, fxn, device=Device.CPU):
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = Tensor.ops[tt.device][name]
|
||||
f.device = tt.device
|
||||
f.op = importlib.import_module(f".cpu", f"tinygrad.llops")
|
||||
return f.apply(f, *x, **kwargs)
|
||||
if getattr(Tensor, name, None) is not None:
|
||||
setattr(Tensor, "_"+name, dispatch)
|
||||
|
||||
Reference in New Issue
Block a user