From 6d7e9e0a56a7fa3dddedb56e2afb4a95274a7a89 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 19 Dec 2023 11:40:56 -0500 Subject: [PATCH] hotfix convert Y_train to int before passing into index (#2850) --- examples/hlb_cifar10.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index a4e6d6eea7..1128694e3b 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -263,7 +263,7 @@ def train_cifar(): X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float() Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float() # one-hot encode labels - Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test] + Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)] # preprocess data X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)