mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix BUFFER_UOPS sts in verify_ast [run_process_replay] (#6389)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import unittest
|
||||
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
@@ -72,5 +74,13 @@ class TestVerifyAST(unittest.TestCase):
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
def test_buffer_uops_st(self):
|
||||
a = Tensor.randn(4, 4)+2
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -791,7 +791,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[-1]]
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]]
|
||||
for x in (src[1:] if op in BUFFER_UOPS else src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
|
||||
|
||||
@@ -347,9 +347,11 @@ class UOp(MathTrait):
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
ret = self.src[0 if self.op is UOps.CONST else 1]
|
||||
ret = self.src[self.st_loc]
|
||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
|
||||
|
||||
Reference in New Issue
Block a user