From 51892c8fac8d733dd6473e7ef0fad0ae57661b60 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:44:30 -0700 Subject: [PATCH] Revert "push contract through cast to fix test_float2_acc (#5581)" (#5583) This reverts commit ddda9420bec953740b2d31eaee5c7040facf72a9. --- test/test_linearizer.py | 3 ++- tinygrad/codegen/uopgraph.py | 8 +++----- tinygrad/codegen/uops.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index dcaca1f4e1..984d23c4a8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1114,6 +1114,7 @@ class TestFloat4(unittest.TestCase): count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) assert count == expected, f"{count=}, {expected=}" + @unittest.expectedFailure def test_float2_acc(self): # from resnet ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 @@ -1124,7 +1125,7 @@ class TestFloat4(unittest.TestCase): k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() - count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)]) + count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) assert count == expected, f"{count=}, {expected=}" class TestHandCodedOpts(unittest.TestCase): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index c6d7ef0eb1..5d6751607d 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -193,13 +193,11 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng): # this is symbolic 2.0 constant_folder = PatternMatcher([ - # CONTRACT before ALU/REDUCE/CAST + # CONTRACT before ALU/REDUCE (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.ALU, name="alu"),)), lambda con, alu: UOp(alu.op, con.dtype, tuple(UOp(UOps.CONTRACT, x.dtype.vec(con.dtype.count), (x,), con.arg) for x in alu.src), alu.arg)), (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.REDUCE, name="red"),)), lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)), - (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.CAST, src=(UPat(name="casted"),)),)), - lambda con, casted: UOp(UOps.CAST, con.dtype, (UOp(UOps.CONTRACT, casted.dtype.vec(con.dtype.count), (casted,), con.arg),))), # bigint is rewritten to int32 (UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"), lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)), @@ -449,7 +447,7 @@ def do_contract(con:UOp): def no_vectorized_alu(alu): if alu.dtype.count == 1: return None - alus = tuple(UOp(alu.op, alu.dtype.scalar(), + alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(), tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count)) return UOp(UOps.VECTORIZE, alu.dtype, alus) @@ -469,7 +467,7 @@ expander = PatternMatcher([ # empty EXPAND is NOOP (UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x), # no ALU on vectorized dtypes - (UPat({UOps.ALU, UOps.CAST}, name="alu"), no_vectorized_alu), + (UOp(UOps.ALU).name("alu"), no_vectorized_alu), ]) # *** uop graph *** diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 511d83babb..f286471e64 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -105,7 +105,7 @@ def type_verify(uops): arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg - if uop is UOps.CAST: assert len(src) == 1 + if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}" assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"