Files
AMD-SHARK-Studio/shark/examples/shark_inference/sparse_arch.py
Huang Qi 66abee8e5b SharkInference: Fix various examples and README.md (#1903)
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name'
for SharkInference.
2023-10-19 09:28:36 -05:00

312 lines
8.9 KiB
Python

import torch
from torch import nn
from torchrec.datasets.utils import Batch
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from typing import Dict, List, Optional, Tuple
from torchrec.models.dlrm import (
choose,
DenseArch,
DLRM,
InteractionArch,
SparseArch,
OverArch,
)
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
def calculate_offsets(tensor_list, prev_values, prev_offsets):
offset_init = 0
offset_list = []
values_list = []
if prev_offsets != None:
offset_init = prev_values.shape[-1]
for tensor in tensor_list:
offset_list.append(offset_init)
offset_init += tensor.shape[0]
concatendated_tensor_list = torch.cat(tensor_list)
if prev_values != None:
concatendated_tensor_list = torch.cat(
[prev_values, concatendated_tensor_list]
)
concatenated_offsets = torch.tensor(offset_list)
if prev_offsets != None:
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
return concatendated_tensor_list, concatenated_offsets
# Have to make combined_keys as dict as to which embedding bags they
# point to. {f1: 0, f3: 0, f2: 1}
# The result will be a triple containing values, indices and pointer tensor.
def to_list(key_jagged, combined_keys):
key_jagged_dict = key_jagged.to_dict()
combined_list = []
for key in combined_keys:
prev_values, prev_offsets = calculate_offsets(
key_jagged_dict[key].to_dense(), None, None
)
print(prev_values)
print(prev_offsets)
combined_list.append(prev_values)
combined_list.append(prev_offsets)
combined_list.append(torch.tensor(combined_keys[key]))
return combined_list
class SparseArchShark(nn.Module):
def create_emb(self, embedding_dim, num_embeddings_list):
embedding_list = nn.ModuleList()
for i in range(0, num_embeddings_list.size):
num_embeddings = num_embeddings_list[i]
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
W = np.random.uniform(
low=-np.sqrt(1 / num_embeddings),
high=np.sqrt(1 / num_embeddings),
size=(num_embeddings, embedding_dim),
).astype(np.float32)
EE.weight.data = torch.tensor(W, requires_grad=True)
embedding_list.append(EE)
return embedding_list
def __init__(
self,
embedding_dim,
total_features,
num_embeddings_list,
):
super(SparseArchShark, self).__init__()
self.embedding_dim = embedding_dim
self.num_features = total_features
self.embedding_list = self.create_emb(
embedding_dim, num_embeddings_list
)
def forward(self, *batched_inputs):
concatenated_list = []
input_enum, embedding_enum = 0, 0
for k in range(len(batched_inputs) // 3):
values = batched_inputs[input_enum]
input_enum += 1
offsets = batched_inputs[input_enum]
input_enum += 1
embedding_pointer = int(batched_inputs[input_enum])
input_enum += 1
E = self.embedding_list[embedding_pointer]
V = E(values, offsets)
concatenated_list.append(V)
return torch.cat(concatenated_list, dim=1).reshape(
-1, self.num_features, self.embedding_dim
)
def test_sparse_arch() -> None:
D = 3
eb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=D,
num_embeddings=10,
feature_names=["f1", "f3"],
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=D,
num_embeddings=10,
feature_names=["f2"],
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
w1 = ebc.embedding_bags["t1"].weight
w2 = ebc.embedding_bags["t2"].weight
sparse_arch = SparseArch(ebc)
keys = ["f1", "f2", "f3", "f4", "f5"]
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
features = KeyedJaggedTensor.from_offsets_sync(
keys=keys,
values=torch.tensor(
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
),
offsets=offsets,
)
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
sparse_archi.embedding_list[0].weight = w1
sparse_archi.embedding_list[1].weight = w2
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
test_results = sparse_archi(*inputs)
sparse_features = sparse_arch(features)
torch.allclose(
sparse_features,
test_results,
rtol=1e-4,
atol=1e-4,
)
test_sparse_arch()
class DLRMShark(nn.Module):
def __init__(
self,
embedding_dim,
total_features,
num_embeddings_list,
dense_in_features: int,
dense_arch_layer_sizes: List[int],
over_arch_layer_sizes: List[int],
) -> None:
super().__init__()
self.sparse_arch: SparseArchShark = SparseArchShark(
embedding_dim, total_features, num_embeddings_list
)
num_sparse_features: int = total_features
self.dense_arch = DenseArch(
in_features=dense_in_features,
layer_sizes=dense_arch_layer_sizes,
)
self.inter_arch = InteractionArch(
num_sparse_features=num_sparse_features,
)
over_in_features: int = (
embedding_dim
+ choose(num_sparse_features, 2)
+ num_sparse_features
)
self.over_arch = OverArch(
in_features=over_in_features,
layer_sizes=over_arch_layer_sizes,
)
def forward(
self, dense_features: torch.Tensor, *sparse_features
) -> torch.Tensor:
embedded_dense = self.dense_arch(dense_features)
embedded_sparse = self.sparse_arch(*sparse_features)
concatenated_dense = self.inter_arch(
dense_features=embedded_dense, sparse_features=embedded_sparse
)
logits = self.over_arch(concatenated_dense)
return logits
def test_dlrm() -> None:
B = 2
D = 8
dense_in_features = 100
eb1_config = EmbeddingBagConfig(
name="t1",
embedding_dim=D,
num_embeddings=100,
feature_names=["f1", "f3"],
)
eb2_config = EmbeddingBagConfig(
name="t2",
embedding_dim=D,
num_embeddings=100,
feature_names=["f2"],
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f3", "f2"],
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
)
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
sparse_nn = DLRM(
embedding_bag_collection=ebc,
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)
sparse_nn_nod = DLRMShark(
embedding_dim=8,
total_features=3,
num_embeddings_list=np.array([100, 100]),
dense_in_features=dense_in_features,
dense_arch_layer_sizes=[20, D],
over_arch_layer_sizes=[5, 1],
)
dense_features = torch.rand((B, dense_in_features))
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
w1 = ebc.embedding_bags["t1"].weight
w2 = ebc.embedding_bags["t2"].weight
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
logits = sparse_nn(
dense_features=dense_features,
sparse_features=sparse_features,
)
logits_nod = sparse_nn_nod(dense_features, *x)
# print(logits)
# print(logits_nod)
# Import the module and print.
mlir_importer = SharkImporter(
sparse_nn_nod,
(dense_features, *x),
frontend="torch",
)
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
tracing_required=True
)
shark_module = SharkInference(
dlrm_mlir, device="cpu", mlir_dialect="linalg"
)
shark_module.compile()
result = shark_module.forward(inputs)
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
torch.allclose(
logits,
logits_nod,
rtol=1e-4,
atol=1e-4,
)
test_dlrm()