Revert "push contract through cast to fix test_float2_acc (#5581)" (#5583)

This reverts commit ddda9420be.
This commit is contained in:
George Hotz
2024-07-19 09:44:30 -07:00
committed by GitHub
parent 6bade4d419
commit 51892c8fac
3 changed files with 6 additions and 7 deletions

View File

@@ -1114,6 +1114,7 @@ class TestFloat4(unittest.TestCase):
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
assert count == expected, f"{count=}, {expected=}"
@unittest.expectedFailure
def test_float2_acc(self):
# from resnet
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
@@ -1124,7 +1125,7 @@ class TestFloat4(unittest.TestCase):
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)])
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
assert count == expected, f"{count=}, {expected=}"
class TestHandCodedOpts(unittest.TestCase):

View File

@@ -193,13 +193,11 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# CONTRACT before ALU/REDUCE/CAST
# CONTRACT before ALU/REDUCE
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.ALU, name="alu"),)),
lambda con, alu: UOp(alu.op, con.dtype, tuple(UOp(UOps.CONTRACT, x.dtype.vec(con.dtype.count), (x,), con.arg) for x in alu.src), alu.arg)),
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.REDUCE, name="red"),)),
lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)),
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.CAST, src=(UPat(name="casted"),)),)),
lambda con, casted: UOp(UOps.CAST, con.dtype, (UOp(UOps.CONTRACT, casted.dtype.vec(con.dtype.count), (casted,), con.arg),))),
# bigint is rewritten to int32
(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"),
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
@@ -449,7 +447,7 @@ def do_contract(con:UOp):
def no_vectorized_alu(alu):
if alu.dtype.count == 1: return None
alus = tuple(UOp(alu.op, alu.dtype.scalar(),
alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(),
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
return UOp(UOps.VECTORIZE, alu.dtype, alus)
@@ -469,7 +467,7 @@ expander = PatternMatcher([
# empty EXPAND is NOOP
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
# no ALU on vectorized dtypes
(UPat({UOps.ALU, UOps.CAST}, name="alu"), no_vectorized_alu),
(UOp(UOps.ALU).name("alu"), no_vectorized_alu),
])
# *** uop graph ***

View File

@@ -105,7 +105,7 @@ def type_verify(uops):
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert len(src) == 1
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
if uop is UOps.VECTORIZE:
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"