mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add a better dedup test for DEFINE_VAR with CONST arg (#6813)
This commit is contained in:
@@ -2,12 +2,7 @@ import unittest, pickle
|
||||
import numpy as np
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, TinyJit, Variable
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor(self):
|
||||
@@ -71,32 +66,6 @@ class TestPickle(unittest.TestCase):
|
||||
sched_pk = pickle.loads(pk)
|
||||
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
||||
|
||||
def test_pickle_define_var(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
|
||||
x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
|
||||
x2,)),
|
||||
UOp.define_var("i", dtypes.int, 1, 10),
|
||||
x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),
|
||||
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
|
||||
x12,
|
||||
UOp(UOps.CONST, dtypes.int, arg=3, src=()),
|
||||
x14,)),)),)),)),)),)),))
|
||||
p = Kernel(ast).to_program(name_override="test")
|
||||
ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test")
|
||||
self.assertEqual(ps.src, p.src)
|
||||
|
||||
class TestPickleJIT(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user