mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
more vconst stuff + gep tuple [run_process_replay] (#6494)
* more vconst stuff [run_process_replay] * revert that * fix inf loop
This commit is contained in:
@@ -86,9 +86,18 @@ class TestGraphRewriteConst(unittest.TestCase):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
|
||||
ret = graph_rewrite(v1+v2, constant_folder)
|
||||
self.assertEqual(ret.op, UOps.VCONST)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, (5,7,9))
|
||||
|
||||
def test_add_const_lose_v(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
|
||||
ret = graph_rewrite(v1+v2, constant_folder)
|
||||
self.assertEqual(ret.op, UOps.CONST)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, 2)
|
||||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
@@ -421,7 +430,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
|
||||
def expander_rewrite(sink):
|
||||
sink = graph_rewrite(sink, constant_folder + expander)
|
||||
return graph_rewrite(sink, constant_folder + reducer)
|
||||
def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding)
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
@@ -593,7 +604,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertListEqual([src.arg for src in single_load.src[2].src], [0.0, 1.0, 2.0, 3.0])
|
||||
self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0])
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
|
||||
@@ -449,6 +449,7 @@ def create_gate(root:UOp) -> Optional[UOp]:
|
||||
return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret
|
||||
|
||||
expander = PatternMatcher([
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
# create gate MUST BE BEFORE expander
|
||||
(UPat(UOps.STORE, name="root"), create_gate),
|
||||
# do expansion
|
||||
@@ -490,7 +491,7 @@ reducer = PatternMatcher([
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count) if x.dtype.count > 1 else dtypes.int32, x.src, x.arg) \
|
||||
if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
|
||||
@@ -398,6 +398,7 @@ class UOp(MathTrait):
|
||||
def _const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return cls(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
|
||||
Reference in New Issue
Block a user