From 6e6bcbe5f24499e69e0cd19dd0e552121a2166c9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 10 Nov 2020 01:23:22 -0800 Subject: [PATCH] shapes on backward --- tinygrad/tensor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 00b54ee758..fb22a8907b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -104,16 +104,15 @@ class Tensor: if DEBUG: st = time.time() grads = self._ctx.backward(self._ctx, self.grad.data) + if len(self._ctx.parents) == 1: + grads = [grads] if DEBUG: global debug_counts, debug_times name = "back_"+self._ctx.__class__.__name__ et = (time.time()-st)*1000. debug_counts[name] += 1 debug_times[name] += et - print("%20s : %7.2f ms" % (name, et)) - - if len(self._ctx.parents) == 1: - grads = [grads] + print("%20s : %7.2f ms %s" % (name, et, [y.shape for y in grads])) for t,g in zip(self._ctx.parents, grads): if g is None: continue