mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
update torch 2.8 (#12172)
support _reshape_alias. something is wrong with one case of unfold
This commit is contained in:
@@ -177,22 +177,28 @@ def cached_to_movement_ops(shape, st) -> list:
|
||||
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
|
||||
|
||||
@wrap_view_op
|
||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
|
||||
# multiple as_strided do not compound
|
||||
base = canonical_base(tensor)
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
|
||||
ret = base
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
|
||||
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
|
||||
return ret
|
||||
|
||||
@torch.library.impl("aten::as_strided", "privateuseone")
|
||||
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
|
||||
storage_offset = storage_offset or tensor.storage_offset()
|
||||
@wrap_view_op
|
||||
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
|
||||
# multiple as_strided do not compound
|
||||
base = canonical_base(tensor)
|
||||
# TODO: this is heavyweight
|
||||
st = ShapeTracker(base.uop.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
|
||||
ret = base
|
||||
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
|
||||
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
|
||||
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
|
||||
return ret
|
||||
return _as_strided(tensor, size, stride, storage_offset)
|
||||
|
||||
@torch.library.impl("aten::_reshape_alias", "privateuseone")
|
||||
def _reshape_alias(tensor:torch.Tensor, size, stride):
|
||||
return _as_strided(tensor, size, stride)
|
||||
|
||||
@torch.library.impl("aten::empty_strided", "privateuseone")
|
||||
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
|
||||
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
|
||||
|
||||
Reference in New Issue
Block a user