Tensor.empty is RESHAPE(BUFFER) (#8987)

* empty is RESHAPE(BUFFER)

* eh

* add test_empty_buf

* can we unsupport this

* linter

* Revert "can we unsupport this"

This reverts commit 0f71e1aadb.
This commit is contained in:
qazal
2025-02-09 18:42:51 +01:00
committed by GitHub
parent 44479f8ad6
commit 7eba5fb413
3 changed files with 16 additions and 7 deletions

View File

@@ -2151,9 +2151,9 @@ class TestBigGraph(unittest.TestCase):
a = Tensor.empty(4, 4, dtype=dtypes.int)
sink = tensor_rewrite(a*0)
assert UPat(Ops.CONST, arg=0).match(sink, {})
self.assertIs(tensor_rewrite(a*1), a.lazydata)
self.assertIs(tensor_rewrite(a+0), a.lazydata)
self.assertIs(tensor_rewrite(a//1), a.lazydata)
self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base)
self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base)
self.assertIs(tensor_rewrite(a//1).base, a.lazydata.base)
def test_cast_folding(self):
a = Tensor(1.0).cast(dtypes.int)
@@ -2310,7 +2310,7 @@ class TestCopyFolding(unittest.TestCase):
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
self.assertIs(b.base, a.base)
def test_clone(self):
a = Tensor.empty(4).lazydata

View File

@@ -118,5 +118,13 @@ class TestTensorUopRepresentation(unittest.TestCase):
#is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
def test_empty_buf(self):
a = Tensor.empty(3, 3)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
vi = UOp.variable("i", 1, 3).bind(1)
a = Tensor.empty(3, vi)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
self.assertEqual(a.lazydata.base.realized.size, 9)
if __name__ == '__main__':
unittest.main()

View File

@@ -490,8 +490,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if op is Ops.BIND:
var, val = arg.unbind()
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
# otherwise it's just a VIEW(BUFFER)
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
# otherwise it's just a RESHAPE(BUFFER)
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
return UOp.new_buffer(device, size, dtype).reshape(shape)
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
@@ -547,7 +548,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return self.src[0].buf_uop
@property
def buffer(self) -> Buffer:
if self.op is Ops.VIEW:
if self is not self.base:
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
return self.src[0].buffer
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"