From 977c2707747c67096521989004ce89989e1dc734 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Mon, 2 Mar 2026 01:35:26 -0800 Subject: [PATCH] IMAGE=1 kernel count failing tests (#15083) --- test/backend/test_schedule.py | 32 ++++++++++++++++++++++++++++++++ tinygrad/helpers.py | 3 ++- tinygrad/tensor.py | 4 ++-- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 260a2b33d2..6ad348fe66 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -795,6 +795,38 @@ class TestSchedule(unittest.TestCase): self.assertIsNotNone(out.uop.base.realized) self.assertIsInstance(out.uop.base.realized.dtype, ImageDType) + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") + @unittest.expectedFailure + def test_image_dot_f16_fusion(self): + with Context(FLOAT16=1): + def cnt(): + x, y, z = Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float') + a = (x @ y).relu() + sched = ((a @ z).relu() + a).schedule() + for si in sched: si.lower() + return len([si for si in sched if isinstance(si.prg, CompiledRunner)]) + + with Context(IMAGE=1): cnt1 = cnt() + with Context(IMAGE=2): cnt2 = cnt() + + self.assertEqual(cnt1, cnt2) + + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") + @unittest.expectedFailure + def test_image_conv_fusion(self): + def cnt(): + x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7)) + a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1) + b = a.conv2d(z, groups=4, padding=3) + sched = (a + b).schedule() + for si in sched: si.lower() + return len([si for si in sched if isinstance(si.prg, CompiledRunner)]) + + with Context(IMAGE=1): cnt1 = cnt() + with Context(IMAGE=2): cnt2 = cnt() + + self.assertEqual(cnt1, cnt2) + def _test_fusion(self, shapes, f, cnt): with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes] run_schedule(check_schedule(compare:=f(*args), cnt)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index abaf6128bd..f5eef5103d 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -172,7 +172,8 @@ class ContextVar(Generic[T]): assert isinstance(self.value, str) return [getattr(obj, x) if obj else x for x in self.value.split(',') if x] -DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) +DEBUG, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) +IMAGE, FLOAT16 = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0) JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 21c7351a01..244de0c731 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile +from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin @@ -3598,7 +3598,7 @@ class Tensor(OpMixin): return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: - base_image_type, dtsz = (dtypes.imageh, 2) if (FLOAT16:=getenv("FLOAT16", 0)) else (dtypes.imagef, 4) + base_image_type, dtsz = (dtypes.imageh, 2) if FLOAT16 else (dtypes.imagef, 4) (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)