mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
pass through name function args in track_rewrites (#11572)
This commit is contained in:
@@ -97,9 +97,9 @@ class TestViz(BaseTestViz):
|
||||
|
||||
# name can also come from a function that returns a string
|
||||
def test_dyn_name_fxn(self):
|
||||
@track_rewrites(name=lambda a,ret: a.render())
|
||||
def name_from_fxn(s:UOp): return graph_rewrite(s, PatternMatcher([]))
|
||||
name_from_fxn(UOp.variable("a", 1, 10)+1)
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: ret.render())
|
||||
def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([]))
|
||||
name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"])
|
||||
lst = get_viz_list()
|
||||
# name gets deduped by the function call counter
|
||||
self.assertEqual(lst[0]["name"], "(a+1) n1")
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.codegen.opt.kernel import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda _ast,_renderer,ret: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
|
||||
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
Reference in New Issue
Block a user