mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add sharded Falcon support
This commit is contained in:
@@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, shark_decoder_layer_module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.model = shark_decoder_layer_module
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
@@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
|
||||
@@ -150,7 +150,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
@@ -159,6 +159,25 @@ class ShardedFalcon(SharkLLMBase):
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(self, layer, falconCompileInput, layer_id):
|
||||
# Determine number of available devices
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
|
||||
if layer_id == "word_embeddings":
|
||||
device_idx = 0 % num_devices
|
||||
elif layer_id == "ln_f":
|
||||
device_idx = 1 % num_devices
|
||||
elif layer_id == "lm_head":
|
||||
device_idx = 2 % num_devices
|
||||
elif type(layer_id) == int:
|
||||
device_idx = layer_id % num_devices
|
||||
else:
|
||||
raise ValueError("Falcon: Unknow layer encountered")
|
||||
|
||||
device_idx = device_idx if self.device == "rocm" else None
|
||||
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
@@ -177,10 +196,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
@@ -257,6 +279,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
@@ -276,7 +299,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
@@ -295,7 +318,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head = self.compile_layer(
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head, [sample_hidden_states], "lm_head"
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
@@ -304,7 +327,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding = self.compile_layer(
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding, [sample_input_ids], "word_embeddings"
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
@@ -313,7 +336,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f, [sample_hidden_states], "ln_f"
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
@@ -321,12 +346,19 @@ class ShardedFalcon(SharkLLMBase):
|
||||
print("Compiling Layer {}".format(i))
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
pytorch_layer_i = DecoderLayer(layer_i)
|
||||
shark_module = self.compile_layer(
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
i,
|
||||
)
|
||||
shark_layer_i = CompiledDecoderLayer(shark_module)
|
||||
del shark_module
|
||||
shark_layer_i = CompiledDecoderLayer(
|
||||
i,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
|
||||
Reference in New Issue
Block a user