mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Run test_real_world in METAL test (#3760)
* clean up test_real_world * skip that * JIT=2 for metal * all device
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -300,6 +300,8 @@ jobs:
|
||||
METAL=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run metal test
|
||||
run: JIT=2 METAL=1 python -m pytest -n=auto test/ --ignore=test/external --ignore=test/models --durations=20
|
||||
- name: Run real world test
|
||||
run: JIT=2 METAL=1 python -m pytest -n=auto test/models/test_real_world.py --durations=20
|
||||
- name: Run ONNX
|
||||
run: JIT=2 METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test tensor core ops (fake)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import unittest, time, gc
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.features.jit import TinyJit
|
||||
from tinygrad import Device, GlobalCounters, dtypes
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from test.helpers import derandomize_model
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
from examples.hlb_cifar10 import SpeedyResNet, hyp
|
||||
@@ -16,20 +16,20 @@ from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAM
|
||||
from examples.stable_diffusion import UNetModel, ResBlock
|
||||
|
||||
global_mem_used = 0
|
||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
tms = []
|
||||
for _ in range(4):
|
||||
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
||||
GlobalCounters.reset()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
st = time.perf_counter_ns()
|
||||
train(*early_gen)
|
||||
model(*early_gen)
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
mem_used = GlobalCounters.mem_used - global_mem_used
|
||||
|
||||
# TODO: jit should expose this correctly with graph
|
||||
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
||||
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
|
||||
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||
@@ -66,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_llama(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
@@ -78,7 +78,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 192 if CI else 719, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
@unittest.skipIf(getenv("JIT"), "failed if JIT is explicitly set") # TODO: fix this
|
||||
def test_gpt2(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
|
||||
@@ -107,13 +108,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_train_cifar(self):
|
||||
# TODO: with default device
|
||||
#old_default = Device.DEFAULT
|
||||
#Device.DEFAULT = "FAKE"
|
||||
#Device['fake'].codegen = Device[old_default].codegen
|
||||
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
@@ -130,11 +126,8 @@ class TestRealWorld(unittest.TestCase):
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal
|
||||
|
||||
# reset device
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["GPU"] and CI, "opencl on intel can't compile half")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_train_cifar_hyp(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
with Tensor.train():
|
||||
@@ -147,7 +140,5 @@ class TestRealWorld(unittest.TestCase):
|
||||
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
|
||||
assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
|
||||
|
||||
dtypes.default_float = dtypes.float32
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user