fully local tensor const representation: CONST(VIEW(DEVICE)) [pr] (#8389)

This commit is contained in:
qazal
2024-12-24 10:15:56 +02:00
committed by GitHub
parent b589dec06e
commit 3a556a7e8b
5 changed files with 24 additions and 30 deletions

View File

@@ -96,7 +96,7 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_literal_one_pow(self):
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
# this fails because of DETACH, it shouldn't
@unittest.expectedFailure
# update: passes after CONST(VIEW(DEVICE)) in tensor
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))

View File

@@ -1987,7 +1987,7 @@ class TestBigGraph(unittest.TestCase):
check_schedule(x, 1)
tensor_const_pm = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST, src=()))), lambda: True),
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),
])
class TestConst(unittest.TestCase):

View File

@@ -3,7 +3,7 @@ from tinygrad import Tensor
from tinygrad.ops import UPat, Ops
realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.CONST)))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {})
class TestTensorUopRepresentation(unittest.TestCase):
@@ -22,7 +22,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
def test_const_pattern(self):
a = Tensor(1)
print(a.lazydata)
is_pattern(a, const_pattern)
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
is_pattern(a, UPat.cvar("x")) # even cvar works!
def test_consts_do_not_realize(self):
a = Tensor(1)