mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
use CONTRACT before REDUCE (#5903)
* use CONTRACT before REDUCE [run_process_replay] * support half expand * EXPAND GEP
This commit is contained in:
@@ -352,6 +352,22 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 2
|
||||
assert sink.src[0] == sink.src[1]
|
||||
|
||||
def test_contract_half_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
|
||||
assert sink.src[0] == sink.src[1]
|
||||
assert sink.src[0] != sink.src[2]
|
||||
assert sink.src[6] == sink.src[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
|
||||
Reference in New Issue
Block a user