diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 61282653..f4aa8224 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1488,13 +1488,13 @@ class ShardedVicuna(VicunaBase): generated_token_op = self.generate_new_token(params=params) - prefill_time = time.time() - decode_st_time + decode_time = (time.time() - decode_st_time) * 1000 _token = generated_token_op["token"] _past_key_values = generated_token_op["past_key_values"] _detok = generated_token_op["detok"] history.append(_token) - yield self.tokenizer.decode(history), None, prefill_time + yield self.tokenizer.decode(history), None, decode_time if _token == 2: break @@ -2097,14 +2097,14 @@ class UnshardedVicuna(VicunaBase): generated_token_op = self.generate_new_token( params=params, sharded=False, cli=cli ) - prefill_time = time.time() - prefill_st_time + prefill_time_ms = (time.time() - prefill_st_time) * 1000 token = generated_token_op["token"] if "cpu" not in self.device: logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] - yield detok, None, prefill_time + yield detok, None, prefill_time_ms res_tokens.append(token) if cli: @@ -2439,8 +2439,7 @@ if __name__ == "__main__": vic.shark_model.shark_runner.iree_config.device.flush_profiling() if msg is None: if is_first: - # Note that the prefill time is in seconds, and all the decoded tokens in ms. - prefill_time_ms = exec_time * 1000 + prefill_time_ms = exec_time is_first = False else: token_times_ms.append(exec_time)