diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 0c19e289f5..87ead365f0 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -198,6 +198,25 @@ class TestLinearizer(unittest.TestCase): lin = Linearizer(*sched[0].ast) assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" + def test_assign_fold(self): + a = Tensor.ones(4, 4).contiguous().realize() + m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None)) + a.assign(a+m) + a.realize() + np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) + + def test_where_fold(self): + a = Tensor.ones(4, 4).contiguous().realize() + b = a.shrink(((1, 2), None)).pad(((1, 2), None)) + a.assign(b.where(2, a)) + sched = create_schedule([a.lazydata]) + assert len(sched) == 1 + lin = Linearizer(*sched[-1].ast) + lin.hand_coded_optimizations() + lin.linearize() + assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" + np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) + def test_simplify_uop(self): def helper_test_simplify(uop, dtype, vin, arg=None): ast = LazyOp(BufferOps.CONST, (), diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index b1d6395a7c..9eefbc3e79 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -14,7 +14,7 @@ class UOps(Enum): LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702 DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702 - ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto() # noqa: E702 + ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto(); NOOP = auto() # noqa: E702 @dataclass(eq=False) class UOp: @@ -116,6 +116,13 @@ constant_folder = PatternMatcher([ # ** zero folding ** ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0 ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0 + # ** load/store folding ** + ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, + {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)), + # TODO: can do the invert of this (flip alt/load) when we fix double ops + ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE, + "vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})}, + lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))), ]) class UOpGraph: diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 8f169e09b6..95a7644f1f 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -101,7 +101,11 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) if assign_to is not None and buf is assign_to: - if not unbound_st.contiguous: raise RuntimeError(f"must be contiguous for assign {unbound_st}") + if not unbound_st.contiguous: + # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine + if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and + ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)): + raise RuntimeError(f"must be contiguous for assign {unbound_st}") return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st)) if buf not in inputs: inputs.append(buf) return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 53b3d2d935..aa868cf5e2 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -44,19 +44,21 @@ class PythonProgram: dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) if uop is UOps.STORE: - assert len(inp) <= 3, "gated stores not supported yet" + if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True if isinstance(dtp[0], ImageDType): # image store assert dtp[2].count == 4 for j,val in enumerate(inp[2]): - for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val): + for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]): assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0] - _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v) + if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v) elif dtp[2].count > 1: for j,val in enumerate(inp[2]): - for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v) + for m,o,v,g in zip(inp[0], inp[1], val, inp[3]): + if g: _store(m, o+j, v) else: - for m,o,v in zip(*inp): _store(m, o, v) + for m,o,v,g in zip(*inp): + if g: _store(m, o, v) i += 1 continue elif uop is UOps.ENDLOOP: