olmoe memory usage cleanups

This commit is contained in:
George Hotz
2025-03-19 12:28:18 +08:00
parent 2c87a22cf2
commit 865f23dd7b

View File

@@ -1,4 +1,5 @@
# https://arxiv.org/pdf/2409.02060
import time
import numpy as np
np.set_printoptions(suppress=True, linewidth=1000)
import functools
@@ -53,13 +54,17 @@ if __name__ == "__main__":
model_state_dict = nn.state.get_state_dict(model)
del model_state_dict['freqs_cis']
with Timing("fetch and load weights: "):
state = fetch_weights()
nhf_state = convert_from_huggingface(state, model, 16, 16)
with Timing("load weights to GPU: "):
nhf_state = convert_from_huggingface(fetch_weights(), model, 16, 16)
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float()
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
with Timing("unpack weights: "):
nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False)
assert len(nhf_state) == 0
Tensor.realize(*list(nn.state.get_state_dict(model).values()))
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
count = 30
temperature = 0
@@ -70,13 +75,17 @@ if __name__ == "__main__":
toks = [12092]
start_pos = 0
timings = []
for i in range(count):
GlobalCounters.reset()
st = time.perf_counter()
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
timings.append(time.perf_counter()-st)
toks.append(tok)
start_pos += 1
print(toks)
print(tokenizer.decode(toks))
print(f"fastest token {min(timings)*1e3:.2f} ms, {1/min(timings):.1f} tok/s")
if temperature == 0:
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a