where fold try 2 (#3748)

* where fold try 2

* assign fold

* test_where_fold works

* add gated store support to ops_python

---------

Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
George Hotz
2024-03-15 07:46:26 -07:00
committed by GitHub
parent 6b8c66e04f
commit ca19eb3e82
4 changed files with 39 additions and 7 deletions

View File

@@ -198,6 +198,25 @@ class TestLinearizer(unittest.TestCase):
lin = Linearizer(*sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(a+m)
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
sched = create_schedule([a.lazydata])
assert len(sched) == 1
lin = Linearizer(*sched[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_simplify_uop(self):
def helper_test_simplify(uop, dtype, vin, arg=None):
ast = LazyOp(BufferOps.CONST, (),

View File

@@ -14,7 +14,7 @@ class UOps(Enum):
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto(); NOOP = auto() # noqa: E702
@dataclass(eq=False)
class UOp:
@@ -116,6 +116,13 @@ constant_folder = PatternMatcher([
# ** zero folding **
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
# ** load/store folding **
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
# TODO: can do the invert of this (flip alt/load) when we fix double ops
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
])
class UOpGraph:

View File

@@ -101,7 +101,11 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
if not unbound_st.contiguous: raise RuntimeError(f"must be contiguous for assign {unbound_st}")
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
if buf not in inputs: inputs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))

View File

@@ -44,19 +44,21 @@ class PythonProgram:
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is UOps.STORE:
assert len(inp) <= 3, "gated stores not supported yet"
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
if isinstance(dtp[0], ImageDType):
# image store
assert dtp[2].count == 4
for j,val in enumerate(inp[2]):
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
elif dtp[2].count > 1:
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
if g: _store(m, o+j, v)
else:
for m,o,v in zip(*inp): _store(m, o, v)
for m,o,v,g in zip(*inp):
if g: _store(m, o, v)
i += 1
continue
elif uop is UOps.ENDLOOP: