mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
add tests to the type spec
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import operator
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -7,6 +8,9 @@ from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_types if dtypes.is_float(dt)]
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
# for GPU, cl_khr_fp16 isn't supported
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@@ -305,8 +309,10 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
|
||||
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
floats = [dt for dt in core_types if dtypes.is_float(dt)]
|
||||
@given(st.sampled_from(core_types), st.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
||||
Reference in New Issue
Block a user