diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 3ee979a784..0c8848f2a7 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -655,6 +655,10 @@ class TestSymbolic(unittest.TestCase): with self.assertRaises(AssertionError): self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)") + def test_div_mod_recombine_3level(self): + gidx = Variable("gidx", 0, 150527) + self.helper_test_variable(gidx//3%224*3 + gidx%3 + gidx//672*672, 0, 150527, "gidx") + def test_div_mod_recombine_with_gcd(self): b = Variable("b", 0, 100) exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18 diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b5e5809c96..b8047e8fe8 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -506,7 +506,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @functools.cached_property def axis(self) -> int|None: - # COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK + # COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK if self.op is Ops.COPY: return None if self.op is Ops.MULTI: return self.arg # NOTE: they all have to share an axis, we always choose [-1] @@ -557,6 +557,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base return self + @property + def multibase(self) -> UOp: + if self.op in GroupOp.Movement: return self.src[0].base + if self.op is Ops.DETACH: return self.src[0].base # DETACH can't change base + return self + # like gep, but might return an integer def sgep(self, i:int) -> sint: match self.op: @@ -649,8 +655,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op in {Ops.CONTIGUOUS, Ops.RESHAPE}: return self.src[0].buffer # this buffer can process disk tensors and simple movement ops if self is not self.base: - from tinygrad.schedule.rangeify import pm_mops - out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops).simplify() + from tinygrad.schedule.rangeify import pm_mops, symbolic + out = graph_rewrite(self.flatten().index(UOp.range(self.size, 0)), pm_mops+symbolic) buf = out.src[0].buffer assert isinstance(buf, Buffer), "must be a Buffer for movement ops" assert out.op is Ops.INDEX, "couldn't collapse to a single INDEX" @@ -660,7 +666,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return buf.view(self.size, out.dtype, 0) if out.src[1].op is Ops.ADD and out.src[1].src[0].op is Ops.RANGE and out.src[1].src[1].op is Ops.CONST: return buf.view(self.size, out.dtype, out.src[1].src[1].arg*out.dtype.itemsize) - raise RuntimeError(f"cannot collapse INDEX {out} to a single size/offset") + raise RuntimeError(f"cannot collapse INDEX {out.pyrender()} to a single size/offset") if self.op is Ops.BITCAST: buf = self.src[0].buffer assert isinstance(buf, Buffer), "must be a Buffer for BITCAST" diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index df4ae3699c..476c3ccf39 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -52,6 +52,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([ lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 + ((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("b")*UPat.cvar("a")+UPat.var("x")%UPat.cvar("a"), + lambda x,a,b: x%(a*b)), # (x//a%b)*a + x%a = x%(a*b) ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x), ((UPat.var("y")+UPat.var("x")%UPat.cvar("c"))+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda y,x,c: y+x), ((UPat.var("y")+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"))+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),