fix bug in assert message (#5787)

This commit is contained in:
P4ssenger
2024-07-29 09:46:23 -03:00
committed by GitHub
parent ab3839a80a
commit 9c80f9adf9

View File

@@ -157,8 +157,8 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
st = op.arg.st if op.op in BufferOps else sts[op.src[0]]
for x in op.src:
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {sts[x].shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {sts[x].shape} {prod(sts[x].shape)} != {prod(st.shape)}")
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op.op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
sts[op] = st
for i, out in enumerate(ast.src):
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"