update get_transcendental_patterns [pr] (#7489)

i think ths is better than `(p[0], cast(Callable, p[1]))`
This commit is contained in:
chenyu
2024-11-02 14:25:31 -04:00
committed by GitHub
parent 55bd136746
commit baaec39ffc

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
@@ -121,15 +121,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
# ***** optional patterns *****
transcendental_patterns = [
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
]
@functools.lru_cache(None)
def get_transcendental_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
def get_transcendental_patterns(already_supported_ops, force_transcendental=False):
pat = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in already_supported_ops or force_transcendental]
return PatternMatcher(pat)
powers_of_two = {2**i:i for i in range(64)}