mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
print gflops avg with DEBUG=2
This commit is contained in:
3
test/external/external_test_opt.py
vendored
3
test/external/external_test_opt.py
vendored
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import nn
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps
|
||||
from tinygrad.lazy import PUSH_PERMUTES
|
||||
@@ -49,7 +50,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
model(img).realize()
|
||||
|
||||
def test_enet(self):
|
||||
model = EfficientNet(has_se=False)
|
||||
model = EfficientNet(getenv("ENET_NUM", 0), has_se=False)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
with CLCache(51):
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.helpers import getenv, DEBUG, GlobalCounters
|
||||
|
||||
GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
@@ -15,6 +15,12 @@ GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), gete
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
cnts: Dict[OpType, int] = defaultdict(int)
|
||||
if DEBUG >= 2:
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
||||
atexit.register(print_globalcounters)
|
||||
if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
|
||||
Reference in New Issue
Block a user