diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 4c56b32ddf..7c55d204da 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -21,6 +21,9 @@ class Model: if __name__ == "__main__": X_train, Y_train, X_test, Y_test = mnist() + # TODO: remove this when HIP is fixed + X_train, X_test = X_train.float(), X_test.float() + model = Model() opt = nn.optim.Adam(nn.state.get_parameters(model))