mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
match any statement in valid for simplification (#6554)
This commit is contained in:
@@ -32,6 +32,15 @@ class TestValidSimplification(unittest.TestCase):
|
|||||||
self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
self.assertEqual(render((10, 10, 4), (gidx1*(-1)).lt(1), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
||||||
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
||||||
|
|
||||||
|
# should match any one of the AND clause and drop the matched statement from valid
|
||||||
|
valid = (gidx1*(-1)).lt(0) and (gidx0*(-1)).lt(0)
|
||||||
|
self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))),
|
||||||
|
"(((gidx0*(-1))<0)?read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1)))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||||
|
|
||||||
|
valid = (gidx1*(-1)).lt(0) and (gidx1*(-1)).lt(0)
|
||||||
|
self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))),
|
||||||
|
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))")
|
||||||
|
|
||||||
def test_idx_lt_bound(self):
|
def test_idx_lt_bound(self):
|
||||||
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
|
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
|
||||||
gidx0 = Variable("gidx0", 32)
|
gidx0 = Variable("gidx0", 32)
|
||||||
|
|||||||
@@ -159,15 +159,18 @@ def fold_unrolled_divs(divs:UOp, c:UOp):
|
|||||||
|
|
||||||
def simplify_valid_image_load(load:UOp, buf:UOp):
|
def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||||
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
|
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
|
||||||
buf, idx, _, valid = load.src
|
buf, idx, invalid_val, valid = load.src
|
||||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT:
|
drop = False
|
||||||
if graph_rewrite(valid.src[0]*(-1)-1+valid.src[1].arg, constant_folder).key == idx.src[1].key:
|
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||||
|
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT:
|
||||||
# valid: A*(-1) < c, idx: (..., A-1+c) -> okay to drop valid because A*(-1) >= c -> A <= -c -> A-1+c <= -1 is out of bound
|
# valid: A*(-1) < c, idx: (..., A-1+c) -> okay to drop valid because A*(-1) >= c -> A <= -c -> A-1+c <= -1 is out of bound
|
||||||
return UOp(UOps.LOAD, dtype=load.dtype, src=(buf, idx))
|
if graph_rewrite(stmt.src[0]*(-1)-1+stmt.src[1].arg, constant_folder).key == idx.src[1].key: drop = True
|
||||||
|
|
||||||
if valid.src[1].arg == buf_dtype.shape[0] and idx.src[1].key == valid.src[0].key:
|
|
||||||
# valid: A < image bound, idx: (..., A) -> okay to drop valid
|
# valid: A < image bound, idx: (..., A) -> okay to drop valid
|
||||||
return UOp(UOps.LOAD, dtype=load.dtype, src=(buf, idx))
|
elif stmt.src[1].arg == buf_dtype.shape[0] and idx.src[1].key == stmt.src[0].key: drop = True
|
||||||
|
|
||||||
|
if drop:
|
||||||
|
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None
|
||||||
|
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))
|
||||||
|
|
||||||
# ***** transcendental *****
|
# ***** transcendental *****
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user