From 1ea4876dfa2646123a9cbff5a82a76582909165f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 18 Mar 2025 15:25:45 -0400 Subject: [PATCH] olmoe touchups (#9499) GlobalCounters.reset() and only validate if temperature is 0 --- examples/olmoe.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/olmoe.py b/examples/olmoe.py index c642498193..e14b4d9508 100644 --- a/examples/olmoe.py +++ b/examples/olmoe.py @@ -1,10 +1,10 @@ # https://arxiv.org/pdf/2409.02060 import numpy as np np.set_printoptions(suppress=True, linewidth=1000) -import functools, collections, json -from tinygrad import Tensor, nn, Device -from tinygrad.helpers import tqdm, CI, Profiling, Timing, fetch, getenv -from extra.models.llama import Transformer, Variable, convert_from_huggingface +import functools +from tinygrad import Tensor, nn, Device, GlobalCounters +from tinygrad.helpers import Timing, getenv +from extra.models.llama import Transformer, convert_from_huggingface class MixtureFeedForward: def __init__(self, num_experts:int, activated_experts:int, dim:int, hidden_dim:int, linear=nn.Linear): @@ -71,13 +71,15 @@ if __name__ == "__main__": toks = [12092] start_pos = 0 for i in range(count): + GlobalCounters.reset() tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item() toks.append(tok) start_pos += 1 print(toks) print(tokenizer.decode(toks)) - # Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a - assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755, - 247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!" + if temperature == 0: + # Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a + assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755, + 247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!"