mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
Add sharded Falcon support
This commit is contained in:
@@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, shark_decoder_layer_module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.model = shark_decoder_layer_module
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
@@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
|
||||
@@ -150,7 +150,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
@@ -159,6 +159,25 @@ class ShardedFalcon(SharkLLMBase):
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(self, layer, falconCompileInput, layer_id):
|
||||
# Determine number of available devices
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
|
||||
if layer_id == "word_embeddings":
|
||||
device_idx = 0 % num_devices
|
||||
elif layer_id == "ln_f":
|
||||
device_idx = 1 % num_devices
|
||||
elif layer_id == "lm_head":
|
||||
device_idx = 2 % num_devices
|
||||
elif type(layer_id) == int:
|
||||
device_idx = layer_id % num_devices
|
||||
else:
|
||||
raise ValueError("Falcon: Unknow layer encountered")
|
||||
|
||||
device_idx = device_idx if self.device == "rocm" else None
|
||||
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
@@ -177,10 +196,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
@@ -257,6 +279,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
@@ -276,7 +299,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
@@ -295,7 +318,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head = self.compile_layer(
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head, [sample_hidden_states], "lm_head"
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
@@ -304,7 +327,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding = self.compile_layer(
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding, [sample_input_ids], "word_embeddings"
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
@@ -313,7 +336,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f, [sample_hidden_states], "ln_f"
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
@@ -321,12 +346,19 @@ class ShardedFalcon(SharkLLMBase):
|
||||
print("Compiling Layer {}".format(i))
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
pytorch_layer_i = DecoderLayer(layer_i)
|
||||
shark_module = self.compile_layer(
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
i,
|
||||
)
|
||||
shark_layer_i = CompiledDecoderLayer(shark_module)
|
||||
del shark_module
|
||||
shark_layer_i = CompiledDecoderLayer(
|
||||
i,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
|
||||
@@ -392,6 +392,9 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
@@ -574,10 +577,17 @@ def get_results(
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm":
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
|
||||
@@ -150,11 +150,15 @@ class SharkInference:
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
function_name, inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
"forward", inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
|
||||
@@ -109,7 +109,9 @@ class SharkRunner:
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
def run(
|
||||
self, function_name, inputs: tuple, send_to_host=False, device=None
|
||||
):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
@@ -117,6 +119,7 @@ class SharkRunner:
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
|
||||
Reference in New Issue
Block a user