pass through name function args in track_rewrites (#11572)

This commit is contained in:
qazal
2025-08-08 07:28:52 +08:00
committed by GitHub
parent 1826004ef9
commit 960cc6533a
2 changed files with 4 additions and 4 deletions

View File

@@ -97,9 +97,9 @@ class TestViz(BaseTestViz):
# name can also come from a function that returns a string
def test_dyn_name_fxn(self):
@track_rewrites(name=lambda a,ret: a.render())
def name_from_fxn(s:UOp): return graph_rewrite(s, PatternMatcher([]))
name_from_fxn(UOp.variable("a", 1, 10)+1)
@track_rewrites(name=lambda *args,ret,**kwargs: ret.render())
def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([]))
name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"])
lst = get_viz_list()
# name gets deduped by the function call counter
self.assertEqual(lst[0]["name"], "(a+1) n1")

View File

@@ -12,7 +12,7 @@ from tinygrad.codegen.opt.kernel import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda _ast,_renderer,ret: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.