support drawing graphs

This commit is contained in:
George Hotz
2022-01-16 10:45:58 -08:00
parent 6a5cb6842e
commit 2a10116bfa
2 changed files with 17 additions and 6 deletions

View File

@@ -15,18 +15,18 @@ class TinyConvNet:
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
self.bn1 = BatchNorm2D(inter_chan)
#self.bn1 = BatchNorm2D(inter_chan)
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
self.bn2 = BatchNorm2D(out_chan)
#self.bn2 = BatchNorm2D(out_chan)
self.l1 = Tensor.uniform(out_chan*6*6, classes)
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
x = self.bn1(x)
#x = self.bn1(x)
x = x.conv2d(self.c2).relu().max_pool2d()
x = self.bn2(x)
#x = self.bn2(x)
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).logsoftmax()
return x.dot(self.l1)
if __name__ == "__main__":
IMAGENET = os.getenv("IMAGENET") is not None

View File

@@ -9,21 +9,32 @@ import numpy as np
DEBUG = os.getenv("DEBUG", None) is not None
if DEBUG:
G = None
if os.getenv("GRAPH", None) is not None:
import networkx as nx
G = nx.DiGraph()
import atexit, time
debug_counts, debug_times = defaultdict(int), defaultdict(float)
def print_debug_exit():
for name, _ in sorted(debug_times.items(), key=lambda x: -x[1]):
print(f"{name:>20} : {debug_counts[name]:>6} {debug_times[name]:>10.2f} ms")
if G is not None:
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
atexit.register(print_debug_exit)
class ProfileOp:
def __init__(self, name, x, backward=False):
self.name, self.x, self.output = f"back_{name}" if backward else name, x, None
self.name, self.x, self.output, self.backward = f"back_{name}" if backward else name, x, None, backward
def __enter__(self):
if DEBUG: self.st = time.time()
return self
def __exit__(self, *junk):
if DEBUG:
if G is not None:
for x in self.x:
for y in self.output:
G.add_edge(id(x.data), id(y.data), label=self.name, color="blue" if self.backward else "black")
G.nodes[id(x.data)]['label'], G.nodes[id(y.data)]['label'] = str(x.shape), str(y.shape)
self.output[0].data.toCPU()
et = (time.time()-self.st)*1000.
debug_counts[self.name] += 1