mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
faster image load (#13977)
sometimes image load does not need to init with NAN
This commit is contained in:
@@ -452,5 +452,23 @@ class TestImageSimplification(unittest.TestCase):
|
||||
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
|
||||
self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)")
|
||||
|
||||
class TestUnfoldableImageChannelSelection(unittest.TestCase):
|
||||
def _count_nans(self, load):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
result = full_rewrite_to_sink(load.sink()).src[0]
|
||||
return sum(1 for u in result.toposort() if u.op is Ops.CONST and u.arg != u.arg)
|
||||
|
||||
def test_bounded_channel_no_nan(self):
|
||||
# unfoldable image load with bounded idx % 4 range [0,1] -> no NAN fallback needed
|
||||
lidx = Special("lidx", 2)
|
||||
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(lidx, ptr=True), UOp.const(dtypes.float, 0)))
|
||||
self.assertEqual(self._count_nans(load), 0)
|
||||
|
||||
def test_unbounded_channel_has_nan(self):
|
||||
# variable with negative range -> x % 4 can be negative -> needs NAN fallback
|
||||
x = Variable("x", -10, 10)
|
||||
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0)))
|
||||
self.assertEqual(self._count_nans(load), 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -197,7 +197,12 @@ def image_fixup(ls:UOp):
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % image_dtype.shape[1], (x // (4*image_dtype.shape[1]))))
|
||||
idx = idx.replace(src=(idx.src[0], oidx.valid(valid)))
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
return functools.reduce(lambda ret, i: (x % 4).ne(i).where(ret, vec_load.gep(i)), range(4), ls.const_like(float('nan')))
|
||||
# image pixels have 4 channels (.xyzw), select channel based on x % 4
|
||||
x_mod_4 = x % 4
|
||||
def sel(ret, i): return x_mod_4.ne(i).where(ret, vec_load.gep(i))
|
||||
# if x is non-negative, x % 4 is in [0, 3] and we can skip NAN fallback
|
||||
if x_mod_4.vmin >= 0: return functools.reduce(sel, range(x_mod_4.vmin+1, x_mod_4.vmax+1), vec_load.gep(x_mod_4.vmin))
|
||||
return functools.reduce(sel, range(4), ls.const_like(float('nan')))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user