diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 91664d253e..739bdd128f 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -56,7 +56,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): new_srcs = [] for s in x.src: new_src = s - if s.op in {Ops.BUFFER, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): + if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag) @@ -182,7 +182,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: if x.op is Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)] if x.op is Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)] if x.op is Ops.EXPAND: - rngs = [a if resolve(x==y, False) else a.const_like(0) for a,x,y in zip(rngs, x.src[0].shape, x.shape)] + rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)] ending_ranges[x] = True if x.op is Ops.PAD: rngs = rngs[:]