diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 55262f35fb..dd25ca89e1 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -274,16 +274,21 @@ transcendental_patterns = [ (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin), ] +powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_extra_patterns(ops, force_transcendental=False): pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental] + # rewrite MOD to AND (which should always be supported, but not for generic in tests) + if BinaryOps.AND in ops: + pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), + lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] + # rewrite MUL/IDIV to SHL+SHR if BinaryOps.SHL in ops and BinaryOps.SHR in ops: - shiftable_consts = set([2**i for i in range(64)]) pat += [ - (UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda root, mul, const: - UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None), - (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const: - UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)] + (UPat(UOps.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: + UOp(UOps.ALU, mul.dtype, (mul, UOp.const(dtypes.int, powers_of_two[const.arg])), BinaryOps.SHL) if const.arg in powers_of_two else None), + (UPat(UOps.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: + UOp(UOps.ALU, div.dtype, (div, UOp.const(dtypes.int, powers_of_two[const.arg])), BinaryOps.SHR) if const.arg in powers_of_two else None)] if UnaryOps.NEG in ops: pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]