diff --git a/apps/language_models/src/model_wrappers/falcon_sharded_model.py b/apps/language_models/src/model_wrappers/falcon_sharded_model.py index b2b53a2c..4b4410d2 100644 --- a/apps/language_models/src/model_wrappers/falcon_sharded_model.py +++ b/apps/language_models/src/model_wrappers/falcon_sharded_model.py @@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module): class CompiledDecoderLayer(torch.nn.Module): - def __init__(self, shark_decoder_layer_module): + def __init__( + self, layer_id, device_idx, falcon_variant, device, precision + ): super().__init__() - self.model = shark_decoder_layer_module + self.layer_id = layer_id + self.device_index = device_idx + self.falcon_variant = falcon_variant + self.device = device + self.precision = precision def forward( self, @@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module): use_cache: bool = False, output_attentions: bool = False, ): + import gc + + torch.cuda.empty_cache() + gc.collect() + from pathlib import Path + from apps.language_models.utils import get_vmfb_from_path + + self.falcon_vmfb_path = Path( + f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" + ) + print("vmfb path for layer: ", self.falcon_vmfb_path) + self.model = get_vmfb_from_path( + self.falcon_vmfb_path, + self.device, + "linalg", + device_id=self.device_index, + ) + if self.model is None: + raise ValueError("Layer vmfb not found") + hidden_states = hidden_states.to(torch.float32).detach().numpy() attention_mask = attention_mask.detach().numpy() @@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module): attention_mask, ), ) + del self.model + return tuple( [ torch.tensor(new_hidden_states), diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index f8b8f631..d7ab5dc6 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -150,7 +150,7 @@ class ShardedFalcon(SharkLLMBase): quantization_config = GPTQConfig(bits=4, disable_exllama=True) kwargs["quantization_config"] = quantization_config kwargs["load_gptq_on_cpu"] = True - kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0" + kwargs["device_map"] = "cpu" falcon_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs ) @@ -159,6 +159,25 @@ class ShardedFalcon(SharkLLMBase): return falcon_model def compile_layer(self, layer, falconCompileInput, layer_id): + # Determine number of available devices + import iree.runtime as ireert + + haldriver = ireert.get_driver(self.device) + num_devices = len(haldriver.query_available_devices()) + + if layer_id == "word_embeddings": + device_idx = 0 % num_devices + elif layer_id == "ln_f": + device_idx = 1 % num_devices + elif layer_id == "lm_head": + device_idx = 2 % num_devices + elif type(layer_id) == int: + device_idx = layer_id % num_devices + else: + raise ValueError("Falcon: Unknow layer encountered") + + device_idx = device_idx if self.device == "rocm" else None + self.falcon_mlir_path = Path( f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir" ) @@ -177,10 +196,13 @@ class ShardedFalcon(SharkLLMBase): single_file=True, ) vmfb = get_vmfb_from_path( - self.falcon_vmfb_path, self.device, "linalg" + self.falcon_vmfb_path, + self.device, + "linalg", + device_id=device_idx, ) if vmfb is not None: - return vmfb + return vmfb, device_idx print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}") if self.falcon_mlir_path.exists(): @@ -257,6 +279,7 @@ class ShardedFalcon(SharkLLMBase): mlir_module=self.falcon_mlir_path, device=self.device, mlir_dialect="linalg", + device_idx=device_idx, ) path = shark_module.save_module( self.falcon_vmfb_path.parent.absolute(), @@ -276,7 +299,7 @@ class ShardedFalcon(SharkLLMBase): print("Saved falcon vmfb at ", str(path)) shark_module.load_module(path) - return shark_module + return shark_module, device_idx def compile(self): sample_input_ids = torch.zeros([100], dtype=torch.int64) @@ -295,7 +318,7 @@ class ShardedFalcon(SharkLLMBase): lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head) print("Compiling Layer lm_head") - shark_lm_head = self.compile_layer( + shark_lm_head, _ = self.compile_layer( lm_head, [sample_hidden_states], "lm_head" ) shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head) @@ -304,7 +327,7 @@ class ShardedFalcon(SharkLLMBase): self.src_model.transformer.word_embeddings ) print("Compiling Layer word_embeddings") - shark_word_embedding = self.compile_layer( + shark_word_embedding, _ = self.compile_layer( word_embedding, [sample_input_ids], "word_embeddings" ) shark_word_embedding = CompiledWordEmbeddingsLayer( @@ -313,7 +336,9 @@ class ShardedFalcon(SharkLLMBase): ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f) print("Compiling Layer ln_f") - shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f") + shark_ln_f, _ = self.compile_layer( + ln_f, [sample_hidden_states], "ln_f" + ) shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f) shark_layers = [] @@ -321,12 +346,19 @@ class ShardedFalcon(SharkLLMBase): print("Compiling Layer {}".format(i)) layer_i = self.src_model.transformer.h[i] pytorch_layer_i = DecoderLayer(layer_i) - shark_module = self.compile_layer( + shark_module, device_idx = self.compile_layer( pytorch_layer_i, [sample_hidden_states, sample_attention_mask], i, ) - shark_layer_i = CompiledDecoderLayer(shark_module) + del shark_module + shark_layer_i = CompiledDecoderLayer( + i, + device_idx, + args.falcon_variant_to_use, + self.device, + self.precision, + ) shark_layers.append(shark_layer_i) sharded_model = ShardedFalconModel( diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 0900b38a..a6d9350c 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -392,6 +392,9 @@ def load_vmfb_using_mmap( ) dl.log(f"ireert.create_device()") config = ireert.Config(device=haldevice) + config.id = haldriver.query_available_devices()[device_idx][ + "device_id" + ] dl.log(f"ireert.Config()") else: config = get_iree_runtime_config(device) @@ -574,10 +577,17 @@ def get_results( frontend="torch", send_to_host=True, debug_timeout: float = 5.0, + device: str = None, ): """Runs a .vmfb file given inputs and config and returns output.""" with DetailLogger(debug_timeout) as dl: device_inputs = [] + if device == "rocm": + haldriver = ireert.get_driver("rocm") + haldevice = haldriver.create_device( + config.id, + allocators=shark_args.device_allocator, + ) for input_array in input: dl.log(f"Load to device: {input_array.shape}") device_inputs.append( diff --git a/shark/shark_inference.py b/shark/shark_inference.py index df0dab7a..032137c0 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -150,11 +150,15 @@ class SharkInference: # inputs are considered to be tuple of np.array. def __call__(self, function_name: str, inputs: tuple, send_to_host=True): - return self.shark_runner.run(function_name, inputs, send_to_host) + return self.shark_runner.run( + function_name, inputs, send_to_host, device=self.device + ) # forward function. def forward(self, inputs: tuple, send_to_host=True): - return self.shark_runner.run("forward", inputs, send_to_host) + return self.shark_runner.run( + "forward", inputs, send_to_host, device=self.device + ) # Get all function names defined within the compiled module. def get_functions_in_module(self): diff --git a/shark/shark_runner.py b/shark/shark_runner.py index b64c7d97..9f24409b 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -109,7 +109,9 @@ class SharkRunner: self.temp_file_to_unlink = params["temp_file_to_unlink"] del params - def run(self, function_name, inputs: tuple, send_to_host=False): + def run( + self, function_name, inputs: tuple, send_to_host=False, device=None + ): return get_results( self.iree_compilation_module, function_name, @@ -117,6 +119,7 @@ class SharkRunner: self.iree_config, self.mlir_dialect, send_to_host, + device=device, ) # Get all function names defined within the compiled module.