disable simplify_phi_loops (#3812)

* disble simplify_phi_loops

this breaks BEAM search GPT2.

* skip that
This commit is contained in:
chenyu
2024-03-18 19:25:26 -04:00
committed by GitHub
parent 4c4d3cb3e3
commit ac866eaf5a
4 changed files with 6 additions and 2 deletions

View File

@@ -251,6 +251,7 @@ class TestLinearizer(unittest.TestCase):
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
arg=TernaryOps.WHERE).arg == c0.arg
@unittest.expectedFailure
def test_phi_simplification(self):
def helper(t, max_ops=0):
sched = create_schedule([t.lazydata])

View File

@@ -190,7 +190,7 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_25(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)]
helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "HSA", "CUDA"])
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@@ -131,6 +131,7 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True)
@unittest.skipIf(Device.DEFAULT=="RHIP", "broken in HIP CI")
def test_split(self):
def tensor(s): return torch.arange(math.prod(s), dtype=torch.int32).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
test_cases = [