fix BUFFER_UOPS sts in verify_ast [run_process_replay] (#6389)

This commit is contained in:
qazal
2024-09-06 16:59:22 +08:00
committed by GitHub
parent cc05016fa8
commit f1bd2a5519
3 changed files with 14 additions and 2 deletions

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import unittest
from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
@@ -72,5 +74,13 @@ class TestVerifyAST(unittest.TestCase):
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
def test_buffer_uops_st(self):
a = Tensor.randn(4, 4)+2
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
if __name__ == '__main__':
unittest.main()

View File

@@ -791,7 +791,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]]
for x in (src[1:] if op in BUFFER_UOPS else src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")

View File

@@ -347,9 +347,11 @@ class UOp(MathTrait):
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
# *** uop syntactic sugar
@property
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
ret = self.src[0 if self.op is UOps.CONST else 1]
ret = self.src[self.st_loc]
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)