diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6d7d93981d..d5f634c318 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -227,6 +227,7 @@ jobs: - name: Test IMAGE=2 support run: | IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm + IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16 IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d - name: Test emulated METAL tensor cores run: | diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 08d2c04c32..6ca7d24568 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -141,6 +141,13 @@ class TestImageDType(unittest.TestCase): self.assertEqual(w1.grad.uop.base.buffer.dtype, dtypes.float32) self.assertEqual(len(sched), 10) + def test_gemm_fp16_image_path_dtype(self): + with Context(IMAGE=2): + x = Tensor.rand(64, 64) + y = Tensor.rand(64, 64) + z = x.half().matmul(y.half()) # don't realize + assert z.dtype == dtypes.half, f"Expected half, got {z.dtype}" + @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageRealization(unittest.TestCase): def test_image_dtype_expand(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9bff0816b6..12a1643e5a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -4352,6 +4352,7 @@ class Tensor(MathTrait): # NCHW output ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) + if dtype is None and (ret_dtype := least_upper_dtype(self.dtype, weight.dtype)) in (dtypes.float16, dtypes.bfloat16): ret = ret.cast(ret_dtype) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) P = ParamSpec("P")