This commit is contained in:
George Hotz
2025-10-17 19:11:19 +08:00
parent 7c80285fa8
commit bc9048ccca
2 changed files with 8 additions and 5 deletions

View File

@@ -28,7 +28,10 @@ class TestRangeifyEdgeCase(unittest.TestCase):
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
if getenv("BIG") > 1:
if getenv("BIG") > 2:
# llama 8B (8192)
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128
elif getenv("BIG") > 1:
# llama 8B
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
elif getenv("BIG") > 0:
@@ -51,10 +54,10 @@ class TestPcontig(unittest.TestCase):
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v)
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
attn = attn_output(attn)
loss = (attn - target).square().mean()
out = attn_output(attn)
loss = (out - target).square().mean()
loss.backward()
Tensor.realize(q.grad, k.grad, v.grad)
Tensor.realize(out, q.grad, k.grad, v.grad)
return q,k,v
with Context(PCONTIG=2, DEBUG=2):

View File

@@ -185,7 +185,7 @@ class ExecItem:
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
f" {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
f" {self.prg.display_name+' '*(46-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")
self.prg.first_run = False