diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 1b020e2a57..5e016bbcef 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -452,5 +452,23 @@ class TestImageSimplification(unittest.TestCase): load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1)) self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)") +class TestUnfoldableImageChannelSelection(unittest.TestCase): + def _count_nans(self, load): + with Context(NOOPT=1, SPEC=0): + result = full_rewrite_to_sink(load.sink()).src[0] + return sum(1 for u in result.toposort() if u.op is Ops.CONST and u.arg != u.arg) + + def test_bounded_channel_no_nan(self): + # unfoldable image load with bounded idx % 4 range [0,1] -> no NAN fallback needed + lidx = Special("lidx", 2) + load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(lidx, ptr=True), UOp.const(dtypes.float, 0))) + self.assertEqual(self._count_nans(load), 0) + + def test_unbounded_channel_has_nan(self): + # variable with negative range -> x % 4 can be negative -> needs NAN fallback + x = Variable("x", -10, 10) + load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0))) + self.assertEqual(self._count_nans(load), 1) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 5e6895fa1e..7d7c8ebc7c 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -197,7 +197,12 @@ def image_fixup(ls:UOp): oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1])))) idx = idx.replace(src=(idx.src[0], oidx.valid(valid))) vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) - return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan'))) + # image pixels have 4 channels (.xyzw), select channel based on x % 4 + x_mod_4 = x % 4 + def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i)) + # if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback + if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin)) + return functools.reduce(sel, range(4), ls.const_like(float('nan'))) return None