move is_dtype_supported to test.helpers (#3762)

* move is_dtype_supported to test.helpers

updated all places that check if float16 is supports

* fix tests
This commit is contained in:
chenyu
2024-03-15 14:33:26 -04:00
committed by GitHub
parent 8af87e20a0
commit a2d3cf64a5
11 changed files with 42 additions and 40 deletions

View File

@@ -2,7 +2,8 @@ import pathlib, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import Timing, CI, fetch, temp, getenv
from tinygrad.helpers import Timing, fetch, temp, getenv
from test.helpers import is_dtype_supported
def compare_weights_both(url):
import torch
@@ -25,10 +26,7 @@ class TestTorchLoad(unittest.TestCase):
# pytorch zip format
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
# for GPU, cl_khr_fp16 isn't supported
# for LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# pytorch tar format