multiout at every level

This commit is contained in:
George Hotz
2025-10-17 19:32:38 +08:00
parent bc9048ccca
commit 28efb4395c

View File

@@ -52,22 +52,25 @@ class TestPcontig(unittest.TestCase):
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v)
#print("****\n\n\n\n\n")
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
out = attn_output(attn)
loss = (out - target).square().mean()
loss.backward()
ret = [out, Tensor.stack(q.grad, k.grad, v.grad)]
Tensor.realize(*ret)
return ret
Tensor.realize(out, q.grad, k.grad, v.grad)
return q,k,v
return out, q.grad, k.grad, v.grad
with Context(PCONTIG=2, DEBUG=2):
q,k,v = fa_bw()
grads = q.grad, k.grad, v.grad
grads= fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
q,k,v = fa_bw()
cmp_grads = q.grad, k.grad, v.grad
cmp_grads = fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):