mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
move assert, remove comment
This commit is contained in:
@@ -126,9 +126,9 @@ class Tensor:
|
||||
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
assert (t0.grad is not None)
|
||||
if not any([x.requires_grad for x in t0._ctx.parents]):
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
|
||||
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
|
||||
if len(t0._ctx.parents) == 1:
|
||||
@@ -366,7 +366,6 @@ def register(name, fxn, device=Device.CPU):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = Tensor.ops[tt.device][name]
|
||||
#f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
|
||||
f.device = tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
if getattr(Tensor, name, None) is not None:
|
||||
|
||||
Reference in New Issue
Block a user