mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
render colors
This commit is contained in:
@@ -48,6 +48,7 @@ class TestPcontig(unittest.TestCase):
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
attn = attn_output(attn)
|
||||
@@ -57,14 +58,14 @@ class TestPcontig(unittest.TestCase):
|
||||
return q,k,v
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
q,k,v = fa_bw()
|
||||
grads = q.grad, k.grad, v.grad
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
q,k,v = fa_bw()
|
||||
cmp_grads = q.grad, k.grad, v.grad
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
|
||||
with Context(DEBUG=0):
|
||||
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
|
||||
@@ -76,14 +77,15 @@ class TestPcontig(unittest.TestCase):
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
ret = fa()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
cmp = fa()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=0):
|
||||
mse = ((cmp-ret)**2).sum().item()
|
||||
print(f"mse: {mse}")
|
||||
|
||||
Reference in New Issue
Block a user