render colors

This commit is contained in:
George Hotz
2025-10-17 18:58:44 +08:00
parent 05f69b48e9
commit 7c80285fa8
2 changed files with 15 additions and 6 deletions

View File

@@ -48,6 +48,7 @@ class TestPcontig(unittest.TestCase):
attn_output.weight.requires_grad_().realize()
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v)
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
attn = attn_output(attn)
@@ -57,14 +58,14 @@ class TestPcontig(unittest.TestCase):
return q,k,v
with Context(PCONTIG=2, DEBUG=2):
GlobalCounters.reset()
q,k,v = fa_bw()
grads = q.grad, k.grad, v.grad
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
GlobalCounters.reset()
q,k,v = fa_bw()
cmp_grads = q.grad, k.grad, v.grad
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
@@ -76,14 +77,15 @@ class TestPcontig(unittest.TestCase):
def fa():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
GlobalCounters.reset()
return q.scaled_dot_product_attention(k, v).realize()
with Context(PCONTIG=2, DEBUG=2):
GlobalCounters.reset()
ret = fa()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
GlobalCounters.reset()
cmp = fa()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):
mse = ((cmp-ret)**2).sum().item()
print(f"mse: {mse}")