where fold try 2 (#3748)

* where fold try 2

* assign fold

* test_where_fold works

* add gated store support to ops_python

---------

Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
George Hotz
2024-03-15 07:46:26 -07:00
committed by GitHub
parent 6b8c66e04f
commit ca19eb3e82
4 changed files with 39 additions and 7 deletions

View File

@@ -44,19 +44,21 @@ class PythonProgram:
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is UOps.STORE:
assert len(inp) <= 3, "gated stores not supported yet"
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
if isinstance(dtp[0], ImageDType):
# image store
assert dtp[2].count == 4
for j,val in enumerate(inp[2]):
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
elif dtp[2].count > 1:
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
if g: _store(m, o+j, v)
else:
for m,o,v in zip(*inp): _store(m, o, v)
for m,o,v,g in zip(*inp):
if g: _store(m, o, v)
i += 1
continue
elif uop is UOps.ENDLOOP: