From ad155f5454741444cd877aa21d432f201690880d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 2 Jul 2025 20:20:01 +0300 Subject: [PATCH] print inputs to get_program in process replay [pr] (#11051) * print inputs to get_program in process replay [pr] * colors * keep dataclass default escapes * Revert "keep dataclass default escapes" This reverts commit c6db7e8a7a23a8da249505a347ea4ff46dd026f3. * note for ast_repr * add that back --- test/external/process_replay/process_replay.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 57c5326d34..3ec542d716 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64 +import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs from typing import Callable, Any from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.kernelize.kernelize import get_kernelize_map @@ -20,7 +20,8 @@ early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format="%(message)s") MAX_LINES = 500 def trunc_log(x): - if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] + if len(lines:=(x if isinstance(x, str) else repr(x)).splitlines()) > MAX_LINES: + lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] logging.info("\n".join(lines)) # user config @@ -41,12 +42,15 @@ def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[ return to_str(new_sink), to_str(ret[big_sink]), (big_sink,) def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer) -> tuple[str, str, tuple[Any, ...]]: - p2 = get_program(ast.replace(arg=KernelInfo(opts_to_apply=p.applied_opts, name=p.name)) if ast.arg is None else ast, renderer) + input_ast = ast.replace(arg=KernelInfo(opts_to_apply=p.applied_opts, name=p.name)) if ast.arg is None else ast + p2 = get_program(input_ast, renderer) def to_str(ret:ProgramSpec) -> str: # PYTHON renderer pickles UOps, first unpickle and decode here if p.device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret.src))]) return ret.src - return to_str(p2), to_str(p), (p.ast, renderer, p.applied_opts) + # properly color the name arg + ast_repr = codecs.decode(str(input_ast), "unicode_escape") + return to_str(p2), to_str(p), (ast_repr, renderer) replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}