mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use precise::sin in metal (#5307)
This commit is contained in:
@@ -160,5 +160,11 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
||||
|
||||
class TestFromFuzzer(unittest.TestCase):
|
||||
@given(strat.sampled_from(dtypes_float))
|
||||
def test_sin(self, dtype):
|
||||
if not is_dtype_supported(dtype): return
|
||||
np.testing.assert_allclose(Tensor([25]).cast(dtype=dtype).sin().numpy(), np.sin(np.array([25])), rtol=3e-4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -232,7 +232,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
Reference in New Issue
Block a user