move assert, remove comment

This commit is contained in:
George Hotz
2022-01-15 21:36:58 -08:00
parent 845bb1fc34
commit d541e2a8e5

View File

@@ -126,9 +126,9 @@ class Tensor:
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None)
if not any([x.requires_grad for x in t0._ctx.parents]):
continue
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
if len(t0._ctx.parents) == 1:
@@ -366,7 +366,6 @@ def register(name, fxn, device=Device.CPU):
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = Tensor.ops[tt.device][name]
#f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
f.device = tt.device
return f.apply(f, *x, **kwargs)
if getattr(Tensor, name, None) is not None: