olmoe touchups (#9499)

GlobalCounters.reset() and only validate if temperature is 0
This commit is contained in:
chenyu
2025-03-18 15:25:45 -04:00
committed by GitHub
parent f7506c6c25
commit 1ea4876dfa

View File

@@ -1,10 +1,10 @@
# https://arxiv.org/pdf/2409.02060
import numpy as np
np.set_printoptions(suppress=True, linewidth=1000)
import functools, collections, json
from tinygrad import Tensor, nn, Device
from tinygrad.helpers import tqdm, CI, Profiling, Timing, fetch, getenv
from extra.models.llama import Transformer, Variable, convert_from_huggingface
import functools
from tinygrad import Tensor, nn, Device, GlobalCounters
from tinygrad.helpers import Timing, getenv
from extra.models.llama import Transformer, convert_from_huggingface
class MixtureFeedForward:
def __init__(self, num_experts:int, activated_experts:int, dim:int, hidden_dim:int, linear=nn.Linear):
@@ -71,13 +71,15 @@ if __name__ == "__main__":
toks = [12092]
start_pos = 0
for i in range(count):
GlobalCounters.reset()
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
toks.append(tok)
start_pos += 1
print(toks)
print(tokenizer.decode(toks))
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755,
247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!"
if temperature == 0:
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
assert toks == [12092, 13, 309, 717, 247, 747, 17782, 281, 436, 12209, 285, 309, 717, 2820, 281, 755,
247, 1805, 4685, 273, 253, 1027, 3510, 273, 941, 326, 476, 320, 7141, 275, 247], "BAD OUTPUT!"