mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test the behavior, not the implementation (#3419)
This commit is contained in:
@@ -4,7 +4,7 @@ import unittest
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node
|
||||
from tinygrad.device import Compiled, Device, Buffer
|
||||
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
|
||||
@@ -42,16 +42,23 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_load_cache_const_bufs(self):
|
||||
# make sure const buffers are differentiated from local and mem buffers
|
||||
a = Tensor([1,2,3,4])
|
||||
out = a[2] + 2 + a[3] + 3 + 2 + a[0]
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
lin = Linearizer(si.ast)
|
||||
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int
|
||||
VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST))
|
||||
|
||||
# data1[0] + VAL
|
||||
a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL))
|
||||
# (literal const 1) + VAL
|
||||
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
|
||||
lin = Linearizer(ast)
|
||||
lin.linearize()
|
||||
|
||||
cache_keys = lin.load_cache.keys()
|
||||
assert len(cache_keys) == 5
|
||||
assert len([k for k in cache_keys if "CONST" in k]) == 2
|
||||
assert len([k for k in cache_keys if "CONST" not in k]) == 3
|
||||
a_bufs = [u.uop for u in lin.uops[-2].vin[0].vin]
|
||||
b_bufs = [u.uop for u in lin.uops[-2].vin[1].vin]
|
||||
|
||||
assert a_bufs == [UOps.LOAD, UOps.CONST]
|
||||
assert b_bufs == [UOps.CONST, UOps.CONST]
|
||||
|
||||
def test_upcast_cse(self):
|
||||
# when upcasting, within a subtree, there may be common expressions.
|
||||
|
||||
Reference in New Issue
Block a user