mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
if you like your transformers twice as slow, use the GPU
This commit is contained in:
@@ -64,7 +64,7 @@ class Transformer:
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.cpu().data
|
||||
xnp = x.cpu().data.astype(np.int32)
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
|
||||
Reference in New Issue
Block a user