mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix test_matmul_relu_cat (#12478)
This commit is contained in:
@@ -2,7 +2,6 @@ import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
from test.helpers import expect_rangeify_fails
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
@@ -302,7 +301,6 @@ class TestOuterworld(unittest.TestCase):
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
@expect_rangeify_fails # TODO: fix
|
||||
def test_matmul_relu_cat(self):
|
||||
a = Tensor.ones(100, 512).contiguous().realize()
|
||||
c = Tensor.ones(1, 512).contiguous().realize()
|
||||
|
||||
@@ -1196,6 +1196,8 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
|
||||
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"[{','.join([y.arg for y in x.src])}]" if not all_same(x.src) else f"{len(x.src)}x[{x.src[0].arg}]")),
|
||||
])
|
||||
renderer_infer = PatternMatcher([
|
||||
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),
|
||||
|
||||
@@ -455,6 +455,7 @@ def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
except ValueError: return 0
|
||||
|
||||
def simplify_valid(valid:UOp) -> UOp|None:
|
||||
if valid.op_in_parents(Ops.LOAD): return None # this should only be for indexing, skip if there's a LOAD
|
||||
ret:list[UOp] = []
|
||||
something_changed = False
|
||||
valids = list(valid.split_uop(Ops.AND))
|
||||
|
||||
Reference in New Issue
Block a user