fix float4 acc by moving contracts (#5559)

This commit is contained in:
George Hotz
2024-07-18 11:30:16 -07:00
committed by GitHub
parent c41cd55556
commit 223d9283ee
2 changed files with 4 additions and 3 deletions

View File

@@ -1101,7 +1101,6 @@ class TestFloat4(unittest.TestCase):
count = TestFloat4.count_half4(k)
assert count == expected, f"{count=}, {expected=}"
@unittest.expectedFailure
def test_float4_acc(self):
# from float32 stable diffusion red tinybox
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(5, 6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501

View File

@@ -193,8 +193,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# CONTRACT before REDUCE
(UPat(UOps.CONTRACT, name="con", src=UPat(UOps.REDUCE, name="red")),
# CONTRACT before ALU/REDUCE
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.ALU, name="alu"),)),
lambda con, alu: UOp(alu.op, con.dtype, tuple(UOp(UOps.CONTRACT, x.dtype.vec(con.dtype.count), (x,), con.arg) for x in alu.src), alu.arg)),
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.REDUCE, name="red"),)),
lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)),
# bigint is rewritten to int32
(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"),