mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
maskless const can lower without valid, p1 [pr] (#8094)
This commit is contained in:
@@ -88,5 +88,11 @@ class TestVerifyAST(unittest.TestCase):
|
||||
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
|
||||
|
||||
def test_flat_const_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -733,7 +733,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
||||
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
|
||||
st = src_sts[0]
|
||||
if not all_same(shapes:=[x.shape for x in src_sts]):
|
||||
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
||||
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
||||
|
||||
Reference in New Issue
Block a user