From 82a61223f2067ca8fa9de745c1938a4244e5a710 Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:21:07 +0100 Subject: [PATCH] as_strided view vs copy --- extra/torch_backend/backend.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 67b8adcdee..d62ebcaff9 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -82,6 +82,7 @@ view_ops = { for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v)) +# TODO do we want Ops.MULTI here? MOVEMENT_OPS = { Ops.RESHAPE: lambda t, u: t.reshape(u.shape), Ops.EXPAND: lambda t, u: t.expand(u.shape), @@ -195,17 +196,30 @@ def fill_scalar(x, y): @torch.library.impl("aten::_local_scalar_dense", "privateuseone") def _local_scalar_dense(tensor): return unwrap(tensor).item() +def as_strided_view(base:Tensor, size, stride, storage_offset): + non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0) + if all(st != 0 for st in stride): return base.shrink(((storage_offset, storage_offset + prod(size)),)).reshape(size) # no broadcast + return base.shrink(((storage_offset, storage_offset + prod(non_broadcast_size)),)) \ + .reshape(tuple(s if st != 0 else 1 for s, st in zip(size, stride))).expand(size) + +def as_strided_gather(base:Tensor, size, stride, storage_offset): + indices = Tensor.full(size, storage_offset, dtype=dtypes.int32, device=base.device) + for dim, (sz, st) in enumerate(zip(size, stride)): + if st != 0: indices += (Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st).reshape((1,) * dim + (sz,) + (1,) * (len(size) - dim - 1)) + return base[indices.flatten()].reshape(size) + +def is_contiguous_strides(base:Tensor, size, stride, storage_offset): + # check if stride pattern matches row-major layout (can use pure movement ops, stay view) + non_broadcast_stride = tuple(st for st in stride if st != 0) + non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0) + return non_broadcast_stride == strides_for_shape(non_broadcast_size) and storage_offset + prod(size) <= base.shape[0] + @wrap_view_op def _as_strided(tensor:Tensor, size, stride, storage_offset=0): + # use movement ops for simple cases (view), gather for complex strides (copy) base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten() - if prod(size) == 1: return base[storage_offset].reshape(size) - indices = Tensor.zeros(size, dtype=dtypes.int32, device=base.device) + storage_offset - for dim, (sz, st) in enumerate(zip(size, stride)): - if st != 0: - dim_range = Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st - shape_for_broadcast = [1] * dim + [sz] + [1] * (len(size) - dim - 1) - indices = indices + dim_range.reshape(shape_for_broadcast) - result = base[indices.flatten()].reshape(size) + if is_contiguous_strides(base, size, stride, storage_offset): result = as_strided_view(base, size, stride, storage_offset) + else: result = as_strided_gather(base, size, stride, storage_offset) result._as_strided_base = base return result