mod can be and (#6810)

This commit is contained in:
George Hotz
2024-09-30 12:33:15 +08:00
committed by GitHub
parent c9d763d331
commit 00b3171902

View File

@@ -274,16 +274,21 @@ transcendental_patterns = [
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
]
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_extra_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
shiftable_consts = set([2**i for i in range(64)])
pat += [
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda root, mul, const:
UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const:
UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)]
(UPat(UOps.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
UOp(UOps.ALU, mul.dtype, (mul, UOp.const(dtypes.int, powers_of_two[const.arg])), BinaryOps.SHL) if const.arg in powers_of_two else None),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
UOp(UOps.ALU, div.dtype, (div, UOp.const(dtypes.int, powers_of_two[const.arg])), BinaryOps.SHR) if const.arg in powers_of_two else None)]
if UnaryOps.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]