shapes on backward

This commit is contained in:
George Hotz
2020-11-10 01:23:22 -08:00
parent 56f71ae8e5
commit 6e6bcbe5f2

View File

@@ -104,16 +104,15 @@ class Tensor:
if DEBUG:
st = time.time()
grads = self._ctx.backward(self._ctx, self.grad.data)
if len(self._ctx.parents) == 1:
grads = [grads]
if DEBUG:
global debug_counts, debug_times
name = "back_"+self._ctx.__class__.__name__
et = (time.time()-st)*1000.
debug_counts[name] += 1
debug_times[name] += et
print("%20s : %7.2f ms" % (name, et))
if len(self._ctx.parents) == 1:
grads = [grads]
print("%20s : %7.2f ms %s" % (name, et, [y.shape for y in grads]))
for t,g in zip(self._ctx.parents, grads):
if g is None:
continue