remove .float call in llama logit (#11598)

* remove .float call in llama logit

* bfloat item
This commit is contained in:
chenyu
2025-08-09 21:02:18 -07:00
committed by GitHub
parent dd3d2eb36c
commit ef17af85c6
2 changed files with 4 additions and 4 deletions

View File

@@ -1392,11 +1392,12 @@ def train_llama3():
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens, grad_acc)
loss = loss.float().item()
# above as tqdm.write f-string
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if getenv("CKPT") and (i % 200 == 0 or i == 10):
tqdm.write("saving checkpoint")

View File

@@ -185,8 +185,7 @@ class Transformer:
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
# TODO: remove .float()?
logits = self.output(self.norm(h)).float()
logits = self.output(self.norm(h))
if math.isnan(temperature): return logits
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)