hotfix: don't spam BEAM debug in speed_v_theoretical

This commit is contained in:
George Hotz
2024-12-04 18:47:16 +08:00
parent 09b00b1b04
commit ea65c79ba2

View File

@@ -41,12 +41,15 @@ class TestKernelSpeed(unittest.TestCase):
if K is None: K = M
tms = []
with Context(BEAM=3):
for _ in range(10):
for i in range(10):
a = self._get_tensor(M, K)
b = self._get_tensor(K, N)
GlobalCounters.time_sum_s = 0
with Context(DEBUG=max(DEBUG, 2)): c = f(a, b)
tms.append(GlobalCounters.time_sum_s)
if i >= 3:
GlobalCounters.time_sum_s = 0
with Context(DEBUG=max(DEBUG, 2)): c = f(a, b)
tms.append(GlobalCounters.time_sum_s)
else:
c = f(a, b)
ops = 2 * M * N * K
mems = a.dtype.itemsize * M * K + b.dtype.itemsize * K * N + c.dtype.itemsize * M * N
@@ -65,16 +68,18 @@ class TestKernelSpeed(unittest.TestCase):
Tensor.realize(*get_parameters(conv))
with Context(BEAM=2):
for _ in range(10):
for i in range(10):
x = self._get_tensor(BS, CIN, H, W)
GlobalCounters.time_sum_s = 0
with Context(DEBUG=max(DEBUG, 2)): _c = f(conv, x)
tms.append(GlobalCounters.time_sum_s)
if i >= 3:
GlobalCounters.time_sum_s = 0
with Context(DEBUG=max(DEBUG, 2)): _c = f(conv, x)
tms.append(GlobalCounters.time_sum_s)
else:
_c = f(conv, x)
# naive algo
ops = 2 * BS * CIN * COUT * K * K * H * W
# TODO: what should this be?
mems = 0
mems = x.nbytes() + conv.weight.nbytes() + conv.bias.nbytes() + _c.nbytes()
tm = min(tms)
tflops = ops / tm / 1e12
gbs = mems / tm / 1e9