mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -56,7 +56,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
|
||||
@@ -182,7 +182,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if x.op is Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
|
||||
if x.op is Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
|
||||
if x.op is Ops.EXPAND:
|
||||
rngs = [a if resolve(x==y, False) else a.const_like(0) for a,x,y in zip(rngs, x.src[0].shape, x.shape)]
|
||||
rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
|
||||
ending_ranges[x] = True
|
||||
if x.op is Ops.PAD:
|
||||
rngs = rngs[:]
|
||||
|
||||
Reference in New Issue
Block a user