diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 92d63146db..a0f1dfb875 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -111,7 +111,10 @@ class TestLinearizer(unittest.TestCase): first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (1,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) - diff = second_x - first_reduce + first_shape = first_x.st_arg.reduce(first_reduce.arg[1]) + const_st = ShapeTracker.from_shape(()).reshape((1,)*len(first_shape)).expand(first_shape) + neg_first_reduce = first_reduce * UOp(UOps.CONST, dtypes.float, (const_st.to_uop(),), -1.0) + diff = second_x + neg_first_reduce second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (0,))) store = UOp(UOps.STORE, None, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce)) sink = UOp(UOps.SINK, src=(store,))