diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index a721940700..cf218301b2 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable +from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, Callable import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType @@ -121,15 +121,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: # ***** optional patterns ***** -transcendental_patterns = [ - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2), - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2), - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin), -] - @functools.lru_cache(None) -def get_transcendental_patterns(ops, force_transcendental=False): - pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental] +def get_transcendental_patterns(already_supported_ops, force_transcendental=False): + pat = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ + ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in already_supported_ops or force_transcendental] return PatternMatcher(pat) powers_of_two = {2**i:i for i in range(64)}