delete old graph engine to reach line count

This commit is contained in:
George Hotz
2022-06-11 17:02:53 -07:00
parent a4d0d3f17a
commit 296f391403
2 changed files with 1 additions and 35 deletions

View File

@@ -8,7 +8,7 @@ ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])
import os
DEBUG = int(os.getenv("PRINT_LLOPS", "0"))
GRAPH = int(os.getenv("GRAPH_LLOPS", "0"))
GRAPH = int(os.getenv("GRAPH", "0"))
if GRAPH:
import atexit
import networkx as nx

View File

@@ -6,15 +6,6 @@ from tinygrad.helpers import prod
# **** profiler ****
GRAPH = os.getenv("GRAPH", None) is not None
if GRAPH:
import networkx as nx
G = nx.DiGraph()
def save_graph_exit():
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, '/tmp/net.dot')
atexit.register(save_graph_exit)
DEBUG = os.getenv("DEBUG", None) is not None
if DEBUG:
debug_counts, debug_times = defaultdict(int), defaultdict(float)
@@ -31,31 +22,6 @@ class ProfileOp:
if DEBUG: self.st = time.time()
return self
def __exit__(self, *junk):
if GRAPH:
def nm(x):
global global_num_max
if getattr(x, 'global_num', None) is None:
setattr(x, 'global_num', global_num_max)
global_num_max += 1
return f"<<< {x.global_num} >>>"
# connect inputs to outputs
for x in self.x:
for y in self.output:
G.add_edge(nm(x.data), nm(y.data), label=self.name, color="blue" if self.backward else "black")
G.nodes[nm(x.data)]['label'], G.nodes[nm(y.data)]['label'] = str(x.shape), str(y.shape)
# which saved tensors does this backward depend on?
saved_tensors = filter(lambda x: any(isinstance(x, v) for v in Device.buffers.values()), self.ctx.saved_tensors)
if self.backward:
for x in saved_tensors:
for y in self.output:
G.add_edge(nm(x), nm(y.data), label=self.name, color="red")
# did this forward create any intermediate tensors?
if not self.backward:
x_data = [nm(x.data) for x in self.x] + [nm(x.data) for x in self.output]
for y in saved_tensors:
if nm(y) not in x_data: # if intermediate tensors are inputs they don't count
for x in self.x:
G.add_edge(nm(x.data), nm(y), label=self.name, color="purple")
if DEBUG:
self.output[0].data.toCPU()
et = (time.time()-self.st)*1000.