fix test_matmul_relu_cat (#12478)

This commit is contained in:
George Hotz
2025-10-07 14:32:23 +08:00
committed by GitHub
parent 514d2a0774
commit ea7672931f
3 changed files with 3 additions and 2 deletions

View File

@@ -2,7 +2,6 @@ import unittest
from tinygrad import Tensor, nn
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
from tinygrad.uop.ops import UOp
from test.helpers import expect_rangeify_fails
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestRangeifyAssign(unittest.TestCase):
@@ -302,7 +301,6 @@ class TestOuterworld(unittest.TestCase):
self.assertTrue((t==o).all().item())
class TestRangeifyEdgeCase(unittest.TestCase):
@expect_rangeify_fails # TODO: fix
def test_matmul_relu_cat(self):
a = Tensor.ones(100, 512).contiguous().realize()
c = Tensor.ones(1, 512).contiguous().realize()

View File

@@ -1196,6 +1196,8 @@ renderer = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"[{','.join([y.arg for y in x.src])}]" if not all_same(x.src) else f"{len(x.src)}x[{x.src[0].arg}]")),
])
renderer_infer = PatternMatcher([
(UPat(Ops.MOD, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"cmod({x.src[0].arg}, {x.src[1].arg})")),

View File

@@ -455,6 +455,7 @@ def _valid_priority(v: UOp, valids:list[UOp]):
except ValueError: return 0
def simplify_valid(valid:UOp) -> UOp|None:
if valid.op_in_parents(Ops.LOAD): return None # this should only be for indexing, skip if there's a LOAD
ret:list[UOp] = []
something_changed = False
valids = list(valid.split_uop(Ops.AND))