fix time calc for sharded

This commit is contained in:
PhaneeshB
2023-12-06 01:18:31 +05:30
parent 93f583f0be
commit eab2194ca1

View File

@@ -1488,13 +1488,13 @@ class ShardedVicuna(VicunaBase):
generated_token_op = self.generate_new_token(params=params)
prefill_time = time.time() - decode_st_time
decode_time = (time.time() - decode_st_time) * 1000
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
history.append(_token)
yield self.tokenizer.decode(history), None, prefill_time
yield self.tokenizer.decode(history), None, decode_time
if _token == 2:
break
@@ -2097,14 +2097,14 @@ class UnshardedVicuna(VicunaBase):
generated_token_op = self.generate_new_token(
params=params, sharded=False, cli=cli
)
prefill_time = time.time() - prefill_st_time
prefill_time_ms = (time.time() - prefill_st_time) * 1000
token = generated_token_op["token"]
if "cpu" not in self.device:
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok, None, prefill_time
yield detok, None, prefill_time_ms
res_tokens.append(token)
if cli:
@@ -2439,8 +2439,7 @@ if __name__ == "__main__":
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
if msg is None:
if is_first:
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
prefill_time_ms = exec_time * 1000
prefill_time_ms = exec_time
is_first = False
else:
token_times_ms.append(exec_time)