beam=16 makes gpt2 gpu-time < 5ms on 3090 (#2154)

This commit is contained in:
nimlgen
2023-10-27 23:21:27 +03:00
committed by GitHub
parent 5204864eca
commit 8d41b3eb3f
2 changed files with 3 additions and 5 deletions

View File

@@ -50,7 +50,7 @@ class Attention:
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
# save the cache
cache_k, cache_v = keys, values
cache_k, cache_v = keys.realize(), values.realize()
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
output = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.c_proj(output), cache_k, cache_v

View File

@@ -9,13 +9,11 @@ from collections import defaultdict
from tinygrad.codegen.optimizer import Opt, OptOps
actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)])
actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)])
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,16]] for axis in range(5)])
actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)])
actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)])
actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256),
Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256),
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
]