move gpt2/llama sampling inside the model call (#3013)

* move gpt2/llama sampling inside the model call

* argmax uses one more kernel
This commit is contained in:
chenyu
2024-01-04 17:01:50 -05:00
committed by GitHub
parent c2a044ed83
commit f88506e630
6 changed files with 21 additions and 24 deletions

View File

@@ -74,7 +74,7 @@ if __name__ == "__main__":
turn = not turn
old_output_len = len(outputted)
while 1:
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
start_pos = len(toks)
toks.append(tok)
outputted = output(outputted, toks, "blue" if not turn else "cyan")

View File

@@ -91,12 +91,12 @@ class Transformer:
for hi in self.h: h = hi(h, start_pos, mask)
logits = self.lm_head(self.ln_f(h))[:, -1, :].flatten()
logits = self.lm_head(self.ln_f(h))[:, -1, :]
if temperature < 1e-6:
ret = (logits == logits.max())
ret = logits.argmax(-1)
else:
ret = (logits / temperature).softmax()
return (ret.half() if HALF else ret).realize()
ret = (logits / temperature).softmax().multinomial()
return ret.flatten().realize()
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
@@ -140,17 +140,14 @@ class GPT2:
GlobalCounters.reset()
if timing: print("")
st = GlobalCounters.time_sum_s
with Timing("total ", enabled=timing):
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
else:
tokens = Tensor([x[start_pos:] for x in toks])
probs = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().flatten().numpy().tolist()
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
if batch_size == 1 and len(toks[0][start_pos:]) == 1:
tokens = Variable("tokens", 0, VOCAB_SIZE).bind(toks[0][start_pos])
else:
tokens = Tensor([x[start_pos:] for x in toks])
tok = self.model(tokens, Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature).numpy().tolist()
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
return [self.tokenizer.decode(x) for x in toks]

View File

@@ -413,9 +413,7 @@ After you are done speaking, output [EOS]. You are not Chad.
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().item()
tok = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).item()
# use the kv cache
start_pos = len(toks)

View File

@@ -50,7 +50,7 @@ if __name__ == "__main__":
for i in range(args.count):
GlobalCounters.reset()
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).multinomial().item()
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).item()
toks.append(tok)
start_pos += 1
print(spp.decode(toks))

View File

@@ -118,10 +118,12 @@ class Transformer:
h = self.tok_embeddings(tokens)
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype).triu(start_pos+1).realize() if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))[:, -1, :].flatten()
logits = self.output(self.norm(h))[:, -1, :]
if temperature < 1e-6:
return (logits == logits.max()).half().realize()
return (logits / temperature).softmax().half().realize()
ret = logits.argmax(-1)
else:
ret = (logits / temperature).softmax().multinomial()
return ret.realize()
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
# TODO: better way to handle the first call v.s. the rest?

View File

@@ -77,7 +77,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 190 if CI else 718, all_jitted=True)
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 14.9, 191 if CI else 719, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):