mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 00:55:11 -05:00
where fold try 2 (#3748)
* where fold try 2 * assign fold * test_where_fold works * add gated store support to ops_python --------- Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
@@ -44,19 +44,21 @@ class PythonProgram:
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is UOps.STORE:
|
||||
assert len(inp) <= 3, "gated stores not supported yet"
|
||||
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
# image store
|
||||
assert dtp[2].count == 4
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
|
||||
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
|
||||
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
||||
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
elif dtp[2].count > 1:
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
|
||||
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
|
||||
if g: _store(m, o+j, v)
|
||||
else:
|
||||
for m,o,v in zip(*inp): _store(m, o, v)
|
||||
for m,o,v,g in zip(*inp):
|
||||
if g: _store(m, o, v)
|
||||
i += 1
|
||||
continue
|
||||
elif uop is UOps.ENDLOOP:
|
||||
|
||||
Reference in New Issue
Block a user