tensor inference (#6156)

* tensor inference

* test is even better name
This commit is contained in:
George Hotz
2024-08-18 00:19:28 -07:00
committed by GitHub
parent 9db2d0d5c6
commit 17a043edad
3 changed files with 12 additions and 11 deletions

View File

@@ -24,16 +24,17 @@ if __name__ == "__main__":
opt = nn.optim.Adam(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
with Tensor.train():
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
@TinyJit
@Tensor.test()
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')

View File

@@ -642,11 +642,11 @@ class TestTrainMode(unittest.TestCase):
assert not Tensor.training
class TestInferenceMode(unittest.TestCase):
def test_inference_mode(self):
def test_inference(self):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
with Tensor.inference_mode():
with Tensor.test():
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
@@ -663,7 +663,7 @@ class TestInferenceMode(unittest.TestCase):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
@Tensor.inference_mode()
@Tensor.test()
def f(x, m, W):
tmp = x.mul(m)
mm = tmp.matmul(W)

View File

@@ -157,7 +157,7 @@ class Tensor:
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
class inference_mode(ContextDecorator):
class test(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev