mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
as_strided view vs copy
This commit is contained in:
@@ -82,6 +82,7 @@ view_ops = {
|
||||
|
||||
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
|
||||
|
||||
# TODO do we want Ops.MULTI here?
|
||||
MOVEMENT_OPS = {
|
||||
Ops.RESHAPE: lambda t, u: t.reshape(u.shape),
|
||||
Ops.EXPAND: lambda t, u: t.expand(u.shape),
|
||||
@@ -195,17 +196,30 @@ def fill_scalar(x, y):
|
||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||
|
||||
def as_strided_view(base:Tensor, size, stride, storage_offset):
|
||||
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
|
||||
if all(st != 0 for st in stride): return base.shrink(((storage_offset, storage_offset + prod(size)),)).reshape(size) # no broadcast
|
||||
return base.shrink(((storage_offset, storage_offset + prod(non_broadcast_size)),)) \
|
||||
.reshape(tuple(s if st != 0 else 1 for s, st in zip(size, stride))).expand(size)
|
||||
|
||||
def as_strided_gather(base:Tensor, size, stride, storage_offset):
|
||||
indices = Tensor.full(size, storage_offset, dtype=dtypes.int32, device=base.device)
|
||||
for dim, (sz, st) in enumerate(zip(size, stride)):
|
||||
if st != 0: indices += (Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st).reshape((1,) * dim + (sz,) + (1,) * (len(size) - dim - 1))
|
||||
return base[indices.flatten()].reshape(size)
|
||||
|
||||
def is_contiguous_strides(base:Tensor, size, stride, storage_offset):
|
||||
# check if stride pattern matches row-major layout (can use pure movement ops, stay view)
|
||||
non_broadcast_stride = tuple(st for st in stride if st != 0)
|
||||
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
|
||||
return non_broadcast_stride == strides_for_shape(non_broadcast_size) and storage_offset + prod(size) <= base.shape[0]
|
||||
|
||||
@wrap_view_op
|
||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=0):
|
||||
# use movement ops for simple cases (view), gather for complex strides (copy)
|
||||
base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten()
|
||||
if prod(size) == 1: return base[storage_offset].reshape(size)
|
||||
indices = Tensor.zeros(size, dtype=dtypes.int32, device=base.device) + storage_offset
|
||||
for dim, (sz, st) in enumerate(zip(size, stride)):
|
||||
if st != 0:
|
||||
dim_range = Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st
|
||||
shape_for_broadcast = [1] * dim + [sz] + [1] * (len(size) - dim - 1)
|
||||
indices = indices + dim_range.reshape(shape_for_broadcast)
|
||||
result = base[indices.flatten()].reshape(size)
|
||||
if is_contiguous_strides(base, size, stride, storage_offset): result = as_strided_view(base, size, stride, storage_offset)
|
||||
else: result = as_strided_gather(base, size, stride, storage_offset)
|
||||
result._as_strided_base = base
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user