uop_given_valid uses less simplify (#12612)

* uop_given_valid uses less simplify

* enable test
This commit is contained in:
Sieds Lykles
2025-10-11 10:57:39 +02:00
committed by GitHub
parent 9205527db0
commit dccdd190aa
3 changed files with 9 additions and 10 deletions

View File

@@ -390,7 +390,6 @@ class TestMultiTensor(unittest.TestCase):
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
@unittest.skip("TODO: pm_rangeify hangs")
def test_data_parallel_resnet(self):
from extra.models.resnet import ResNet18

View File

@@ -249,11 +249,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop evaluation ***
def simplify(self, tracked=False):
def simplify(self, tracked=False, full_symbolic=True):
# late import!
from tinygrad.uop.symbolic import symbolic
from tinygrad.uop.symbolic import symbolic, commutative
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic, name="simplify")
return graph_rewrite(self, symbolic if full_symbolic else commutative, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"

View File

@@ -397,7 +397,7 @@ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# return simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
@@ -414,10 +414,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
for expr,v in bounds.items():
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
# some expr has lower bound > upper bound -> valid is an empty set and we return None
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
if try_simplex and expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
# try checking the whole clause
@@ -425,7 +424,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
newuops = [uop.substitute({X:newX}) for X,newX in candidate]
if any(u is uop for u in newuops): continue # if any branch doesnt appear in uop, skip
newuops = [u.simplify().substitute({newX:X}).simplify(full_symbolic=False) for (X,newX),u in zip(candidate,newuops)]
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
@@ -469,8 +470,7 @@ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if
(newx:=uop_given_valid(cond, x)) is not x else None),
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),