make gpt2 decode output just once at the end (#2869)

also updated function name from greedy_until to generate, as it's not greedy nor until
This commit is contained in:
chenyu
2023-12-20 12:14:55 -05:00
committed by GitHub
parent e92069fb1c
commit 857c35d256

View File

@@ -132,7 +132,7 @@ class GPT2:
self.model = model
self.tokenizer = tokenizer
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
toks = [prompt_tokens[:] for _ in range(batch_size)]
start_pos = 0
@@ -153,8 +153,7 @@ class GPT2:
tok = probs.multinomial().flatten().numpy().tolist()
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
output = [self.tokenizer.decode(x) for x in toks]
return output
return [self.tokenizer.decode(x) for x in toks]
# **** main code ****
@@ -189,7 +188,7 @@ if __name__ == "__main__":
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
else:
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
if not args.noshow:
print('Generating text...')
if len(texts) == 1: print(texts[0])