mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
as_strided view vs copy
This commit is contained in:
@@ -82,6 +82,7 @@ view_ops = {
|
|||||||
|
|
||||||
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
|
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
|
||||||
|
|
||||||
|
# TODO do we want Ops.MULTI here?
|
||||||
MOVEMENT_OPS = {
|
MOVEMENT_OPS = {
|
||||||
Ops.RESHAPE: lambda t, u: t.reshape(u.shape),
|
Ops.RESHAPE: lambda t, u: t.reshape(u.shape),
|
||||||
Ops.EXPAND: lambda t, u: t.expand(u.shape),
|
Ops.EXPAND: lambda t, u: t.expand(u.shape),
|
||||||
@@ -195,17 +196,30 @@ def fill_scalar(x, y):
|
|||||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||||
|
|
||||||
|
def as_strided_view(base:Tensor, size, stride, storage_offset):
|
||||||
|
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
|
||||||
|
if all(st != 0 for st in stride): return base.shrink(((storage_offset, storage_offset + prod(size)),)).reshape(size) # no broadcast
|
||||||
|
return base.shrink(((storage_offset, storage_offset + prod(non_broadcast_size)),)) \
|
||||||
|
.reshape(tuple(s if st != 0 else 1 for s, st in zip(size, stride))).expand(size)
|
||||||
|
|
||||||
|
def as_strided_gather(base:Tensor, size, stride, storage_offset):
|
||||||
|
indices = Tensor.full(size, storage_offset, dtype=dtypes.int32, device=base.device)
|
||||||
|
for dim, (sz, st) in enumerate(zip(size, stride)):
|
||||||
|
if st != 0: indices += (Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st).reshape((1,) * dim + (sz,) + (1,) * (len(size) - dim - 1))
|
||||||
|
return base[indices.flatten()].reshape(size)
|
||||||
|
|
||||||
|
def is_contiguous_strides(base:Tensor, size, stride, storage_offset):
|
||||||
|
# check if stride pattern matches row-major layout (can use pure movement ops, stay view)
|
||||||
|
non_broadcast_stride = tuple(st for st in stride if st != 0)
|
||||||
|
non_broadcast_size = tuple(s for s, st in zip(size, stride) if st != 0)
|
||||||
|
return non_broadcast_stride == strides_for_shape(non_broadcast_size) and storage_offset + prod(size) <= base.shape[0]
|
||||||
|
|
||||||
@wrap_view_op
|
@wrap_view_op
|
||||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=0):
|
def _as_strided(tensor:Tensor, size, stride, storage_offset=0):
|
||||||
|
# use movement ops for simple cases (view), gather for complex strides (copy)
|
||||||
base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten()
|
base = getattr(tensor, "_as_strided_base", canonical_base(tensor)).flatten()
|
||||||
if prod(size) == 1: return base[storage_offset].reshape(size)
|
if is_contiguous_strides(base, size, stride, storage_offset): result = as_strided_view(base, size, stride, storage_offset)
|
||||||
indices = Tensor.zeros(size, dtype=dtypes.int32, device=base.device) + storage_offset
|
else: result = as_strided_gather(base, size, stride, storage_offset)
|
||||||
for dim, (sz, st) in enumerate(zip(size, stride)):
|
|
||||||
if st != 0:
|
|
||||||
dim_range = Tensor.arange(sz, device=base.device, dtype=dtypes.int32) * st
|
|
||||||
shape_for_broadcast = [1] * dim + [sz] + [1] * (len(size) - dim - 1)
|
|
||||||
indices = indices + dim_range.reshape(shape_for_broadcast)
|
|
||||||
result = base[indices.flatten()].reshape(size)
|
|
||||||
result._as_strided_base = base
|
result._as_strided_base = base
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user