mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
hotfix amd bf16 is supported case (#10039)
* hotfix amd and amd_llvm * bf16 not supported in ci * hotfix amd_llvm is not a device * remove default * dont gate on ci and amd_llvm * minor cleanup * skip bf16 tc test for amd_llvm
This commit is contained in:
@@ -1061,6 +1061,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if CI and getenv("AMD_LLVM") and tc.dtype_in is dtypes.bfloat16: continue # TODO: compilation error in CI
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
@@ -1099,6 +1100,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_tensor_cores_padded_amd(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if CI and getenv("AMD_LLVM") and tc.dtype_in is dtypes.bfloat16: continue # TODO: compilation error in CI
|
||||
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
|
||||
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user