as_strided view vs copy

This commit is contained in:
Roelof van Dijk
2025-11-14 13:21:07 +01:00
parent e278692139
commit 82a61223f2

View File

@@ -82,6 +82,7 @@ view_ops = {
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
# TODO do we want Ops.MULTI here?
MOVEMENT_OPS = {
Ops.RESHAPE: lambda t, u: t.reshape(u.shape),
Ops.EXPAND: lambda t, u: t.expand(u.shape),
@@ -195,17 +196,30 @@ def fill_scalar(x, y):
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
def _local_scalar_dense(tensor): return unwrap(tensor).item()
def as_strided_view(base:Tensor, size, stride, storage_offset):
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
if all(st != 0 for st in stride): return base.shrink(((storage_offset, storage_offset + prod(size)),)).reshape(size) # no broadcast
return base.shrink(((storage_offset, storage_offset + prod(non_broadcast_size)),)) \
.reshape(tuple(s if st != 0 else 1 for s, st in zip(size, stride))).expand(size)
def as_strided_gather(base:Tensor, size, stride, storage_offset):
indices = Tensor.full(size, storage_offset, dtype=dtypes.int32, device=base.device)
for dim, (sz, st) in enumerate(zip(size, stride)):
if st != 0: indices += (Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st).reshape((1,) * dim + (sz,) + (1,) * (len(size) - dim - 1))
return base[indices.flatten()].reshape(size)
def is_contiguous_strides(base:Tensor, size, stride, storage_offset):
# check if stride pattern matches row-major layout (can use pure movement ops, stay view)
non_broadcast_stride = tuple(st for st in stride if st != 0)
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
return non_broadcast_stride == strides_for_shape(non_broadcast_size) and storage_offset + prod(size) <= base.shape[0]
@wrap_view_op
def _as_strided(tensor:Tensor, size, stride, storage_offset=0):
# use movement ops for simple cases (view), gather for complex strides (copy)
base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten()
if prod(size) == 1: return base[storage_offset].reshape(size)
indices = Tensor.zeros(size, dtype=dtypes.int32, device=base.device) + storage_offset
for dim, (sz, st) in enumerate(zip(size, stride)):
if st != 0:
dim_range = Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st
shape_for_broadcast = [1] * dim + [sz] + [1] * (len(size) - dim - 1)
indices = indices + dim_range.reshape(shape_for_broadcast)
result = base[indices.flatten()].reshape(size)
if is_contiguous_strides(base, size, stride, storage_offset): result = as_strided_view(base, size, stride, storage_offset)
else: result = as_strided_gather(base, size, stride, storage_offset)
result._as_strided_base = base
return result