update torch 2.8 (#12172)

support _reshape_alias. something is wrong with one case of unfold
This commit is contained in:
chenyu
2025-09-14 15:19:03 -04:00
committed by GitHub
parent 98ecab7563
commit 12a910f1d2
4 changed files with 21 additions and 14 deletions

View File

@@ -177,22 +177,28 @@ def cached_to_movement_ops(shape, st) -> list:
from tinygrad.shape.shapetracker import ShapeTracker, View
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
@wrap_view_op
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
# multiple as_strided do not compound
base = canonical_base(tensor)
# TODO: this is heavyweight
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
ret = base
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
return ret
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
storage_offset = storage_offset or tensor.storage_offset()
@wrap_view_op
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
# multiple as_strided do not compound
base = canonical_base(tensor)
# TODO: this is heavyweight
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
ret = base
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
return ret
return _as_strided(tensor, size, stride, storage_offset)
@torch.library.impl("aten::_reshape_alias", "privateuseone")
def _reshape_alias(tensor:torch.Tensor, size, stride):
return _as_strided(tensor, size, stride)
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")