mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -24,16 +24,17 @@ if __name__ == "__main__":
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
|
||||
@TinyJit
|
||||
@Tensor.test()
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
|
||||
@@ -642,11 +642,11 @@ class TestTrainMode(unittest.TestCase):
|
||||
assert not Tensor.training
|
||||
|
||||
class TestInferenceMode(unittest.TestCase):
|
||||
def test_inference_mode(self):
|
||||
def test_inference(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
with Tensor.inference_mode():
|
||||
with Tensor.test():
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
@@ -663,7 +663,7 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
@Tensor.inference_mode()
|
||||
@Tensor.test()
|
||||
def f(x, m, W):
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
|
||||
@@ -157,7 +157,7 @@ class Tensor:
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
class inference_mode(ContextDecorator):
|
||||
class test(ContextDecorator):
|
||||
def __init__(self, mode:bool = True): self.mode = mode
|
||||
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
||||
|
||||
Reference in New Issue
Block a user