From teeny (#2426)

* changes from teenygrad work

* support not supporting ImageDType/PtrDType

* fixups from teeny
This commit is contained in:
George Hotz
2023-11-24 12:50:56 -08:00
committed by GitHub
parent 9ae83fba04
commit 8ff2e13550
11 changed files with 24 additions and 22 deletions

View File

@@ -1,17 +1,16 @@
import unittest
import numpy as np
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX
from tinygrad.ops import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from extra.utils import OSX, temp
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and not OSX
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
@@ -113,6 +112,7 @@ class TestBFloat16DType(unittest.TestCase):
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
def test_bf16_disk_write_read(self):
from extra.utils import temp
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
@@ -173,17 +173,20 @@ class TestBoolDtype(TestDType): DTYPE = dtypes.bool
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
assert dtypes.float == dtypes.float32, "float doesn't match?"
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior?
assert PtrDType(dtypes.float32) == dtypes.float32
#assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
#assert PtrDType(dtypes.float32) != dtypes.float32
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")