mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
test const pattern [pr] (#8304)
* test const pattern [pr] * add model to test_tiny
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# basic self-contained tests of the external functionality of tinygrad
|
||||
import unittest, random
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn
|
||||
from tinygrad.helpers import IMAGE
|
||||
|
||||
class TestTiny(unittest.TestCase):
|
||||
@@ -79,6 +79,25 @@ class TestTiny(unittest.TestCase):
|
||||
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
|
||||
self.assertEqual(ret.item(), s)
|
||||
|
||||
# *** a model ***
|
||||
|
||||
def test_mnist_model(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
# pre-realize random weights
|
||||
for p in nn.state.get_parameters(layers): p.realize()
|
||||
|
||||
# run model inference
|
||||
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
|
||||
self.assertEqual(len(probs[0]), 10)
|
||||
|
||||
# *** image ***
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
|
||||
@@ -3,6 +3,7 @@ from tinygrad import Tensor
|
||||
from tinygrad.ops import UPat, Ops
|
||||
|
||||
realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),))
|
||||
const_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.CONST)))
|
||||
def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {})
|
||||
|
||||
class TestTensorUopRepresentation(unittest.TestCase):
|
||||
@@ -18,6 +19,11 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
print(c.lazydata)
|
||||
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
|
||||
|
||||
def test_const_pattern(self):
|
||||
a = Tensor(1)
|
||||
print(a.lazydata)
|
||||
is_pattern(a, const_pattern)
|
||||
|
||||
def test_consts_do_not_realize(self):
|
||||
a = Tensor(1)
|
||||
print(a.lazydata)
|
||||
|
||||
Reference in New Issue
Block a user