test the behavior, not the implementation (#3419)

This commit is contained in:
qazal
2024-02-15 18:23:42 +02:00
committed by GitHub
parent b1c0d8c99d
commit e1a57fe58a

View File

@@ -4,7 +4,7 @@ import unittest
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
@@ -42,16 +42,23 @@ class TestLinearizer(unittest.TestCase):
def test_load_cache_const_bufs(self):
# make sure const buffers are differentiated from local and mem buffers
a = Tensor([1,2,3,4])
out = a[2] + 2 + a[3] + 3 + 2 + a[0]
si = create_schedule([out.lazydata])[-1]
lin = Linearizer(si.ast)
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int
VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST))
# data1[0] + VAL
a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL))
# (literal const 1) + VAL
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
lin = Linearizer(ast)
lin.linearize()
cache_keys = lin.load_cache.keys()
assert len(cache_keys) == 5
assert len([k for k in cache_keys if "CONST" in k]) == 2
assert len([k for k in cache_keys if "CONST" not in k]) == 3
a_bufs = [u.uop for u in lin.uops[-2].vin[0].vin]
b_bufs = [u.uop for u in lin.uops[-2].vin[1].vin]
assert a_bufs == [UOps.LOAD, UOps.CONST]
assert b_bufs == [UOps.CONST, UOps.CONST]
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.