ResNet: match implementation with Nvidia and PyTorch (#770)

* Match ResNet implementation with pytorch and nvidia

* Reduce number of Epochs
This commit is contained in:
Jacky Lee
2023-05-10 09:01:22 -07:00
committed by GitHub
parent b80cf9220c
commit d13629cb26
2 changed files with 4 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ if __name__ == "__main__":
lambda x: x / 255.0,
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
])
for _ in range(10):
for _ in range(5):
optimizer = optim.SGD(optim.get_parameters(model), lr=lr, momentum=0.9)
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)