mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove .float call in llama logit (#11598)
* remove .float call in llama logit * bfloat item
This commit is contained in:
@@ -1392,11 +1392,12 @@ def train_llama3():
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens, grad_acc)
|
||||
loss = loss.float().item()
|
||||
# above as tqdm.write f-string
|
||||
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if getenv("CKPT") and (i % 200 == 0 or i == 10):
|
||||
tqdm.write("saving checkpoint")
|
||||
|
||||
@@ -185,8 +185,7 @@ class Transformer:
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
# TODO: remove .float()?
|
||||
logits = self.output(self.norm(h)).float()
|
||||
logits = self.output(self.norm(h))
|
||||
if math.isnan(temperature): return logits
|
||||
|
||||
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
Reference in New Issue
Block a user