failed test case for unrolled half4 (#5552)

This commit is contained in:
chenyu
2024-07-18 13:05:52 -04:00
committed by GitHub
parent d1a7279605
commit 12e6771209

View File

@@ -945,6 +945,10 @@ class TestFloat4(unittest.TestCase):
def count_float4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)]))
@staticmethod
def count_half4(k):
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)]))
# TODO: express opts below as auto opts
@@ -1081,6 +1085,24 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(k) == (1, 1)
@unittest.expectedFailure
def test_half4_load_unrolled(self):
# from llama 7B shard 4 gpus
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
# TODO: fix this, expected might change but should be positive
for expected, opts in [
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
count = TestFloat4.count_half4(k)
assert count == expected, f"{count=}, {expected=}"
class TestHandCodedOpts(unittest.TestCase):
def test_masked_upcast(self):
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])