mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
beam=16 makes gpt2 gpu-time < 5ms on 3090 (#2154)
This commit is contained in:
@@ -50,7 +50,7 @@ class Attention:
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
|
||||
# save the cache
|
||||
cache_k, cache_v = keys, values
|
||||
cache_k, cache_v = keys.realize(), values.realize()
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
output = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.c_proj(output), cache_k, cache_v
|
||||
|
||||
@@ -9,13 +9,11 @@ from collections import defaultdict
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
|
||||
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,16]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
|
||||
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
|
||||
actions += [
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256),
|
||||
Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256),
|
||||
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user