diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 26140d7ec7..4f1f914e0f 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -325,12 +325,12 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental): pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f), (UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"), lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))] + # rewrite SQRT to xpow 0.5 + if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) # no real hardware supports THREEFRY, but NullRenderer does if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) - # rewrite SQRT to xpow 0.5 - if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),