use precise::sin in metal (#5307)

This commit is contained in:
chenyu
2024-07-06 12:47:27 -04:00
committed by GitHub
parent 5c2ca7bad4
commit ca0ef1700b
2 changed files with 7 additions and 1 deletions

View File

@@ -160,5 +160,11 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
class TestFromFuzzer(unittest.TestCase):
@given(strat.sampled_from(dtypes_float))
def test_sin(self, dtype):
if not is_dtype_supported(dtype): return
np.testing.assert_allclose(Tensor([25]).cast(dtype=dtype).sin().numpy(), np.sin(np.array([25])), rtol=3e-4)
if __name__ == '__main__':
unittest.main()

View File

@@ -232,7 +232,7 @@ class MetalRenderer(CStyleLanguage):
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)