mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
less rewrite stages in matcher (#7445)
* less rewrite stages in matcher * better name
This commit is contained in:
@@ -127,10 +127,15 @@ transcendental_patterns = [
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
|
||||
]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_transcendental_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_extra_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
def get_late_rewrite_patterns(ops):
|
||||
pat: List[Tuple[UPat, Callable]] = []
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
@@ -457,7 +462,7 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optio
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
||||
|
||||
reducer = PatternMatcher([
|
||||
load_store_indexing = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
@@ -505,9 +510,11 @@ pm_render = PatternMatcher([
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# temp for indexing migration
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
# initial symbolic + migrate indexing (remove this) + transcendental
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
@@ -515,15 +522,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
|
||||
|
||||
# devectorize
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
# devectorize + load_store_indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
|
||||
# cleanups
|
||||
sink = graph_rewrite(sink, sym+reducer)
|
||||
|
||||
# add extra patterns
|
||||
sink = graph_rewrite(sink, sym+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
# for rendering without sym (including the rules from the renderer)
|
||||
sink = graph_rewrite(sink, symbolic_simple+(pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render))
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops)+pm_render+extra_matcher)
|
||||
return sink
|
||||
|
||||
@@ -513,7 +513,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py",
|
||||
"lowerer.py", "cstyle.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
|
||||
Reference in New Issue
Block a user