mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
update get_transcendental_patterns [pr] (#7489)
i think ths is better than `(p[0], cast(Callable, p[1]))`
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
@@ -121,15 +121,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
|
||||
# ***** optional patterns *****
|
||||
|
||||
transcendental_patterns = [
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
|
||||
]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_transcendental_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
def get_transcendental_patterns(already_supported_ops, force_transcendental=False):
|
||||
pat = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
|
||||
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in already_supported_ops or force_transcendental]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
|
||||
Reference in New Issue
Block a user