diff --git a/test/test_multitensor.py b/test/test_multitensor.py index fdf07f3ca7..2fa3a614b8 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -390,7 +390,6 @@ class TestMultiTensor(unittest.TestCase): # NOTE: this is failing on LLVM CI, no idea why. Works locally. @unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU") - @unittest.skip("TODO: pm_rangeify hangs") def test_data_parallel_resnet(self): from extra.models.resnet import ResNet18 diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index d555a550da..16900e97fc 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -249,11 +249,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop evaluation *** - def simplify(self, tracked=False): + def simplify(self, tracked=False, full_symbolic=True): # late import! - from tinygrad.uop.symbolic import symbolic + from tinygrad.uop.symbolic import symbolic, commutative with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value): - return graph_rewrite(self, symbolic, name="simplify") + return graph_rewrite(self, symbolic if full_symbolic else commutative, name="simplify") def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret def _eval(self, dtype, expected_type:Type[T]) -> T: assert self.dtype in dtype, f"eval with wrong dtype {self}" diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 84580039be..b2f9988b85 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -397,7 +397,7 @@ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 raise ValueError(f"not able to parse {valid=}") -def uop_given_valid(valid:UOp, uop:UOp) -> UOp: +def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: # return simplified uop (might be the same as input) # first, parse valid into {expr: (lower_bound, upper_bound)} @@ -414,10 +414,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp: for expr,v in bounds.items(): v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop - # some expr has lower bound > upper bound -> valid is an empty set and we return None # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): + if try_simplex and expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) # try checking the whole clause @@ -425,7 +424,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp: for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop - newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] + newuops = [uop.substitute({X:newX}) for X,newX in candidate] + if any(u is uop for u in newuops): continue # if any branch doesnt appear in uop, skip + newuops = [u.simplify().substitute({newX:X}).simplify(full_symbolic=False) for (X,newX),u in zip(candidate,newuops)] if uop.op is Ops.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) @@ -469,8 +470,7 @@ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} sym = symbolic_flat+PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), - (UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if - (newx:=uop_given_valid(cond, x)) is not x else None), + (UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)), # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),