diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index a8e2f44c94..31aa4e1134 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -97,9 +97,9 @@ class TestViz(BaseTestViz): # name can also come from a function that returns a string def test_dyn_name_fxn(self): - @track_rewrites(name=lambda a,ret: a.render()) - def name_from_fxn(s:UOp): return graph_rewrite(s, PatternMatcher([])) - name_from_fxn(UOp.variable("a", 1, 10)+1) + @track_rewrites(name=lambda *args,ret,**kwargs: ret.render()) + def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([])) + name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"]) lst = get_viz_list() # name gets deduped by the function call counter self.assertEqual(lst[0]["name"], "(a+1) n1") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 0f1b7613a6..7af3b0b550 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -12,7 +12,7 @@ from tinygrad.codegen.opt.kernel import Opt # **************** Program Creation **************** -@track_rewrites(name=lambda _ast,_renderer,ret: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret)) +@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret)) def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec: """ Transform an AST into a ProgramSpec. May trigger BEAM search.