add a better dedup test for DEFINE_VAR with CONST arg (#6813)

This commit is contained in:
qazal
2024-09-30 15:43:55 +08:00
committed by GitHub
parent e7fcbe1a4d
commit 4a4aa69b84
2 changed files with 9 additions and 31 deletions

View File

@@ -2,12 +2,7 @@ import unittest, pickle
import numpy as np
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable
from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import PtrDType, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
@@ -71,32 +66,6 @@ class TestPickle(unittest.TestCase):
sched_pk = pickle.loads(pk)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
def test_pickle_define_var(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
x2,)),
UOp.define_var("i", dtypes.int, 1, 10),
x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12,
UOp(UOps.CONST, dtypes.int, arg=3, src=()),
x14,)),)),)),)),)),)),))
p = Kernel(ast).to_program(name_override="test")
ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test")
self.assertEqual(ps.src, p.src)
class TestPickleJIT(unittest.TestCase):
@classmethod
def setUpClass(cls):