better test demonstration (#3077)

* a better test demonstration

* fix white space
This commit is contained in:
Yixiang Gao
2024-01-10 10:50:52 -08:00
committed by GitHub
parent 507e0afba0
commit 6842476ca6

View File

@@ -160,33 +160,38 @@ class TestMultiTensor(unittest.TestCase):
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
y = layer_norm(x)
# for norm layers, the weights are duplicated
# for norm layers, the correct way to shard weights is duplication
layer_norm_sharded = RMSNorm(embed_size)
layer_norm_sharded.weight.shard_((d0, d1), axis=None).realize()
# if x is being sharded then all reduce is involved
# if x is being sharded, then all-reduce is involved
x_sharded = x.shard((d0, d1), axis=2).realize()
y_shard = layer_norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
# if x is copyed, then the operations remain inside each GPU
# if x is being duplicated, then the operations remain inside each GPU
# which is the common case
x_sharded = x.shard((d0, d1), axis=None).realize()
y_shard = layer_norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_scaled_product_attention(self):
q = Tensor.rand(32, 8, 16, 64).contiguous().realize()
k = Tensor.rand(32, 8, 16, 64).contiguous().realize()
v = Tensor.rand(32, 8, 16, 64).contiguous().realize()
bs, n_heads, seq_len, head_dim = 1, 8, 4, 32
q = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
k = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
v = Tensor.rand(bs, n_heads, seq_len, head_dim).contiguous().realize()
y = Tensor.scaled_dot_product_attention(q, k, v)
# scaled dot product attention performs k.transpose(-2, -1) internally which
# prevent you from sharding those axis but more importantly we can avoid all-reduce
# if we shard along the n_heads axis
q_sharded = q.shard((d0, d1), axis=None).realize()
k_sharded = k.shard((d0, d1), axis=1).realize()
v_sharded = v.shard((d0, d1), axis=1).realize()
y_sharded = Tensor.scaled_dot_product_attention(q_sharded, k_sharded, v_sharded)
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-6, rtol=1e-6)
m = Tensor.rand(32, 8, 16, 16).contiguous().realize()
m = Tensor.rand(1, 8, 4, 4).contiguous().realize()
y = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=m)
m_sharded = m.shard((d0, d1), axis=None).realize()