mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove assign contiguous hack (#12659)
* remove assign contiguous hack * remove bad contiguous usage in torch backend * assign
This commit is contained in:
@@ -155,16 +155,14 @@ def index_tensor(x, y):
|
||||
def zero_(x):
|
||||
if TORCH_DEBUG: print(f"zero_ {x.shape}")
|
||||
tt = unwrap(x)
|
||||
# NOTE: unconditional contiguous covers if x is contiguous (match it) or if x is view (realize for inplace)
|
||||
# TODO: consolidate
|
||||
tt.assign(tt.zeros_like().contiguous())
|
||||
tt.assign(tt.zeros_like())
|
||||
|
||||
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
|
||||
@inplace_fn("x")
|
||||
def fill_scalar(x, y):
|
||||
if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}")
|
||||
tt = unwrap(x)
|
||||
tt.assign(tt.full_like(y).contiguous())
|
||||
tt.assign(tt.full_like(y))
|
||||
|
||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||
|
||||
@@ -129,6 +129,7 @@ class TestAssign(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
|
||||
|
||||
@unittest.skip("assign to contiguous shouldn't change the base buffer")
|
||||
def test_assign_changes_buffer_alt(self):
|
||||
a, b = [Tensor(Tensor(0).contiguous().realize().uop.as_buf()) for _ in range(2)]
|
||||
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
|
||||
|
||||
@@ -3177,6 +3177,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||
|
||||
@unittest.skip("we have test_linalg, no need to test here. TODO: should be in torch backend tests")
|
||||
def test_svd(self):
|
||||
# test for tiny backend. real svd tests are in test_linalg
|
||||
A = torch.randn(5, 5)
|
||||
|
||||
@@ -92,9 +92,6 @@ earliest_rewrites = PatternMatcher([
|
||||
|
||||
# realize before assign if input permutes the target buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
|
||||
|
||||
# contiguous buffer is buffer, this is for *correctness* of assign, not just speed
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.BUFFER),)), lambda root: root.src[0].forced_reshape(root.shape).rtag(root.tag)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user