mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
where fold try 2 (#3748)
* where fold try 2 * assign fold * test_where_fold works * add gated store support to ops_python --------- Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
@@ -198,6 +198,25 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
a.assign(a+m)
|
||||
a.realize()
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_where_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
a.assign(b.where(2, a))
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_simplify_uop(self):
|
||||
def helper_test_simplify(uop, dtype, vin, arg=None):
|
||||
ast = LazyOp(BufferOps.CONST, (),
|
||||
|
||||
@@ -14,7 +14,7 @@ class UOps(Enum):
|
||||
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto(); NOOP = auto() # noqa: E702
|
||||
|
||||
@dataclass(eq=False)
|
||||
class UOp:
|
||||
@@ -116,6 +116,13 @@ constant_folder = PatternMatcher([
|
||||
# ** zero folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
# ** load/store folding **
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
|
||||
{"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
|
||||
"vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
|
||||
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
|
||||
])
|
||||
|
||||
class UOpGraph:
|
||||
|
||||
@@ -101,7 +101,11 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Var
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if assign_to is not None and buf is assign_to:
|
||||
if not unbound_st.contiguous: raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
||||
if not unbound_st.contiguous:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
||||
|
||||
@@ -44,19 +44,21 @@ class PythonProgram:
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is UOps.STORE:
|
||||
assert len(inp) <= 3, "gated stores not supported yet"
|
||||
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
# image store
|
||||
assert dtp[2].count == 4
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
|
||||
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
|
||||
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
||||
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
elif dtp[2].count > 1:
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
|
||||
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
|
||||
if g: _store(m, o+j, v)
|
||||
else:
|
||||
for m,o,v in zip(*inp): _store(m, o, v)
|
||||
for m,o,v,g in zip(*inp):
|
||||
if g: _store(m, o, v)
|
||||
i += 1
|
||||
continue
|
||||
elif uop is UOps.ENDLOOP:
|
||||
|
||||
Reference in New Issue
Block a user