mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 05:18:01 -05:00
This reverts commit ddda9420be.
This commit is contained in:
@@ -1114,6 +1114,7 @@ class TestFloat4(unittest.TestCase):
|
||||
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_float2_acc(self):
|
||||
# from resnet
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
@@ -1124,7 +1125,7 @@ class TestFloat4(unittest.TestCase):
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)])
|
||||
count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
|
||||
class TestHandCodedOpts(unittest.TestCase):
|
||||
|
||||
@@ -193,13 +193,11 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
|
||||
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# CONTRACT before ALU/REDUCE/CAST
|
||||
# CONTRACT before ALU/REDUCE
|
||||
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.ALU, name="alu"),)),
|
||||
lambda con, alu: UOp(alu.op, con.dtype, tuple(UOp(UOps.CONTRACT, x.dtype.vec(con.dtype.count), (x,), con.arg) for x in alu.src), alu.arg)),
|
||||
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.REDUCE, name="red"),)),
|
||||
lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)),
|
||||
(UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.CAST, src=(UPat(name="casted"),)),)),
|
||||
lambda con, casted: UOp(UOps.CAST, con.dtype, (UOp(UOps.CONTRACT, casted.dtype.vec(con.dtype.count), (casted,), con.arg),))),
|
||||
# bigint is rewritten to int32
|
||||
(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
|
||||
@@ -449,7 +447,7 @@ def do_contract(con:UOp):
|
||||
|
||||
def no_vectorized_alu(alu):
|
||||
if alu.dtype.count == 1: return None
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(),
|
||||
alus = tuple(UOp(UOps.ALU, alu.dtype.scalar(),
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
@@ -469,7 +467,7 @@ expander = PatternMatcher([
|
||||
# empty EXPAND is NOOP
|
||||
(UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat({UOps.ALU, UOps.CAST}, name="alu"), no_vectorized_alu),
|
||||
(UOp(UOps.ALU).name("alu"), no_vectorized_alu),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
@@ -105,7 +105,7 @@ def type_verify(uops):
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert len(src) == 1
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == dtype.count
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
|
||||
Reference in New Issue
Block a user