faster image load (#13977)

sometimes image load does not need to init with NAN
This commit is contained in:
chenyu
2026-01-04 13:09:59 -05:00
committed by GitHub
parent 7ebda28692
commit cfb8bf5814
2 changed files with 24 additions and 1 deletions

View File

@@ -452,5 +452,23 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)")
class TestUnfoldableImageChannelSelection(unittest.TestCase):
def _count_nans(self, load):
with Context(NOOPT=1, SPEC=0):
result = full_rewrite_to_sink(load.sink()).src[0]
return sum(1 for u in result.toposort() if u.op is Ops.CONST and u.arg != u.arg)
def test_bounded_channel_no_nan(self):
# unfoldable image load with bounded idx % 4 range [0,1] -> no NAN fallback needed
lidx = Special("lidx", 2)
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(lidx, ptr=True), UOp.const(dtypes.float, 0)))
self.assertEqual(self._count_nans(load), 0)
def test_unbounded_channel_has_nan(self):
# variable with negative range -> x % 4 can be negative -> needs NAN fallback
x = Variable("x", -10, 10)
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0)))
self.assertEqual(self._count_nans(load), 1)
if __name__ == '__main__':
unittest.main()

View File

@@ -197,7 +197,12 @@ def image_fixup(ls:UOp):
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1]))))
idx = idx.replace(src=(idx.src[0], oidx.valid(valid)))
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
# image pixels have 4 channels (.xyzw), select channel based on x % 4
x_mod_4 = x % 4
def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i))
# if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback
if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin))
return functools.reduce(sel, range(4), ls.const_like(float('nan')))
return None