add ctx saved tensors to graph

This commit is contained in:
George Hotz
2022-01-16 11:18:36 -08:00
parent 2a10116bfa
commit 2f531e35be

View File

@@ -19,12 +19,16 @@ if DEBUG:
for name, _ in sorted(debug_times.items(), key=lambda x: -x[1]):
print(f"{name:>20} : {debug_counts[name]:>6} {debug_times[name]:>10.2f} ms")
if G is not None:
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
atexit.register(print_debug_exit)
def is_buffer(x):
return any([isinstance(x, v) for v in Device.buffers.values()])
class ProfileOp:
def __init__(self, name, x, backward=False):
self.name, self.x, self.output, self.backward = f"back_{name}" if backward else name, x, None, backward
def __init__(self, ctx, name, x, backward=False):
self.ctx, self.name, self.x, self.output, self.backward = ctx, f"back_{name}" if backward else name, x, None, backward
def __enter__(self):
if DEBUG: self.st = time.time()
return self
@@ -35,6 +39,10 @@ class ProfileOp:
for y in self.output:
G.add_edge(id(x.data), id(y.data), label=self.name, color="blue" if self.backward else "black")
G.nodes[id(x.data)]['label'], G.nodes[id(y.data)]['label'] = str(x.shape), str(y.shape)
if self.backward:
for x in filter(is_buffer, self.ctx.saved_tensors):
for y in self.output:
G.add_edge(id(x), id(y.data), label=self.name, color="red")
self.output[0].data.toCPU()
et = (time.time()-self.st)*1000.
debug_counts[self.name] += 1
@@ -138,7 +146,7 @@ class Tensor:
if not any([x.requires_grad for x in t0._ctx.parents]):
continue
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
with ProfileOp(t0._ctx, t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
@@ -361,7 +369,7 @@ class Function:
# overwrite with passed params
for k, v in kwargs.items():
setattr(ctx, k, v)
with ProfileOp(ctx.__class__.__name__, x) as po:
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=any([t.requires_grad for t in x]))
po.output = [ret]