diff --git a/test/unit/test_uop_repr.py b/test/unit/test_uop_repr.py new file mode 100644 index 0000000000..f1fbd9bed4 --- /dev/null +++ b/test/unit/test_uop_repr.py @@ -0,0 +1,38 @@ +import unittest +from tinygrad import UOp, dtypes + +class TestUOpRepr(unittest.TestCase): + def test_simple_const(self): + a = UOp.const(dtypes.int, 42) + self.assertEqual(repr(a), "UOp(Ops.CONST, dtypes.int, arg=42, src=())") + def test_different_consts(self): + a, b = UOp.const(dtypes.int, 42), UOp.const(dtypes.int, 3) + expected = ( + "UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" + + " UOp(Ops.CONST, dtypes.int, arg=42, src=()),\n" + + " UOp(Ops.CONST, dtypes.int, arg=3, src=()),))" + ) + self.assertEqual(repr(a+b), expected) + def test_walrus_operator_indentation(self): + # The reference should have the same indentation as the definition + a = UOp.const(dtypes.int, 42) + expected = ( + "UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" + + " x0:=UOp(Ops.CONST, dtypes.int, arg=42, src=()),\n" + + " x0,))" + ) + self.assertEqual(repr(a+a), expected) + def test_nested_walrus_indentation(self): + # Ensure indentation is consistent at multiple levels + b = (a:=UOp.const(dtypes.int, 1)) + a + expected = ( + "UOp(Ops.MUL, dtypes.int, arg=None, src=(\n" + + " x0:=UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" + + " x1:=UOp(Ops.CONST, dtypes.int, arg=1, src=()),\n" + + " x1,)),\n" + + " x0,))" + ) + self.assertEqual(repr(b*b), expected) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 71a8003f64..5e10812348 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -70,7 +70,7 @@ def pretty_print(x:UOp, cache=None, d=0)->str: cache.setdefault(s, [len(cache), 0, False])[1] += 1 if cache[s][1] == 1: dfs(s, cache) if cache is None: dfs(x, cache:={}) - if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}" + if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d}x{cx[0]}" cx[2], srcs = True, (''.join(f'\n{pretty_print(s, cache, d+2)},' for s in x.src)) return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{type(x).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=({srcs}))"