mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.empty is RESHAPE(BUFFER) (#8987)
* empty is RESHAPE(BUFFER)
* eh
* add test_empty_buf
* can we unsupport this
* linter
* Revert "can we unsupport this"
This reverts commit 0f71e1aadb.
This commit is contained in:
@@ -2151,9 +2151,9 @@ class TestBigGraph(unittest.TestCase):
|
||||
a = Tensor.empty(4, 4, dtype=dtypes.int)
|
||||
sink = tensor_rewrite(a*0)
|
||||
assert UPat(Ops.CONST, arg=0).match(sink, {})
|
||||
self.assertIs(tensor_rewrite(a*1), a.lazydata)
|
||||
self.assertIs(tensor_rewrite(a+0), a.lazydata)
|
||||
self.assertIs(tensor_rewrite(a//1), a.lazydata)
|
||||
self.assertIs(tensor_rewrite(a*1).base, a.lazydata.base)
|
||||
self.assertIs(tensor_rewrite(a+0).base, a.lazydata.base)
|
||||
self.assertIs(tensor_rewrite(a//1).base, a.lazydata.base)
|
||||
|
||||
def test_cast_folding(self):
|
||||
a = Tensor(1.0).cast(dtypes.int)
|
||||
@@ -2310,7 +2310,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
self.assertIs(b.base, a.base)
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
|
||||
@@ -118,5 +118,13 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
#is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
|
||||
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
|
||||
|
||||
def test_empty_buf(self):
|
||||
a = Tensor.empty(3, 3)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
vi = UOp.variable("i", 1, 3).bind(1)
|
||||
a = Tensor.empty(3, vi)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
self.assertEqual(a.lazydata.base.realized.size, 9)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -490,8 +490,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if op is Ops.BIND:
|
||||
var, val = arg.unbind()
|
||||
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
||||
# otherwise it's just a VIEW(BUFFER)
|
||||
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
|
||||
# otherwise it's just a RESHAPE(BUFFER)
|
||||
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
||||
return UOp.new_buffer(device, size, dtype).reshape(shape)
|
||||
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
||||
@@ -547,7 +548,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return self.src[0].buf_uop
|
||||
@property
|
||||
def buffer(self) -> Buffer:
|
||||
if self.op is Ops.VIEW:
|
||||
if self is not self.base:
|
||||
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
||||
return self.src[0].buffer
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
|
||||
Reference in New Issue
Block a user