mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start using Invalid in image_conv2d (#15208)
This commit is contained in:
committed by
GitHub
parent
ecbddfcffe
commit
25d86ec9e1
@@ -811,6 +811,26 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(cnt1, 5)
|
||||
self.assertEqual(cnt2, 5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
def test_image_f16_residual_fusion(self):
|
||||
with Context(FLOAT16=1, OPENPILOT_HACKS=1):
|
||||
def cnt():
|
||||
inp = Tensor.empty((512,), dtype='float')
|
||||
b1, b2 = Tensor.empty((512, 1024), dtype='float'), Tensor.empty((1024, 512), dtype='float')
|
||||
c1, c2 = Tensor.empty((1024,), dtype='float'), Tensor.empty((512,), dtype='float')
|
||||
rb = (((((inp @ b1) + c1).relu() @ b2) + c2).relu() + inp).relu()
|
||||
b16, c16 = Tensor.empty((512, 16), dtype='float'), Tensor.empty((16,), dtype='float')
|
||||
b32, c32 = Tensor.empty((512, 32), dtype='float'), Tensor.empty((32,), dtype='float')
|
||||
sched = Tensor.schedule((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu())
|
||||
for si in sched: si.lower()
|
||||
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
|
||||
|
||||
with Context(IMAGE=1): cnt1 = cnt()
|
||||
with Context(IMAGE=2): cnt2 = cnt()
|
||||
|
||||
self.assertEqual(cnt1, 9)
|
||||
self.assertEqual(cnt2, 9)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
@unittest.expectedFailure
|
||||
def test_image_conv_fusion(self):
|
||||
|
||||
@@ -3655,11 +3655,15 @@ class Tensor(OpMixin):
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE == 1:
|
||||
# pad with Invalid
|
||||
def _invalid_pad_to(t, shape):
|
||||
if all(p is None or p == s for p,s in zip(shape, t.shape)): return t
|
||||
return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid)
|
||||
# hacks for pitch alignment
|
||||
assert isinstance(ix, int) and isinstance(H, int)
|
||||
ALIGN = 64 // dtsz
|
||||
x = x.pad_to(None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None)
|
||||
w = w.pad_to((None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2))
|
||||
x = _invalid_pad_to(x, (None, None, round_up(ix, ALIGN // math.gcd(groups * cin, ALIGN)), None))
|
||||
w = _invalid_pad_to(w, (None, round_up(H, ALIGN // math.gcd(W * cin * 4, ALIGN))) + (None,) * (w.ndim - 2))
|
||||
|
||||
if FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
|
||||
else: x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
Reference in New Issue
Block a user