mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
shapes on backward
This commit is contained in:
@@ -104,16 +104,15 @@ class Tensor:
|
||||
if DEBUG:
|
||||
st = time.time()
|
||||
grads = self._ctx.backward(self._ctx, self.grad.data)
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
if DEBUG:
|
||||
global debug_counts, debug_times
|
||||
name = "back_"+self._ctx.__class__.__name__
|
||||
et = (time.time()-st)*1000.
|
||||
debug_counts[name] += 1
|
||||
debug_times[name] += et
|
||||
print("%20s : %7.2f ms" % (name, et))
|
||||
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
print("%20s : %7.2f ms %s" % (name, et, [y.shape for y in grads]))
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
if g is None:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user