mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
olmoe memory usage cleanups
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# https://arxiv.org/pdf/2409.02060
|
||||
import time
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True, linewidth=1000)
|
||||
import functools
|
||||
@@ -53,13 +54,17 @@ if __name__ == "__main__":
|
||||
model_state_dict = nn.state.get_state_dict(model)
|
||||
del model_state_dict['freqs_cis']
|
||||
|
||||
with Timing("fetch and load weights: "):
|
||||
state = fetch_weights()
|
||||
nhf_state = convert_from_huggingface(state, model, 16, 16)
|
||||
with Timing("load weights to GPU: "):
|
||||
nhf_state = convert_from_huggingface(fetch_weights(), model, 16, 16)
|
||||
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
|
||||
for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float()
|
||||
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
|
||||
|
||||
with Timing("unpack weights: "):
|
||||
nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False)
|
||||
assert len(nhf_state) == 0
|
||||
Tensor.realize(*list(nn.state.get_state_dict(model).values()))
|
||||
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
|
||||
|
||||
count = 30
|
||||
temperature = 0
|
||||
@@ -70,13 +75,17 @@ if __name__ == "__main__":
|
||||
|
||||
toks = [12092]
|
||||
start_pos = 0
|
||||
timings = []
|
||||
for i in range(count):
|
||||
GlobalCounters.reset()
|
||||
st = time.perf_counter()
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
|
||||
timings.append(time.perf_counter()-st)
|
||||
toks.append(tok)
|
||||
start_pos += 1
|
||||
print(toks)
|
||||
print(tokenizer.decode(toks))
|
||||
print(f"fastest token {min(timings)*1e3:.2f} ms, {1/min(timings):.1f} tok/s")
|
||||
|
||||
if temperature == 0:
|
||||
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
|
||||
|
||||
Reference in New Issue
Block a user