diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index cf218301b2..afd899794f 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -121,16 +121,11 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: # ***** optional patterns ***** -@functools.lru_cache(None) -def get_transcendental_patterns(already_supported_ops, force_transcendental=False): - pat = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ - ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in already_supported_ops or force_transcendental] - return PatternMatcher(pat) - powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) -def get_late_rewrite_patterns(ops): - pat: List[Tuple[UPat, Callable]] = [] +def get_late_rewrite_patterns(ops, force_transcendental=False): + pat: List[Tuple[UPat, Callable]] = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ + ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) if BinaryOps.AND in ops: pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), @@ -500,8 +495,8 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) - # initial symbolic + migrate indexing (remove this) + transcendental - sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)) + # initial symbolic + migrate indexing (remove this) + sink = graph_rewrite(sink, sym+migrate_indexing) # expand sink = graph_rewrite(sink, sym+expander) @@ -510,5 +505,5 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing) # final rules for the renderer (without sym) - sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops)+pm_render+extra_matcher) + sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher) return sink