mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix indentation in UOp pretty_print for repeated references (#13857)
* fix correct indentation in UOp pretty_print for repeated references When a UOp was referenced multiple times, the walrus operator notation (e.g., x0:=) was correctly used for the first occurrence, but subsequent references had misaligned indentation due to an extra space character. Fix indentation misalignment in pretty_print() when UOps are referenced multiple times. * add simple unit tests for UOp repr --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
38
test/unit/test_uop_repr.py
Normal file
38
test/unit/test_uop_repr.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
from tinygrad import UOp, dtypes
|
||||
|
||||
class TestUOpRepr(unittest.TestCase):
|
||||
def test_simple_const(self):
|
||||
a = UOp.const(dtypes.int, 42)
|
||||
self.assertEqual(repr(a), "UOp(Ops.CONST, dtypes.int, arg=42, src=())")
|
||||
def test_different_consts(self):
|
||||
a, b = UOp.const(dtypes.int, 42), UOp.const(dtypes.int, 3)
|
||||
expected = (
|
||||
"UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" +
|
||||
" UOp(Ops.CONST, dtypes.int, arg=42, src=()),\n" +
|
||||
" UOp(Ops.CONST, dtypes.int, arg=3, src=()),))"
|
||||
)
|
||||
self.assertEqual(repr(a+b), expected)
|
||||
def test_walrus_operator_indentation(self):
|
||||
# The reference should have the same indentation as the definition
|
||||
a = UOp.const(dtypes.int, 42)
|
||||
expected = (
|
||||
"UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" +
|
||||
" x0:=UOp(Ops.CONST, dtypes.int, arg=42, src=()),\n" +
|
||||
" x0,))"
|
||||
)
|
||||
self.assertEqual(repr(a+a), expected)
|
||||
def test_nested_walrus_indentation(self):
|
||||
# Ensure indentation is consistent at multiple levels
|
||||
b = (a:=UOp.const(dtypes.int, 1)) + a
|
||||
expected = (
|
||||
"UOp(Ops.MUL, dtypes.int, arg=None, src=(\n" +
|
||||
" x0:=UOp(Ops.ADD, dtypes.int, arg=None, src=(\n" +
|
||||
" x1:=UOp(Ops.CONST, dtypes.int, arg=1, src=()),\n" +
|
||||
" x1,)),\n" +
|
||||
" x0,))"
|
||||
)
|
||||
self.assertEqual(repr(b*b), expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -70,7 +70,7 @@ def pretty_print(x:UOp, cache=None, d=0)->str:
|
||||
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
||||
if cache[s][1] == 1: dfs(s, cache)
|
||||
if cache is None: dfs(x, cache:={})
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d}x{cx[0]}"
|
||||
cx[2], srcs = True, (''.join(f'\n{pretty_print(s, cache, d+2)},' for s in x.src))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{type(x).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=({srcs}))"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user