mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move is_dtype_supported to test.helpers (#3762)
* move is_dtype_supported to test.helpers updated all places that check if float16 is supports * fix tests
This commit is contained in:
@@ -2,7 +2,8 @@ import pathlib, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import Timing, CI, fetch, temp, getenv
|
||||
from tinygrad.helpers import Timing, fetch, temp, getenv
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def compare_weights_both(url):
|
||||
import torch
|
||||
@@ -25,10 +26,7 @@ class TestTorchLoad(unittest.TestCase):
|
||||
# pytorch zip format
|
||||
def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
@unittest.skipIf(Device.DEFAULT in ["GPU", "LLVM", "CUDA"] and CI, "fp16 broken in some backends")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
|
||||
|
||||
# pytorch tar format
|
||||
|
||||
Reference in New Issue
Block a user