if you like your transformers twice as slow, use the GPU

This commit is contained in:
George Hotz
2020-12-29 17:14:23 -05:00
parent 6a6a82e999
commit f9170505b3
4 changed files with 7 additions and 5 deletions

View File

@@ -64,7 +64,7 @@ class Transformer:
def forward(self, x):
bs = x.shape[0]
xnp = x.cpu().data
xnp = x.cpu().data.astype(np.int32)
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
for i in range(x.shape[1]):
onehot[range(bs), i, i] = 1