diff --git a/test/test_dtype.py b/test/test_dtype.py index 5d6ab91c44..a9b3c9417e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -3,7 +3,7 @@ import numpy as np import torch from typing import Any, List from tinygrad.device import is_dtype_supported -from tinygrad.helpers import getenv, DEBUG, CI, AMD_LLVM +from tinygrad.helpers import getenv, DEBUG, CI from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype from tinygrad import Device, Tensor, dtypes from hypothesis import assume, given, settings, strategies as strat @@ -428,7 +428,6 @@ class TestOpsBFloat16(unittest.TestCase): data = [60000.0, 70000.0, 80000.0] np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy()) - @unittest.skipIf(Device.DEFAULT == "AMD" and AMD_LLVM, "AMD_LLVM failed on this") def test_no_approximation(self): data = [326.0, 339.0, 10603200512.0] expected = torch.tensor(data, dtype=torch.bfloat16).sqrt().float().numpy() diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index a75808719f..47ed182bf5 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -196,14 +196,20 @@ class LLVMRenderer(Renderer): barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n' code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()", "l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"} +# https://rocm.docs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPUUsage.html#llvm-ir-intrinsics +# llvm.log2/llvm.exp2 don't support double +llvm_intrinsics = {Ops.SQRT: "sqrt"} class AMDLLVMRenderer(LLVMRenderer): device = "AMD" has_local = True shared_max = AMDRenderer.shared_max global_max = AMDRenderer.global_max abi = "amdgpu_kernel" + code_for_op = {**LLVMRenderer.code_for_op, **{op: lambda: None for op in llvm_intrinsics}} string_rewrite = PatternMatcher([ (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "), + (UPat(tuple(llvm_intrinsics), name="x"), + lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BARRIER), lambda ctx: barrier), ]) + base_rewrite extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([