mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
fix time calc for sharded
This commit is contained in:
@@ -1488,13 +1488,13 @@ class ShardedVicuna(VicunaBase):
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params)
|
||||
|
||||
prefill_time = time.time() - decode_st_time
|
||||
decode_time = (time.time() - decode_st_time) * 1000
|
||||
|
||||
_token = generated_token_op["token"]
|
||||
_past_key_values = generated_token_op["past_key_values"]
|
||||
_detok = generated_token_op["detok"]
|
||||
history.append(_token)
|
||||
yield self.tokenizer.decode(history), None, prefill_time
|
||||
yield self.tokenizer.decode(history), None, decode_time
|
||||
|
||||
if _token == 2:
|
||||
break
|
||||
@@ -2097,14 +2097,14 @@ class UnshardedVicuna(VicunaBase):
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
)
|
||||
prefill_time = time.time() - prefill_st_time
|
||||
prefill_time_ms = (time.time() - prefill_st_time) * 1000
|
||||
|
||||
token = generated_token_op["token"]
|
||||
if "cpu" not in self.device:
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, None, prefill_time
|
||||
yield detok, None, prefill_time_ms
|
||||
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
@@ -2439,8 +2439,7 @@ if __name__ == "__main__":
|
||||
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
|
||||
if msg is None:
|
||||
if is_first:
|
||||
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
|
||||
prefill_time_ms = exec_time * 1000
|
||||
prefill_time_ms = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
token_times_ms.append(exec_time)
|
||||
|
||||
Reference in New Issue
Block a user