mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
uop_given_valid uses less simplify (#12612)
* uop_given_valid uses less simplify * enable test
This commit is contained in:
@@ -390,7 +390,6 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
|
||||
@unittest.skip("TODO: pm_rangeify hangs")
|
||||
def test_data_parallel_resnet(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
|
||||
|
||||
@@ -249,11 +249,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self, tracked=False):
|
||||
def simplify(self, tracked=False, full_symbolic=True):
|
||||
# late import!
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.uop.symbolic import symbolic, commutative
|
||||
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
|
||||
return graph_rewrite(self, symbolic, name="simplify")
|
||||
return graph_rewrite(self, symbolic if full_symbolic else commutative, name="simplify")
|
||||
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
|
||||
@@ -397,7 +397,7 @@ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
||||
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
|
||||
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
# return simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
@@ -414,10 +414,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
|
||||
for expr,v in bounds.items():
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
|
||||
if try_simplex and expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
@@ -425,7 +424,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp:
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
||||
newuops = [uop.substitute({X:newX}) for X,newX in candidate]
|
||||
if any(u is uop for u in newuops): continue # if any branch doesnt appear in uop, skip
|
||||
newuops = [u.simplify().substitute({newX:X}).simplify(full_symbolic=False) for (X,newX),u in zip(candidate,newuops)]
|
||||
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
@@ -469,8 +470,7 @@ REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda cond,x,i: cond.where(newx, i) if
|
||||
(newx:=uop_given_valid(cond, x)) is not x else None),
|
||||
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
|
||||
Reference in New Issue
Block a user