Add sharded Falcon support

This commit is contained in:
Vivek Khandelwal
2023-10-26 07:23:15 -07:00
parent 486202377a
commit ea920f2955
5 changed files with 91 additions and 14 deletions

View File

@@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module):
class CompiledDecoderLayer(torch.nn.Module):
def __init__(self, shark_decoder_layer_module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.model = shark_decoder_layer_module
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
@@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module):
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
@@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module):
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),

View File

@@ -150,7 +150,7 @@ class ShardedFalcon(SharkLLMBase):
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
@@ -159,6 +159,25 @@ class ShardedFalcon(SharkLLMBase):
return falcon_model
def compile_layer(self, layer, falconCompileInput, layer_id):
# Determine number of available devices
import iree.runtime as ireert
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if layer_id == "word_embeddings":
device_idx = 0 % num_devices
elif layer_id == "ln_f":
device_idx = 1 % num_devices
elif layer_id == "lm_head":
device_idx = 2 % num_devices
elif type(layer_id) == int:
device_idx = layer_id % num_devices
else:
raise ValueError("Falcon: Unknow layer encountered")
device_idx = device_idx if self.device == "rocm" else None
self.falcon_mlir_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
)
@@ -177,10 +196,13 @@ class ShardedFalcon(SharkLLMBase):
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=device_idx,
)
if vmfb is not None:
return vmfb
return vmfb, device_idx
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
@@ -257,6 +279,7 @@ class ShardedFalcon(SharkLLMBase):
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
device_idx=device_idx,
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
@@ -276,7 +299,7 @@ class ShardedFalcon(SharkLLMBase):
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
return shark_module, device_idx
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
@@ -295,7 +318,7 @@ class ShardedFalcon(SharkLLMBase):
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
shark_lm_head = self.compile_layer(
shark_lm_head, _ = self.compile_layer(
lm_head, [sample_hidden_states], "lm_head"
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -304,7 +327,7 @@ class ShardedFalcon(SharkLLMBase):
self.src_model.transformer.word_embeddings
)
print("Compiling Layer word_embeddings")
shark_word_embedding = self.compile_layer(
shark_word_embedding, _ = self.compile_layer(
word_embedding, [sample_input_ids], "word_embeddings"
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
@@ -313,7 +336,9 @@ class ShardedFalcon(SharkLLMBase):
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
print("Compiling Layer ln_f")
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
shark_ln_f, _ = self.compile_layer(
ln_f, [sample_hidden_states], "ln_f"
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
shark_layers = []
@@ -321,12 +346,19 @@ class ShardedFalcon(SharkLLMBase):
print("Compiling Layer {}".format(i))
layer_i = self.src_model.transformer.h[i]
pytorch_layer_i = DecoderLayer(layer_i)
shark_module = self.compile_layer(
shark_module, device_idx = self.compile_layer(
pytorch_layer_i,
[sample_hidden_states, sample_attention_mask],
i,
)
shark_layer_i = CompiledDecoderLayer(shark_module)
del shark_module
shark_layer_i = CompiledDecoderLayer(
i,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
)
shark_layers.append(shark_layer_i)
sharded_model = ShardedFalconModel(

View File

@@ -392,6 +392,9 @@ def load_vmfb_using_mmap(
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
@@ -574,10 +577,17 @@ def get_results(
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
device: str = None,
):
"""Runs a .vmfb file given inputs and config and returns output."""
with DetailLogger(debug_timeout) as dl:
device_inputs = []
if device == "rocm":
haldriver = ireert.get_driver("rocm")
haldevice = haldriver.create_device(
config.id,
allocators=shark_args.device_allocator,
)
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(

View File

@@ -150,11 +150,15 @@ class SharkInference:
# inputs are considered to be tuple of np.array.
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
return self.shark_runner.run(
function_name, inputs, send_to_host, device=self.device
)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
return self.shark_runner.run(
"forward", inputs, send_to_host, device=self.device
)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):

View File

@@ -109,7 +109,9 @@ class SharkRunner:
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
def run(
self, function_name, inputs: tuple, send_to_host=False, device=None
):
return get_results(
self.iree_compilation_module,
function_name,
@@ -117,6 +119,7 @@ class SharkRunner:
self.iree_config,
self.mlir_dialect,
send_to_host,
device=device,
)
# Get all function names defined within the compiled module.