move print_diff to process replay [pr] (#8566)

* move print_diff to process replay [pr]

* ruff rightfully complians
This commit is contained in:
qazal
2025-01-11 09:28:45 -05:00
committed by GitHub
parent 2f0856c1e2
commit a70d1bf439
2 changed files with 3 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
import time, logging, difflib
import time
from typing import Callable, Optional, Tuple
import numpy as np
from tinygrad import Tensor, dtypes
@@ -8,7 +8,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import T, getenv, colored
from tinygrad.helpers import T
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
@@ -40,16 +40,6 @@ def rand_for_dtype(dt:DType, size:int):
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
else:
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)