move is_dtype_supported to test.helpers (#3762)

* move is_dtype_supported to test.helpers

updated all places that check if float16 is supports

* fix tests
This commit is contained in:
chenyu
2024-03-15 14:33:26 -04:00
committed by GitHub
parent 8af87e20a0
commit a2d3cf64a5
11 changed files with 42 additions and 40 deletions

View File

@@ -47,7 +47,7 @@ jobs:
DEBUG=2 EMULATE_HIP=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py
run: PYTHONPATH=. DEBUG=2 PYTHON=1 python3 test/test_dtype.py
- name: Test ops with Python emulator
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)" --durations=20
- name: Test symbolic with Python emulator

View File

@@ -3,9 +3,9 @@ from typing import Any, Tuple
from onnx.backend.base import Backend, BackendRep
import onnx.backend.test
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, CI, OSX
from tinygrad.device import Device
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, OSX
from test.helpers import is_dtype_supported
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
@@ -49,7 +49,7 @@ backend_test.exclude('test_adam_multiple_cpu')
backend_test.exclude('test_nesterov_momentum_cpu')
# about different dtypes
if Device.DEFAULT in ["METAL"] or (OSX and Device.DEFAULT == "GPU"):
if not is_dtype_supported(dtypes.float64):
backend_test.exclude('float64')
backend_test.exclude('DOUBLE')
# these have float64 inputs
@@ -59,8 +59,7 @@ if Device.DEFAULT in ["METAL"] or (OSX and Device.DEFAULT == "GPU"):
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_cumsum_*')
# no float16 in CI, LLVM segfaults, GPU requires cl_khr_fp16
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
if not is_dtype_supported(dtypes.float16):
backend_test.exclude('float16')
backend_test.exclude('FLOAT16')

View File

@@ -1,7 +1,9 @@
import sys
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import JITRunner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad import Tensor
from tinygrad.helpers import Context
from tinygrad.helpers import Context, CI, OSX
def derandomize_model(model):
with Context(GRAPH=0):
@@ -17,3 +19,18 @@ def assert_jit_cache_len(fxn, expected_len):
else:
assert len(fxn.jit_cache) == 1
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in ["HIP"]
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device in ["GPU", "LLVM", "CUDA"]: return not CI
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True

View File

@@ -7,8 +7,7 @@ from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.helpers import CI, getenv
from tinygrad.shape.symbolic import Variable
from extra.lr_scheduler import OneCycleLR
from test.helpers import derandomize_model
from test.test_dtype import is_dtype_supported
from test.helpers import derandomize_model, is_dtype_supported
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
from examples.hlb_cifar10 import SpeedyResNet, hyp

View File

@@ -2,7 +2,8 @@ import unittest
import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI, fetch
from tinygrad import Device
from tinygrad import Device, dtypes
from test.helpers import is_dtype_supported
# Audio generated with the command on MacOS:
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
@@ -15,7 +16,8 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501
@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16")
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG"], "slow")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):

View File

@@ -1,11 +1,12 @@
import unittest, operator, sys
import unittest, operator
import numpy as np
import torch
from typing import Any, List
from tinygrad.helpers import CI, getenv, DEBUG, OSX
from tinygrad.helpers import getenv, DEBUG
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device, Tensor, dtypes
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
@@ -13,20 +14,6 @@ settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in ["HIP"]
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CUDA in CI uses CUDACPU that does not support half
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device in ["GPU", "LLVM", "CUDA"]: return not CI
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.realize import create_schedule
from tinygrad.ops import UnaryOps, get_lazyop_info
from test.test_dtype import is_dtype_supported
from test.helpers import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")

View File

@@ -7,7 +7,7 @@ from tinygrad.features.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
from test.external.fuzz_linearizer import run_linearizer, get_fuzz_rawbufs, get_fuzz_rawbuf_like
from test.test_dtype import is_dtype_supported
from test.helpers import is_dtype_supported
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad import Device, dtypes
from tinygrad import Tensor, Device, dtypes
from test.helpers import is_dtype_supported
# similar to test/external/external_test_gpu_ast.py, but universal
@unittest.skipIf(Device.DEFAULT == "CUDA" and CI, "slow on CUDA CI")
@@ -20,7 +20,7 @@ class TestSpecific(unittest.TestCase):
w = Tensor.randn(2048, 512)
(x @ w).reshape(1, 128, 4).contiguous().realize()
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], "Broken on LLVM and webgpu, GPU requires cl_khr_fp16")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]

View File

@@ -8,7 +8,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.realize import create_schedule
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.test_dtype import is_dtype_supported
from test.helpers import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.render("test", uops)

View File

@@ -2,7 +2,8 @@ import pathlib, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, CI, fetch, temp, getenv
from tinygrad.helpers import Timing, fetch, temp, getenv
from test.helpers import is_dtype_supported
def compare_weights_both(url):
import torch
@@ -25,10 +26,7 @@ class TestTorchLoad(unittest.TestCase):
# pytorch zip format
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
# for GPU, cl_khr_fp16 isn't supported
# for LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# pytorch tar format