remove assign contiguous hack (#12659)

* remove assign contiguous hack

* remove bad contiguous usage in torch backend

* assign
This commit is contained in:
George Hotz
2025-10-14 16:42:14 +08:00
committed by GitHub
parent 30ee7c4c26
commit fb61f3519f
4 changed files with 4 additions and 7 deletions

View File

@@ -155,16 +155,14 @@ def index_tensor(x, y):
def zero_(x):
if TORCH_DEBUG: print(f"zero_ {x.shape}")
tt = unwrap(x)
# NOTE: unconditional contiguous covers if x is contiguous (match it) or if x is view (realize for inplace)
# TODO: consolidate
tt.assign(tt.zeros_like().contiguous())
tt.assign(tt.zeros_like())
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
@inplace_fn("x")
def fill_scalar(x, y):
if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}")
tt = unwrap(x)
tt.assign(tt.full_like(y).contiguous())
tt.assign(tt.full_like(y))
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
def _local_scalar_dense(tensor): return unwrap(tensor).item()

View File

@@ -129,6 +129,7 @@ class TestAssign(unittest.TestCase):
@unittest.expectedFailure
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
@unittest.skip("assign to contiguous shouldn't change the base buffer")
def test_assign_changes_buffer_alt(self):
a, b = [Tensor(Tensor(0).contiguous().realize().uop.as_buf()) for _ in range(2)]
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))

View File

@@ -3177,6 +3177,7 @@ class TestOps(unittest.TestCase):
def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
@unittest.skip("we have test_linalg, no need to test here. TODO: should be in torch backend tests")
def test_svd(self):
# test for tiny backend. real svd tests are in test_linalg
A = torch.randn(5, 5)

View File

@@ -92,9 +92,6 @@ earliest_rewrites = PatternMatcher([
# realize before assign if input permutes the target buffer
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
# contiguous buffer is buffer, this is for *correctness* of assign, not just speed
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.BUFFER),)), lambda root: root.src[0].forced_reshape(root.shape).rtag(root.tag)),
])
# *****************