start using Invalid in image_conv2d (#15208)

This commit is contained in:
Christopher Milan
2026-03-10 04:11:06 -07:00
committed by GitHub
parent ecbddfcffe
commit 25d86ec9e1
2 changed files with 26 additions and 2 deletions

View File

@@ -811,6 +811,26 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(cnt1, 5)
self.assertEqual(cnt2, 5)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
def test_image_f16_residual_fusion(self):
with Context(FLOAT16=1, OPENPILOT_HACKS=1):
def cnt():
inp = Tensor.empty((512,), dtype='float')
b1, b2 = Tensor.empty((512, 1024), dtype='float'), Tensor.empty((1024, 512), dtype='float')
c1, c2 = Tensor.empty((1024,), dtype='float'), Tensor.empty((512,), dtype='float')
rb = (((((inp @ b1) + c1).relu() @ b2) + c2).relu() + inp).relu()
b16, c16 = Tensor.empty((512, 16), dtype='float'), Tensor.empty((16,), dtype='float')
b32, c32 = Tensor.empty((512, 32), dtype='float'), Tensor.empty((32,), dtype='float')
sched = Tensor.schedule((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu())
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1): cnt1 = cnt()
with Context(IMAGE=2): cnt2 = cnt()
self.assertEqual(cnt1, 9)
self.assertEqual(cnt2, 9)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@unittest.expectedFailure
def test_image_conv_fusion(self):

View File

@@ -3655,11 +3655,15 @@ class Tensor(OpMixin):
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE == 1:
# pad with Invalid
def _invalid_pad_to(t, shape):
if all(p is None or p == s for p,s in zip(shape, t.shape)): return t
return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid)
# hacks for pitch alignment
assert isinstance(ix, int) and isinstance(H, int)
ALIGN = 64 // dtsz
x = x.pad_to(None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None)
w = w.pad_to((None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2))
x = _invalid_pad_to(x, (None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None))
w = _invalid_pad_to(w, (None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2))
if FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
else: x, w = x.contiguous(), w.contiguous()