fix half4 on qcom and gpu (#8573)

* add test_setitem_half

* this fixes comma benchmark
This commit is contained in:
qazal
2025-01-12 06:23:05 -05:00
committed by GitHub
parent cff1ee9038
commit ae241e96db
2 changed files with 10 additions and 1 deletions

View File

@@ -2,6 +2,7 @@
import unittest
import numpy as np
from tinygrad import dtypes, Tensor, TinyJit, GlobalCounters, Variable
from tinygrad.device import is_dtype_supported
N = 200 # has to be bigger than the cache to fail
@@ -365,6 +366,14 @@ class TestAssign(unittest.TestCase):
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_setitem_half(self):
a = Tensor.full((8,), 1.0, dtype=dtypes.half).contiguous().realize()
b = Tensor.full((4,), 2.0, dtype=dtypes.half).contiguous().realize()
assign = a[:4].assign(b)
assign.realize()
np.testing.assert_allclose(a.numpy(), [2., 2., 2., 2., 1., 1., 1., 1.])
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)

View File

@@ -225,7 +225,7 @@ class OpenCLRenderer(CStyleLanguage):
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
class IntelRenderer(OpenCLRenderer):