mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix convert Y_train to int before passing into index (#2850)
This commit is contained in:
@@ -263,7 +263,7 @@ def train_cifar():
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)]
|
||||
# preprocess data
|
||||
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user