more vconst stuff + gep tuple [run_process_replay] (#6494)

* more vconst stuff [run_process_replay]

* revert that

* fix inf loop
This commit is contained in:
George Hotz
2024-09-12 14:58:14 +08:00
committed by GitHub
parent 4507ab8016
commit 6dfa63cb21
3 changed files with 16 additions and 3 deletions

View File

@@ -86,9 +86,18 @@ class TestGraphRewriteConst(unittest.TestCase):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
ret = graph_rewrite(v1+v2, constant_folder)
self.assertEqual(ret.op, UOps.VCONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, (5,7,9))
def test_add_const_lose_v(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
ret = graph_rewrite(v1+v2, constant_folder)
self.assertEqual(ret.op, UOps.CONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, 2)
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
@@ -421,7 +430,9 @@ class TestUOpGraph(unittest.TestCase):
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
def expander_rewrite(sink):
sink = graph_rewrite(sink, constant_folder + expander)
return graph_rewrite(sink, constant_folder + reducer)
def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding)
class TestExpander(unittest.TestCase):
@@ -593,7 +604,7 @@ class TestLoadStoreFolder(unittest.TestCase):
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0])
self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0])
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))

View File

@@ -449,6 +449,7 @@ def create_gate(root:UOp) -> Optional[UOp]:
return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret
expander = PatternMatcher([
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
# create gate MUST BE BEFORE expander
(UPat(UOps.STORE, name="root"), create_gate),
# do expansion
@@ -490,7 +491,7 @@ reducer = PatternMatcher([
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
])
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
if x.dtype.scalar() == dtypes.pyint else None)])

View File

@@ -398,6 +398,7 @@ class UOp(MathTrait):
def _const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return cls(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}