mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
failed test case that DEFINE_ACC no long uses float4 (#5555)
* failed test case that DEFINE_ACC no long uses float4 * line
This commit is contained in:
@@ -1095,13 +1095,26 @@ class TestFloat4(unittest.TestCase):
|
||||
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
|
||||
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
|
||||
]:
|
||||
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
count = TestFloat4.count_half4(k)
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_float4_acc(self):
|
||||
# from float32 stable diffusion red tinybox
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(5, 6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
for expected, opts in [
|
||||
(1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]),
|
||||
(4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]),
|
||||
]:
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
class TestHandCodedOpts(unittest.TestCase):
|
||||
def test_masked_upcast(self):
|
||||
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
|
||||
|
||||
Reference in New Issue
Block a user