TIP 3 - Tensor realization spec tests (#8288)

This commit is contained in:
qazal
2024-12-17 18:04:50 +02:00
committed by GitHub
parent f16188b8c0
commit c0d4346b5a

View File

@@ -26,5 +26,43 @@ class TestTensorUopRepresentation(unittest.TestCase):
# NOTE: this is wrong, COPY has an extra buffer for some reason
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern,)))
@unittest.expectedFailure
def test_const(self):
a = Tensor(1).realize()
is_pattern(a, realized_pattern)
# NOTE: for VIEW of CONST we have two options:
# a. realize the base, expand
# b. realize the view
# depending on which one we pick you can comment out the other assert
def _assert_realized_const(self, a:Tensor):
# a.
# NOTE: this needs to rewrite a VIEW(BUFFER, <op>) that folded to VIEW(BUFFER, CONST) to a STORE(BUFFER, ShapeTracker.from_shape(()), CONST)
# while keeping the BUFFER around (to mark the tensor_uop as realized)
realized_pattern.match(a.lazydata.base, realized_pattern)
self.assertEqual(a.lazydata.base.realized.size, 1) # NOTE: the BUFFER may resize (eg. if it's a Tensor(4,4)*0, we push the movement op to a VIEW)
self.assertEqual(a.lazydata.op, Ops.EXPAND)
# b.
# NOTE: this option is like calling .contiguous() on all the Tensors passed into realize
realized_pattern.match(a.lazydata, realized_pattern)
@unittest.expectedFailure
def test_const_view(self):
a = Tensor(1).expand(4, 4).realize()
self._assert_realized_const(a)
@unittest.expectedFailure
def test_late_const_fold_simple(self):
a = ((Tensor([1, 2, 3])+1) * (1-1)).realize()
self._assert_realized_const(a)
# NOTE: this behaves like calling .contiguous() on all the Tensors passed into realize
@unittest.expectedFailure
def test_late_const_fold_complex(self):
a = Tensor.uniform(16, 3, 3, 3).realize()
is_pattern(a, realized_pattern)
self.assertEqual(a.lazydata.realized.size, 432)
if __name__ == '__main__':
unittest.main()