mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 13:15:01 -05:00
From teeny (#2426)
* changes from teenygrad work * support not supporting ImageDType/PtrDType * fixups from teeny
This commit is contained in:
@@ -1,17 +1,16 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from extra.utils import OSX, temp
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
|
||||
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
|
||||
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX
|
||||
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
|
||||
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
|
||||
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
|
||||
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
|
||||
@@ -113,6 +112,7 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
def test_bf16_disk_write_read(self):
|
||||
from extra.utils import temp
|
||||
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
@@ -173,17 +173,20 @@ class TestBoolDtype(TestDType): DTYPE = dtypes.bool
|
||||
|
||||
class TestEqStrDType(unittest.TestCase):
|
||||
def test_image_ne(self):
|
||||
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
||||
assert dtypes.float == dtypes.float32, "float doesn't match?"
|
||||
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
|
||||
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
|
||||
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
|
||||
def test_ptr_ne(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
# TODO: is this the wrong behavior?
|
||||
assert PtrDType(dtypes.float32) == dtypes.float32
|
||||
#assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
|
||||
#assert PtrDType(dtypes.float32) != dtypes.float32
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user