don't bufferize 0s

This commit is contained in:
George Hotz
2025-08-20 21:02:09 -07:00
parent a044648111
commit 392a21b82b

View File

@@ -121,7 +121,7 @@ def map_reshape(idx:UOp, r:UOp):
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
tret = ret[0].sink(*ret[1:]).simplify(tracked=True).src[::-1] if len(ret) else ()
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
@@ -184,7 +184,7 @@ def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
passthrough_idx.append(idx.src[1+i])
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
new_ranges.append(ranges[-1])
ret = x.src[0].index(*ranges).bufferize(*new_ranges, arg=x.device)
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=x.device)
return ret.index(*passthrough_idx)
def map_contiguous(ctx:RangeifyContext, x:UOp):
@@ -192,7 +192,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp):
ranges = []
for s in x.shape:
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0))
ret = x.src[0].index(*ranges).bufferize(*ranges, arg=x.device)
ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device)
return ret.forced_reshape(x.shape)
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):