mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
better test demonstration (#3077)
* a better test demonstration * fix white space
This commit is contained in:
@@ -160,33 +160,38 @@ class TestMultiTensor(unittest.TestCase):
|
||||
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
|
||||
y = layer_norm(x)
|
||||
|
||||
# for norm layers, the weights are duplicated
|
||||
# for norm layers, the correct way to shard weights is duplication
|
||||
layer_norm_sharded = RMSNorm(embed_size)
|
||||
layer_norm_sharded.weight.shard_((d0, d1), axis=None).realize()
|
||||
|
||||
# if x is being sharded then all reduce is involved
|
||||
# if x is being sharded, then all-reduce is involved
|
||||
x_sharded = x.shard((d0, d1), axis=2).realize()
|
||||
y_shard = layer_norm_sharded(x_sharded).realize()
|
||||
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
# if x is copyed, then the operations remain inside each GPU
|
||||
# if x is being duplicated, then the operations remain inside each GPU
|
||||
# which is the common case
|
||||
x_sharded = x.shard((d0, d1), axis=None).realize()
|
||||
y_shard = layer_norm_sharded(x_sharded).realize()
|
||||
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
q = Tensor.rand(32, 8, 16, 64).contiguous().realize()
|
||||
k = Tensor.rand(32, 8, 16, 64).contiguous().realize()
|
||||
v = Tensor.rand(32, 8, 16, 64).contiguous().realize()
|
||||
bs, n_heads, seq_len, head_dim = 1, 8, 4, 32
|
||||
q = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
|
||||
k = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
|
||||
v = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
|
||||
y = Tensor.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# scaled dot product attention performs k.transpose(-2, -1) internally which
|
||||
# prevent you from sharding those axis but more importantly we can avoid all-reduce
|
||||
# if we shard along the n_heads axis
|
||||
q_sharded = q.shard((d0, d1), axis=None).realize()
|
||||
k_sharded = k.shard((d0, d1), axis=1).realize()
|
||||
v_sharded = v.shard((d0, d1), axis=1).realize()
|
||||
y_sharded = Tensor.scaled_dot_product_attention(q_sharded, k_sharded, v_sharded)
|
||||
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
m = Tensor.rand(32, 8, 16, 16).contiguous().realize()
|
||||
m = Tensor.rand(1, 8, 4, 4).contiguous().realize()
|
||||
y = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=m)
|
||||
|
||||
m_sharded = m.shard((d0, d1), axis=None).realize()
|
||||
|
||||
Reference in New Issue
Block a user