mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
WebGPU f16 support (f16 bounty part 2) (#8653)
* WebGPU f16 support * Don't enable f16 yet * dtype tests passing after bitcast fix * Maybe all WebGPU green? * Require shader-f16 in examples * Minor wgsl touchup * 1 line shorter * Simpler * Add transcendetal support * log2 nan location mismatch on Vulkan * Nan skips
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.helpers import Context, getenv, CI
|
||||
from test.test_schedule import check_schedule
|
||||
from test.test_dtype_alu import ht, dtypes_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
import numpy as np
|
||||
import math
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -25,22 +26,29 @@ class TestTranscendentalMath(unittest.TestCase):
|
||||
atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x
|
||||
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
|
||||
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
|
||||
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
|
||||
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
|
||||
def test_float32(self, x, op):
|
||||
# wrong nan behavior on Vulkan
|
||||
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU": return
|
||||
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
|
||||
np.testing.assert_allclose(op[0](Tensor([x], dtype=dtypes.float32)).numpy(),
|
||||
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float32))),
|
||||
atol=2e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
|
||||
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
|
||||
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
|
||||
def test_float16(self, x, op):
|
||||
# wrong nan behavior on Vulkan
|
||||
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU": return
|
||||
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
|
||||
np.testing.assert_allclose(op[0](Tensor([x], dtype=dtypes.float16)).numpy(),
|
||||
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float16))),
|
||||
atol=1e-2, rtol=5e-3) # exp can have bigger rtol
|
||||
|
||||
@given(strat.sampled_from([(dtypes.float64, 709.5), (dtypes.float32, 88.7), (dtypes.float16, 11)]))
|
||||
@given(strat.sampled_from([(dtypes.float64, 709.5), (dtypes.float32, 88.7), (dtypes.float16, 11)] if Device.DEFAULT != "WEBGPU"
|
||||
else [(dtypes.float64, 709.5), (dtypes.float32, 88.3), (dtypes.float16, 10.7)]))
|
||||
def test_exp_near_inf(self, dtype_x):
|
||||
# reordering compute might return inf
|
||||
dtype, x = dtype_x
|
||||
@@ -52,6 +60,7 @@ class TestTranscendentalMath(unittest.TestCase):
|
||||
|
||||
class TestFromFuzzer(unittest.TestCase):
|
||||
@given(strat.sampled_from(dtypes_float))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_sin(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype == dtypes.float64:
|
||||
@@ -74,6 +83,7 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
_test_value(np.pi * 2, unit=1.5)
|
||||
|
||||
@given(strat.sampled_from(dtypes_float))
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan location mismatch on Vulkan, Metal works")
|
||||
def test_log2(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
if dtype == dtypes.float64:
|
||||
@@ -93,6 +103,7 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
_test_value(0.0000009)
|
||||
|
||||
class TestTranscendentalSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_transcendental_sin_fusion(self):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
a = Tensor.empty(10)
|
||||
|
||||
Reference in New Issue
Block a user