mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
disable simplify_phi_loops (#3812)
* disble simplify_phi_loops this breaks BEAM search GPT2. * skip that
This commit is contained in:
@@ -251,6 +251,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
|
||||
arg=TernaryOps.WHERE).arg == c0.arg
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_phi_simplification(self):
|
||||
def helper(t, max_ops=0):
|
||||
sched = create_schedule([t.lazydata])
|
||||
|
||||
@@ -190,7 +190,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_failure_25(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)]
|
||||
helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "HSA", "CUDA"])
|
||||
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -131,6 +131,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
|
||||
helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="RHIP", "broken in HIP CI")
|
||||
def test_split(self):
|
||||
def tensor(s): return torch.arange(math.prod(s), dtype=torch.int32).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
|
||||
test_cases = [
|
||||
|
||||
Reference in New Issue
Block a user