add int64 as supported dtype from numpy (#699)

* add int64 as supported dtype from numpy

Without this, examples/transformer.py didn't run. With this change it runs successfully.

* Update helpers.py

* Update transformer.py

* Update training.py
This commit is contained in:
Fernando Vidal
2023-03-18 17:15:04 -07:00
committed by GitHub
parent f355b02987
commit 73bd0b217b
2 changed files with 2 additions and 2 deletions

View File

@@ -14,7 +14,7 @@ def make_dataset():
s = i+j
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
random.shuffle(ds)
ds = np.array(ds)
ds = np.array(ds).astype(np.float32)
ds_X = ds[:, 0:6]
ds_Y = np.copy(ds[:, 1:])
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]