IMAGE=1 kernel count failing tests (#15083)

This commit is contained in:
Christopher Milan
2026-03-02 01:35:26 -08:00
committed by GitHub
parent 3539693555
commit 977c270774
3 changed files with 36 additions and 3 deletions

View File

@@ -795,6 +795,38 @@ class TestSchedule(unittest.TestCase):
self.assertIsNotNone(out.uop.base.realized)
self.assertIsInstance(out.uop.base.realized.dtype, ImageDType)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@unittest.expectedFailure
def test_image_dot_f16_fusion(self):
with Context(FLOAT16=1):
def cnt():
x, y, z = Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float')
a = (x @ y).relu()
sched = ((a @ z).relu() + a).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1): cnt1 = cnt()
with Context(IMAGE=2): cnt2 = cnt()
self.assertEqual(cnt1, cnt2)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@unittest.expectedFailure
def test_image_conv_fusion(self):
def cnt():
x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7))
a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1)
b = a.conv2d(z, groups=4, padding=3)
sched = (a + b).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1): cnt1 = cnt()
with Context(IMAGE=2): cnt2 = cnt()
self.assertEqual(cnt1, cnt2)
def _test_fusion(self, shapes, f, cnt):
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
run_schedule(check_schedule(compare:=f(*args), cnt))

View File

@@ -172,7 +172,8 @@ class ContextVar(Generic[T]):
assert isinstance(self.value, str)
return [getattr(obj, x) if obj else x for x in self.value.split(',') if x]
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
DEBUG, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
IMAGE, FLOAT16 = ContextVar("IMAGE", 0), ContextVar("FLOAT16", 0)
JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVar("JIT_BATCH_SIZE", 32)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@@ -3598,7 +3598,7 @@ class Tensor(OpMixin):
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
base_image_type, dtsz = (dtypes.imageh, 2) if (FLOAT16:=getenv("FLOAT16", 0)) else (dtypes.imagef, 4)
base_image_type, dtsz = (dtypes.imageh, 2) if FLOAT16 else (dtypes.imagef, 4)
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)