diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index d5993f64b0..615b2a8074 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -155,16 +155,14 @@ def index_tensor(x, y): def zero_(x): if TORCH_DEBUG: print(f"zero_ {x.shape}") tt = unwrap(x) - # NOTE: unconditional contiguous covers if x is contiguous (match it) or if x is view (realize for inplace) - # TODO: consolidate - tt.assign(tt.zeros_like().contiguous()) + tt.replace(tt.zeros_like()) @torch.library.impl("aten::fill_.Scalar", "privateuseone") @inplace_fn("x") def fill_scalar(x, y): if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}") tt = unwrap(x) - tt.assign(tt.full_like(y).contiguous()) + tt.replace(tt.full_like(y)) @torch.library.impl("aten::_local_scalar_dense", "privateuseone") def _local_scalar_dense(tensor): return unwrap(tensor).item()