failing test for const folding breaking indexing [pr] (#8103)

This commit is contained in:
qazal
2024-12-07 13:55:02 +02:00
committed by GitHub
parent 8b1fa9cb7d
commit 6be388be86

View File

@@ -108,6 +108,25 @@ class TestLinearizer(unittest.TestCase):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.expectedFailure
def test_const_alu_indexing(self):
st = ShapeTracker.from_shape((4,)).to_uop()
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
op = load+UOp.const(dtypes.float, 1.0)*UOp.const(dtypes.float, -1)
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
Tensor.manual_seed(0)
x = Tensor.randn(4,).realize()
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1*-1], opts=[])
def test_const_alu_indexing_one_const_fine(self):
st = ShapeTracker.from_shape((4,)).to_uop()
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
op = load+UOp.const(dtypes.float, 1.0)
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
Tensor.manual_seed(0)
x = Tensor.randn(4,).realize()
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1], opts=[])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")