hotfix convert Y_train to int before passing into index (#2850)

This commit is contained in:
chenyu
2023-12-19 11:40:56 -05:00
committed by GitHub
parent fec8e9060c
commit 6d7e9e0a56

View File

@@ -263,7 +263,7 @@ def train_cifar():
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
# one-hot encode labels
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)]
# preprocess data
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)