mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make gpt2 decode output just once at the end (#2869)
also updated function name from greedy_until to generate, as it's not greedy nor until
This commit is contained in:
@@ -132,7 +132,7 @@ class GPT2:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
||||
def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||
start_pos = 0
|
||||
@@ -153,8 +153,7 @@ class GPT2:
|
||||
tok = probs.multinomial().flatten().numpy().tolist()
|
||||
start_pos = len(toks[0])
|
||||
for i,t in enumerate(tok): toks[i].append(t)
|
||||
output = [self.tokenizer.decode(x) for x in toks]
|
||||
return output
|
||||
return [self.tokenizer.decode(x) for x in toks]
|
||||
|
||||
# **** main code ****
|
||||
|
||||
@@ -189,7 +188,7 @@ if __name__ == "__main__":
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
else:
|
||||
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
||||
texts = gpt2.generate(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
||||
if not args.noshow:
|
||||
print('Generating text...')
|
||||
if len(texts) == 1: print(texts[0])
|
||||
|
||||
Reference in New Issue
Block a user