mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
* `global_load` and `global_store` using buffer dtype * `UOps.PHI` in all dtypes * `UOps.ALU` in all dtypes * `UOps.CONST` & `UOps.DEFINE_ACC` in all dtypes * -- endof implementation -- +tiny lint changes * these tests require the fp16 extention you can run them locally to confirm they're green: (GPT2 test is broken in master for mac, see [this](https://discord.com/channels/1068976834382925865/1069001075828469790/1177993277958533261) `GPU=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_max_float16_cpu test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_min_float16_cpu test/models/test_real_world.py::TestRealWorld::test_llama test/models/test_real_world.py::TestRealWorld::test_gpt2 test/models/test_whisper.py test/test_specific_conv.py::TestSpecific::test_big_vec_mul` skip the new test_linearizer_failures in CI GPU because of the fp16 extention This passes on a real GPU since the extention is available: `GPU=1 python3 -m pytest test/test_linearizer_failures.py::TestLinearizerFailures::test_failure_8` see CI logs [here](https://github.com/tinygrad/tinygrad/actions/runs/6996590597/job/19032641427#step:14:644) * these tests fail in CI due to segfaults and CPU crashes To confirm they're green locally, you can run the following commands: 1. For the tests skipped in test_ops.py (note: CLANG is very slow) `for var in GPU CUDA CLANG; do export $var=1; for test in test/test_ops.py::TestOps::test_slice_fancy_indexing_no_dim_collapse test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_collapse_int test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_none test/test_ops.py::TestOps::test_slice_fancy_indexing_dim_inject_and_collapse; do python3 -m pytest $test; done; unset $var; done` 2. For the ONNX tests skipped in CLANG: ``` CLANG=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_ai_onnx_ml_array_feature_extractor_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_0_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_1_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_3d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_gather_elements_negative_indices_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_NCd1d2d3d4d5_mean_weight_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_sce_mean_weight_ii_4d_log_prob_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1_weight_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_sum_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_reduction_mean_expanded_cpu \ test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_nllloss_NCd1d2_with_weight_expanded_cpu ``` 3. The LLVM test I skipped here is already [skipped in master for all backends](https://github.com/tinygrad/tinygrad/blob/master/test/external/external_test_onnx_backend.py#L186), I just made it more specific `LLVM=1 python3 -m pytest test/external/external_test_onnx_backend.py::OnnxBackendNodeModelTest::test_dequantizelinear_e4m3fn_float16_cpu` * Revert "these tests fail in CI due to segfaults and CPU crashes" This reverts commit15db570143. * merge with cleanup-vectorized-hip-renders * barely working HIP P1, ALU ops need a refactor? * manage the fact that in HIP [half2 is actually an unsigned int vec](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L59)) and half is a totally different __half that [has an unsigned int element in it](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L50)) but can't be accessed [because it's private](f921880387/hip/include/hip/amd_detail/amd_hip_fp16.h (L86)). If you just do this: ``` half2 val0 = // ... half val1 = // ... ``` then you can't do: ``` val0.x + val1 // error: use of overloaded operator '+' is ambiguous (with operand types 'unsigned short' and 'half' (aka '__half')) ``` * update the sign definition to avoid division by zero in all dtypes * diff cleanup p1: why were these in the diff anyways * less hacky HIP, enable CIFAR fp16 benchmark, test ops for HIP in CI! add ALU ops overloads for HIP this will make HIP max work handle mod Revert "handle mod" This reverts commit 370fd4b3fbe99b6ae8cc293d005b106628205933. update max to use hmax add HIP GEP render logic enable CIFAR fp16 benchmark test ops for HIP back to store as float because this only works for float4 grouping right now test_ops for hip!! always sign * back to the sign we had before because we cant do a backward pass on a Less node * remove old hacks HIP compiling test_ops in CI takes ~9 mins, not doing it for now new HIP ALUs * reduce accs done right * refactor to function * no device hacks hacks p2 the other way * LLVM ALU ops half, float and double are all float update max * update test_uops, cmplt is always a bool in the real linearizer. assertAlmostEqual is wrong when ret is bool * cleanup LLVM wrong code * dummy change for the CUDA install glitch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
105 lines
4.5 KiB
Python
105 lines
4.5 KiB
Python
import unittest, time
|
|
import numpy as np
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.nn import optim
|
|
from tinygrad.nn.state import get_parameters
|
|
from tinygrad.jit import TinyJit
|
|
from tinygrad import Device, GlobalCounters
|
|
from tinygrad.helpers import CI, dtypes
|
|
from tinygrad.shape.symbolic import Variable
|
|
from test.helpers import derandomize_model
|
|
|
|
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
|
from examples.hlb_cifar10 import SpeedyResNet
|
|
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
|
from examples.stable_diffusion import UNetModel
|
|
|
|
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
|
tms = []
|
|
for _ in range(4):
|
|
GlobalCounters.reset()
|
|
GlobalCounters.mem_used = 0
|
|
Device[Device.DEFAULT].synchronize()
|
|
st = time.perf_counter_ns()
|
|
train(*gen())
|
|
Device[Device.DEFAULT].synchronize()
|
|
tms.append(time.perf_counter_ns() - st)
|
|
|
|
# TODO: jit should expose this correctly with graph
|
|
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
|
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
|
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
|
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
|
if all_jitted:
|
|
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
|
|
|
|
class TestRealWorld(unittest.TestCase):
|
|
def setUp(self):
|
|
self.old_type = Tensor.default_type
|
|
np.random.seed(2002)
|
|
|
|
def tearDown(self):
|
|
Tensor.default_type = self.old_type
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
|
@unittest.skipIf(CI, "too big for CI")
|
|
def test_stable_diffusion(self):
|
|
model = UNetModel()
|
|
derandomize_model(model)
|
|
@TinyJit
|
|
def test(t, t2): return model(t, 801, t2).realize()
|
|
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
|
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
|
|
def test_llama(self):
|
|
Tensor.default_type = dtypes.float16
|
|
|
|
args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
|
model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
|
|
derandomize_model(model)
|
|
@TinyJit
|
|
def test(t): return model(t, 0).realize()
|
|
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
|
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
|
|
|
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
|
def test_gpt2(self):
|
|
Tensor.default_type = dtypes.float16
|
|
|
|
args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000}
|
|
model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]))
|
|
derandomize_model(model)
|
|
@TinyJit
|
|
def test(t, v): return model(t, v).realize()
|
|
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 180 if CI else 516, all_jitted=True)
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
|
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
|
def test_train_cifar(self):
|
|
# TODO: with default device
|
|
#old_default = Device.DEFAULT
|
|
#Device.DEFAULT = "FAKE"
|
|
#Device['fake'].codegen = Device[old_default].codegen
|
|
|
|
with Tensor.train():
|
|
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
|
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
|
|
|
BS = 32 if CI else 512
|
|
|
|
@TinyJit
|
|
def train(X):
|
|
out = model(X)
|
|
loss = out.mean()
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal
|
|
|
|
# reset device
|
|
#Device.DEFAULT = old_default
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |