mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add LUNIQUE op (#13554)
This commit is contained in:
@@ -1,8 +1,21 @@
|
||||
import itertools
|
||||
from examples.beautiful_mnist import Model
|
||||
from typing import Callable
|
||||
from tinygrad import nn, Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv, trange, partition
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: list[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
# TODO: refactor this into optim/onnx
|
||||
def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0.001, b1=0.9, b2=0.999, eps=1e-6) -> Tensor:
|
||||
b1_t *= b1
|
||||
|
||||
Reference in New Issue
Block a user