mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
Follow https://github.com/nod-ai/SHARK/pull/708, remove parameter 'func_name' for SharkInference.
312 lines
8.9 KiB
Python
312 lines
8.9 KiB
Python
import torch
|
|
from torch import nn
|
|
from torchrec.datasets.utils import Batch
|
|
from torchrec.modules.crossnet import LowRankCrossNet
|
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
|
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
|
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
|
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
|
from typing import Dict, List, Optional, Tuple
|
|
from torchrec.models.dlrm import (
|
|
choose,
|
|
DenseArch,
|
|
DLRM,
|
|
InteractionArch,
|
|
SparseArch,
|
|
OverArch,
|
|
)
|
|
from shark.shark_inference import SharkInference
|
|
from shark.shark_importer import SharkImporter
|
|
import numpy as np
|
|
|
|
torch.manual_seed(0)
|
|
|
|
np.random.seed(0)
|
|
|
|
|
|
def calculate_offsets(tensor_list, prev_values, prev_offsets):
|
|
offset_init = 0
|
|
offset_list = []
|
|
values_list = []
|
|
|
|
if prev_offsets != None:
|
|
offset_init = prev_values.shape[-1]
|
|
for tensor in tensor_list:
|
|
offset_list.append(offset_init)
|
|
offset_init += tensor.shape[0]
|
|
|
|
concatendated_tensor_list = torch.cat(tensor_list)
|
|
|
|
if prev_values != None:
|
|
concatendated_tensor_list = torch.cat(
|
|
[prev_values, concatendated_tensor_list]
|
|
)
|
|
|
|
concatenated_offsets = torch.tensor(offset_list)
|
|
|
|
if prev_offsets != None:
|
|
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
|
|
|
|
return concatendated_tensor_list, concatenated_offsets
|
|
|
|
|
|
# Have to make combined_keys as dict as to which embedding bags they
|
|
# point to. {f1: 0, f3: 0, f2: 1}
|
|
# The result will be a triple containing values, indices and pointer tensor.
|
|
def to_list(key_jagged, combined_keys):
|
|
key_jagged_dict = key_jagged.to_dict()
|
|
combined_list = []
|
|
|
|
for key in combined_keys:
|
|
prev_values, prev_offsets = calculate_offsets(
|
|
key_jagged_dict[key].to_dense(), None, None
|
|
)
|
|
print(prev_values)
|
|
print(prev_offsets)
|
|
combined_list.append(prev_values)
|
|
combined_list.append(prev_offsets)
|
|
combined_list.append(torch.tensor(combined_keys[key]))
|
|
|
|
return combined_list
|
|
|
|
|
|
class SparseArchShark(nn.Module):
|
|
def create_emb(self, embedding_dim, num_embeddings_list):
|
|
embedding_list = nn.ModuleList()
|
|
for i in range(0, num_embeddings_list.size):
|
|
num_embeddings = num_embeddings_list[i]
|
|
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
|
|
W = np.random.uniform(
|
|
low=-np.sqrt(1 / num_embeddings),
|
|
high=np.sqrt(1 / num_embeddings),
|
|
size=(num_embeddings, embedding_dim),
|
|
).astype(np.float32)
|
|
EE.weight.data = torch.tensor(W, requires_grad=True)
|
|
embedding_list.append(EE)
|
|
return embedding_list
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim,
|
|
total_features,
|
|
num_embeddings_list,
|
|
):
|
|
super(SparseArchShark, self).__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.num_features = total_features
|
|
self.embedding_list = self.create_emb(
|
|
embedding_dim, num_embeddings_list
|
|
)
|
|
|
|
def forward(self, *batched_inputs):
|
|
concatenated_list = []
|
|
input_enum, embedding_enum = 0, 0
|
|
|
|
for k in range(len(batched_inputs) // 3):
|
|
values = batched_inputs[input_enum]
|
|
input_enum += 1
|
|
offsets = batched_inputs[input_enum]
|
|
input_enum += 1
|
|
embedding_pointer = int(batched_inputs[input_enum])
|
|
input_enum += 1
|
|
|
|
E = self.embedding_list[embedding_pointer]
|
|
V = E(values, offsets)
|
|
concatenated_list.append(V)
|
|
|
|
return torch.cat(concatenated_list, dim=1).reshape(
|
|
-1, self.num_features, self.embedding_dim
|
|
)
|
|
|
|
|
|
def test_sparse_arch() -> None:
|
|
D = 3
|
|
eb1_config = EmbeddingBagConfig(
|
|
name="t1",
|
|
embedding_dim=D,
|
|
num_embeddings=10,
|
|
feature_names=["f1", "f3"],
|
|
)
|
|
eb2_config = EmbeddingBagConfig(
|
|
name="t2",
|
|
embedding_dim=D,
|
|
num_embeddings=10,
|
|
feature_names=["f2"],
|
|
)
|
|
|
|
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
|
|
|
w1 = ebc.embedding_bags["t1"].weight
|
|
w2 = ebc.embedding_bags["t2"].weight
|
|
|
|
sparse_arch = SparseArch(ebc)
|
|
|
|
keys = ["f1", "f2", "f3", "f4", "f5"]
|
|
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
|
|
features = KeyedJaggedTensor.from_offsets_sync(
|
|
keys=keys,
|
|
values=torch.tensor(
|
|
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
|
|
),
|
|
offsets=offsets,
|
|
)
|
|
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
|
|
sparse_archi.embedding_list[0].weight = w1
|
|
sparse_archi.embedding_list[1].weight = w2
|
|
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
|
|
|
test_results = sparse_archi(*inputs)
|
|
sparse_features = sparse_arch(features)
|
|
|
|
torch.allclose(
|
|
sparse_features,
|
|
test_results,
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
|
|
|
|
test_sparse_arch()
|
|
|
|
|
|
class DLRMShark(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim,
|
|
total_features,
|
|
num_embeddings_list,
|
|
dense_in_features: int,
|
|
dense_arch_layer_sizes: List[int],
|
|
over_arch_layer_sizes: List[int],
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.sparse_arch: SparseArchShark = SparseArchShark(
|
|
embedding_dim, total_features, num_embeddings_list
|
|
)
|
|
num_sparse_features: int = total_features
|
|
|
|
self.dense_arch = DenseArch(
|
|
in_features=dense_in_features,
|
|
layer_sizes=dense_arch_layer_sizes,
|
|
)
|
|
|
|
self.inter_arch = InteractionArch(
|
|
num_sparse_features=num_sparse_features,
|
|
)
|
|
|
|
over_in_features: int = (
|
|
embedding_dim
|
|
+ choose(num_sparse_features, 2)
|
|
+ num_sparse_features
|
|
)
|
|
|
|
self.over_arch = OverArch(
|
|
in_features=over_in_features,
|
|
layer_sizes=over_arch_layer_sizes,
|
|
)
|
|
|
|
def forward(
|
|
self, dense_features: torch.Tensor, *sparse_features
|
|
) -> torch.Tensor:
|
|
embedded_dense = self.dense_arch(dense_features)
|
|
embedded_sparse = self.sparse_arch(*sparse_features)
|
|
concatenated_dense = self.inter_arch(
|
|
dense_features=embedded_dense, sparse_features=embedded_sparse
|
|
)
|
|
logits = self.over_arch(concatenated_dense)
|
|
return logits
|
|
|
|
|
|
def test_dlrm() -> None:
|
|
B = 2
|
|
D = 8
|
|
dense_in_features = 100
|
|
|
|
eb1_config = EmbeddingBagConfig(
|
|
name="t1",
|
|
embedding_dim=D,
|
|
num_embeddings=100,
|
|
feature_names=["f1", "f3"],
|
|
)
|
|
eb2_config = EmbeddingBagConfig(
|
|
name="t2",
|
|
embedding_dim=D,
|
|
num_embeddings=100,
|
|
feature_names=["f2"],
|
|
)
|
|
|
|
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
|
|
|
sparse_features = KeyedJaggedTensor.from_offsets_sync(
|
|
keys=["f1", "f3", "f2"],
|
|
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
|
|
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
|
|
)
|
|
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
|
sparse_nn = DLRM(
|
|
embedding_bag_collection=ebc,
|
|
dense_in_features=dense_in_features,
|
|
dense_arch_layer_sizes=[20, D],
|
|
over_arch_layer_sizes=[5, 1],
|
|
)
|
|
sparse_nn_nod = DLRMShark(
|
|
embedding_dim=8,
|
|
total_features=3,
|
|
num_embeddings_list=np.array([100, 100]),
|
|
dense_in_features=dense_in_features,
|
|
dense_arch_layer_sizes=[20, D],
|
|
over_arch_layer_sizes=[5, 1],
|
|
)
|
|
|
|
dense_features = torch.rand((B, dense_in_features))
|
|
|
|
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
|
|
|
|
w1 = ebc.embedding_bags["t1"].weight
|
|
w2 = ebc.embedding_bags["t2"].weight
|
|
|
|
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
|
|
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
|
|
|
|
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
|
|
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
|
|
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
|
|
|
|
logits = sparse_nn(
|
|
dense_features=dense_features,
|
|
sparse_features=sparse_features,
|
|
)
|
|
logits_nod = sparse_nn_nod(dense_features, *x)
|
|
|
|
# print(logits)
|
|
# print(logits_nod)
|
|
|
|
# Import the module and print.
|
|
mlir_importer = SharkImporter(
|
|
sparse_nn_nod,
|
|
(dense_features, *x),
|
|
frontend="torch",
|
|
)
|
|
|
|
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
|
tracing_required=True
|
|
)
|
|
|
|
shark_module = SharkInference(
|
|
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
|
)
|
|
shark_module.compile()
|
|
result = shark_module.forward(inputs)
|
|
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
|
|
|
torch.allclose(
|
|
logits,
|
|
logits_nod,
|
|
rtol=1e-4,
|
|
atol=1e-4,
|
|
)
|
|
|
|
|
|
test_dlrm()
|