mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
multiout at every level
This commit is contained in:
@@ -52,22 +52,25 @@ class TestPcontig(unittest.TestCase):
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
#print("****\n\n\n\n\n")
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
ret = [out, Tensor.stack(q.grad, k.grad, v.grad)]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
Tensor.realize(out, q.grad, k.grad, v.grad)
|
||||
return q,k,v
|
||||
return out, q.grad, k.grad, v.grad
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
q,k,v = fa_bw()
|
||||
grads = q.grad, k.grad, v.grad
|
||||
grads= fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=2):
|
||||
q,k,v = fa_bw()
|
||||
cmp_grads = q.grad, k.grad, v.grad
|
||||
cmp_grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=0):
|
||||
|
||||
Reference in New Issue
Block a user