maskless const can lower without valid, p1 [pr] (#8094)

This commit is contained in:
qazal
2024-12-06 23:21:19 +02:00
committed by GitHub
parent aaf2379f97
commit a97b8fa3c5
2 changed files with 8 additions and 1 deletions

View File

@@ -88,5 +88,11 @@ class TestVerifyAST(unittest.TestCase):
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
def test_flat_const_always_valid(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
helper_test_verify_ast(st)
if __name__ == '__main__':
unittest.main()

View File

@@ -733,7 +733,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
st = uop.arg
# everything else inherits shape
else:
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
st = src_sts[0]
if not all_same(shapes:=[x.shape for x in src_sts]):
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
raise AssertionError(f"found implicit expand {sizes} {shapes}")