mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
test_multireduce const has a shape (#6218)
This commit is contained in:
@@ -111,7 +111,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
|
||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (ReduceOps.SUM, (1,)))
|
||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
|
||||
diff = second_x - first_reduce
|
||||
first_shape = first_x.st_arg.reduce(first_reduce.arg[1])
|
||||
const_st = ShapeTracker.from_shape(()).reshape((1,)*len(first_shape)).expand(first_shape)
|
||||
neg_first_reduce = first_reduce * UOp(UOps.CONST, dtypes.float, (const_st.to_uop(),), -1.0)
|
||||
diff = second_x + neg_first_reduce
|
||||
second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (ReduceOps.SUM, (0,)))
|
||||
store = UOp(UOps.STORE, None, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
|
||||
Reference in New Issue
Block a user