print gflops avg with DEBUG=2

This commit is contained in:
George Hotz
2023-03-23 16:07:08 -07:00
parent de04208247
commit e88b9bfe1e
2 changed files with 9 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ import numpy as np
import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad import nn
from tinygrad.helpers import getenv
from tinygrad.nn import optim
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
from tinygrad.lazy import PUSH_PERMUTES
@@ -49,7 +50,7 @@ class TestInferenceMinKernels(unittest.TestCase):
model(img).realize()
def test_enet(self):
model = EfficientNet(has_se=False)
model = EfficientNet(getenv("ENET_NUM", 0), has_se=False)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(51):

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
from typing import Dict, List, Optional
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import getenv, DEBUG
from tinygrad.helpers import getenv, DEBUG, GlobalCounters
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
@@ -15,6 +15,12 @@ GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), gete
G = nx.DiGraph() if nx is not None else None
cnts: Dict[OpType, int] = defaultdict(int)
if DEBUG >= 2:
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
atexit.register(print_globalcounters)
if GRAPH:
def save_graph_exit():
for k,v in cnts.items(): print(k, v)