test_multireduce const has a shape (#6218)

This commit is contained in:
qazal
2024-08-21 16:02:45 +08:00
committed by GitHub
parent 911bf7216c
commit f03e5a4b3b

View File

@@ -111,7 +111,10 @@ class TestLinearizer(unittest.TestCase):
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (1,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
diff = second_x - first_reduce
first_shape = first_x.st_arg.reduce(first_reduce.arg[1])
const_st = ShapeTracker.from_shape(()).reshape((1,)*len(first_shape)).expand(first_shape)
neg_first_reduce = first_reduce * UOp(UOps.CONST, dtypes.float, (const_st.to_uop(),), -1.0)
diff = second_x + neg_first_reduce
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (0,)))
store = UOp(UOps.STORE, None, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
sink = UOp(UOps.SINK, src=(store,))