mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
move gpt2/llama sampling inside the model call (#3013)
* move gpt2/llama sampling inside the model call * argmax uses one more kernel
This commit is contained in:
@@ -74,7 +74,7 @@ if __name__ == "__main__":
|
||||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
while 1:
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
|
||||
@@ -91,12 +91,12 @@ class Transformer:
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos, mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))[:, -1, :].flatten()
|
||||
logits = self.lm_head(self.ln_f(h))[:, -1, :]
|
||||
if temperature < 1e-6:
|
||||
ret = (logits == logits.max())
|
||||
ret = logits.argmax(-1)
|
||||
else:
|
||||
ret = (logits / temperature).softmax()
|
||||
return (ret.half() if HALF else ret).realize()
|
||||
ret = (logits / temperature).softmax().multinomial()
|
||||
return ret.flatten().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
@@ -140,17 +140,14 @@ class GPT2:
|
||||
GlobalCounters.reset()
|
||||
if timing: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("total ", enabled=timing):
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
|
||||
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
|
||||
else:
|
||||
tokens = Tensor([x[start_pos:] for x in toks])
|
||||
probs = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().flatten().numpy().tolist()
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
|
||||
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
|
||||
else:
|
||||
tokens = Tensor([x[start_pos:] for x in toks])
|
||||
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
|
||||
start_pos = len(toks[0])
|
||||
for i,t in enumerate(tok): toks[i].append(t)
|
||||
return [self.tokenizer.decode(x) for x in toks]
|
||||
|
||||
@@ -413,9 +413,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
tok = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).item()
|
||||
|
||||
# use the kv cache
|
||||
start_pos = len(toks)
|
||||
|
||||
@@ -50,7 +50,7 @@ if __name__ == "__main__":
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
|
||||
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).multinomial().item()
|
||||
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).item()
|
||||
toks.append(tok)
|
||||
start_pos += 1
|
||||
print(spp.decode(toks))
|
||||
|
||||
@@ -118,10 +118,12 @@ class Transformer:
|
||||
h = self.tok_embeddings(tokens)
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))[:, -1, :].flatten()
|
||||
logits = self.output(self.norm(h))[:, -1, :]
|
||||
if temperature < 1e-6:
|
||||
return (logits == logits.max()).half().realize()
|
||||
return (logits / temperature).softmax().half().realize()
|
||||
ret = logits.argmax(-1)
|
||||
else:
|
||||
ret = (logits / temperature).softmax().multinomial()
|
||||
return ret.realize()
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 190 if CI else 718, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
|
||||
Reference in New Issue
Block a user