diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 83dfec777d..f246642dde 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -811,6 +811,26 @@ class TestSchedule(unittest.TestCase): self.assertEqual(cnt1, 5) self.assertEqual(cnt2, 5) + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") + def test_image_f16_residual_fusion(self): + with Context(FLOAT16=1, OPENPILOT_HACKS=1): + def cnt(): + inp = Tensor.empty((512,), dtype='float') + b1, b2 = Tensor.empty((512, 1024), dtype='float'), Tensor.empty((1024, 512), dtype='float') + c1, c2 = Tensor.empty((1024,), dtype='float'), Tensor.empty((512,), dtype='float') + rb = (((((inp @ b1) + c1).relu() @ b2) + c2).relu() + inp).relu() + b16, c16 = Tensor.empty((512, 16), dtype='float'), Tensor.empty((16,), dtype='float') + b32, c32 = Tensor.empty((512, 32), dtype='float'), Tensor.empty((32,), dtype='float') + sched = Tensor.schedule((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu()) + for si in sched: si.lower() + return len([si for si in sched if isinstance(si.prg, CompiledRunner)]) + + with Context(IMAGE=1): cnt1 = cnt() + with Context(IMAGE=2): cnt2 = cnt() + + self.assertEqual(cnt1, 9) + self.assertEqual(cnt2, 9) + @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") @unittest.expectedFailure def test_image_conv_fusion(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 91e7ded7be..65d71269f1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3655,11 +3655,15 @@ class Tensor(OpMixin): # contiguous creates the image, and early realize static weights (TODO: test for the static weight) if IMAGE == 1: + # pad with Invalid + def _invalid_pad_to(t, shape): + if all(p is None or p == s for p,s in zip(shape, t.shape)): return t + return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid) # hacks for pitch alignment assert isinstance(ix, int) and isinstance(H, int) ALIGN = 64 // dtsz - x = x.pad_to(None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None) - w = w.pad_to((None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2)) + x = _invalid_pad_to(x, (None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None)) + w = _invalid_pad_to(w, (None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2)) if FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float) else: x, w = x.contiguous(), w.contiguous()