mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tighter test_real_world mem and kernel count bounds (#13573)
also check if actual usage is within 20% of set limit, the old limits are too big to be useful
This commit is contained in:
@@ -28,13 +28,16 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
|
||||
model(*early_gen)
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
mem_used = GlobalCounters.mem_used - global_mem_used
|
||||
mem_used = (GlobalCounters.mem_used - global_mem_used) / 1e9
|
||||
|
||||
# TODO: jit should expose this correctly with graph
|
||||
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
|
||||
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB - {mem_used/1e9:.2} GB used"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels, it used {kernels_used}"
|
||||
assert mem_used < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.3f} GB - {mem_used:.3} GB used"
|
||||
assert (max_memory_allowed - mem_used) / max_memory_allowed < 0.2, f"{max_memory_allowed:.3f} GB is too far from {mem_used:.3} GB used"
|
||||
if kernels_used:
|
||||
assert kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels, it used {kernels_used}"
|
||||
assert (max_kernels_allowed - kernels_used) / max_kernels_allowed < 0.2, f"{max_kernels_allowed=} is too far from {kernels_used=} used"
|
||||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
||||
|
||||
@@ -61,7 +64,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, Tensor([801]), t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 32, 32),Tensor.randn(1, 77, params["ctx_dim"])), test, 18.0, 515)
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 32, 32), Tensor.randn(1, 77, params["ctx_dim"])), test, 0.011, 515)
|
||||
|
||||
def test_unet_resblock(self):
|
||||
model = [ResBlock(16, 24, 16) for _ in range(4)]
|
||||
@@ -70,7 +73,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
def test(t, t2):
|
||||
for l in model: t = l(t, t2)
|
||||
return t.realize()
|
||||
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 37)
|
||||
helper_test("test_unet_resblock", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.0002, 37)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_llama(self):
|
||||
@@ -82,7 +85,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27, 168, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.23, 118, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_gpt2(self):
|
||||
@@ -112,7 +115,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 103)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.017, 103)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
|
||||
def test_forward_cifar(self):
|
||||
@@ -122,7 +125,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
@TinyJit
|
||||
def run(X): return model(X)
|
||||
helper_test("forward_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), run, (1.0/48)*BS, 126)
|
||||
helper_test("forward_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), run, 0.033, 27)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
@@ -139,7 +142,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 126)
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, 0.12, 126)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_train_cifar_hyp(self):
|
||||
@@ -176,7 +179,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
for v in data.values(): v.to_(Device.DEFAULT)
|
||||
|
||||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 427)
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.31, 400)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user