mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
update test_real_world configs (#2557)
This commit is contained in:
@@ -25,6 +25,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
|
||||
# TODO: jit should expose this correctly with graph
|
||||
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
||||
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||
@@ -40,15 +41,17 @@ class TestRealWorld(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
Tensor.default_type = self.old_type
|
||||
|
||||
@unittest.skipUnless(not CI, "too big for CI")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(CI, "too big for CI")
|
||||
def test_stable_diffusion(self):
|
||||
model = UNetModel()
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
|
||||
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
|
||||
def test_llama(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
@@ -60,7 +63,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM" and CI, "too long on CI LLVM")
|
||||
def test_gpt2(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
@@ -71,7 +74,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
def test(t, v): return model(t, v).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 180 if CI else 516, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
||||
def test_train_cifar(self):
|
||||
# TODO: with default device
|
||||
#old_default = Device.DEFAULT
|
||||
@@ -92,7 +96,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 154) # it's 154 on metal
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal
|
||||
|
||||
# reset device
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
Reference in New Issue
Block a user