diff --git a/apps/language_models/src/model_wrappers/falcon_sharded_model.py b/apps/language_models/src/model_wrappers/falcon_sharded_model.py index aefd9dc1..adb58c5c 100644 --- a/apps/language_models/src/model_wrappers/falcon_sharded_model.py +++ b/apps/language_models/src/model_wrappers/falcon_sharded_model.py @@ -576,6 +576,388 @@ class CompiledEightDecoderLayer(torch.nn.Module): return result +class EightDecoderLayer2(torch.nn.Module): + def __init__(self, decoder_layer_model, falcon_variant): + super().__init__() + self.model = decoder_layer_model + self.falcon_variant = falcon_variant + + def forward(self, hidden_states, attention_mask): + new_pkvs = [] + for layer in self.model: + outputs = layer( + hidden_states=hidden_states, + alibi=None, + attention_mask=attention_mask, + use_cache=True, + ) + hidden_states = outputs[0] + new_pkvs.append( + ( + outputs[-1][0], + outputs[-1][1], + ) + ) + if self.falcon_variant == "180b": + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + (new_pkv80, new_pkv81), + (new_pkv90, new_pkv91), + (new_pkv100, new_pkv101), + (new_pkv110, new_pkv111), + (new_pkv120, new_pkv121), + (new_pkv130, new_pkv131), + (new_pkv140, new_pkv141), + (new_pkv150, new_pkv151), + (new_pkv160, new_pkv161), + (new_pkv170, new_pkv171), + (new_pkv180, new_pkv181), + (new_pkv190, new_pkv191), + (new_pkv200, new_pkv201), + (new_pkv210, new_pkv211), + (new_pkv220, new_pkv221), + (new_pkv230, new_pkv231), + (new_pkv240, new_pkv241), + (new_pkv250, new_pkv251), + (new_pkv260, new_pkv261), + (new_pkv270, new_pkv271), + (new_pkv280, new_pkv281), + (new_pkv290, new_pkv291), + (new_pkv300, new_pkv301), + (new_pkv310, new_pkv311), + (new_pkv320, new_pkv321), + (new_pkv330, new_pkv331), + (new_pkv340, new_pkv341), + (new_pkv350, new_pkv351), + (new_pkv360, new_pkv361), + (new_pkv370, new_pkv371), + (new_pkv380, new_pkv381), + (new_pkv390, new_pkv391), + ) = new_pkvs + result = ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + new_pkv80, + new_pkv81, + new_pkv90, + new_pkv91, + new_pkv100, + new_pkv101, + new_pkv110, + new_pkv111, + new_pkv120, + new_pkv121, + new_pkv130, + new_pkv131, + new_pkv140, + new_pkv141, + new_pkv150, + new_pkv151, + new_pkv160, + new_pkv161, + new_pkv170, + new_pkv171, + new_pkv180, + new_pkv181, + new_pkv190, + new_pkv191, + new_pkv200, + new_pkv201, + new_pkv210, + new_pkv211, + new_pkv220, + new_pkv221, + new_pkv230, + new_pkv231, + new_pkv240, + new_pkv241, + new_pkv250, + new_pkv251, + new_pkv260, + new_pkv261, + new_pkv270, + new_pkv271, + new_pkv280, + new_pkv281, + new_pkv290, + new_pkv291, + new_pkv300, + new_pkv301, + new_pkv310, + new_pkv311, + new_pkv320, + new_pkv321, + new_pkv330, + new_pkv331, + new_pkv340, + new_pkv341, + new_pkv350, + new_pkv351, + new_pkv360, + new_pkv361, + new_pkv370, + new_pkv371, + new_pkv380, + new_pkv381, + new_pkv390, + new_pkv391, + ) + else: + raise ValueError( + "Unsupported Falcon variant: ", self.falcon_variant + ) + return result + + +class CompiledEightDecoderLayer2(torch.nn.Module): + def __init__( + self, layer_id, device_idx, falcon_variant, device, precision + ): + super().__init__() + self.layer_id = layer_id + self.device_index = device_idx + self.falcon_variant = falcon_variant + self.device = device + self.precision = precision + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + alibi: torch.Tensor = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + import gc + + torch.cuda.empty_cache() + gc.collect() + from pathlib import Path + from apps.language_models.utils import get_vmfb_from_path + + self.falcon_vmfb_path = Path( + f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" + ) + print("vmfb path for layer: ", self.falcon_vmfb_path) + self.model = get_vmfb_from_path( + self.falcon_vmfb_path, + self.device, + "linalg", + device_id=self.device_index, + ) + if self.model is None: + raise ValueError("Layer vmfb not found") + + hidden_states = hidden_states.to(torch.float32).detach().numpy() + attention_mask = attention_mask.detach().numpy() + + if alibi is not None or layer_past is not None: + raise ValueError("Past Key Values and alibi should be None") + else: + output = self.model( + "forward", + ( + hidden_states, + attention_mask, + ), + ) + del self.model + + if self.falcon_variant == "180b": + result = ( + torch.tensor(output[0]), + ( + torch.tensor(output[1]), + torch.tensor(output[2]), + ), + ( + torch.tensor(output[3]), + torch.tensor(output[4]), + ), + ( + torch.tensor(output[5]), + torch.tensor(output[6]), + ), + ( + torch.tensor(output[7]), + torch.tensor(output[8]), + ), + ( + torch.tensor(output[9]), + torch.tensor(output[10]), + ), + ( + torch.tensor(output[11]), + torch.tensor(output[12]), + ), + ( + torch.tensor(output[13]), + torch.tensor(output[14]), + ), + ( + torch.tensor(output[15]), + torch.tensor(output[16]), + ), + ( + torch.tensor(output[17]), + torch.tensor(output[18]), + ), + ( + torch.tensor(output[19]), + torch.tensor(output[20]), + ), + ( + torch.tensor(output[21]), + torch.tensor(output[22]), + ), + ( + torch.tensor(output[23]), + torch.tensor(output[24]), + ), + ( + torch.tensor(output[25]), + torch.tensor(output[26]), + ), + ( + torch.tensor(output[27]), + torch.tensor(output[28]), + ), + ( + torch.tensor(output[29]), + torch.tensor(output[30]), + ), + ( + torch.tensor(output[31]), + torch.tensor(output[32]), + ), + ( + torch.tensor(output[33]), + torch.tensor(output[34]), + ), + ( + torch.tensor(output[35]), + torch.tensor(output[36]), + ), + ( + torch.tensor(output[37]), + torch.tensor(output[38]), + ), + ( + torch.tensor(output[39]), + torch.tensor(output[40]), + ), + ( + torch.tensor(output[41]), + torch.tensor(output[42]), + ), + ( + torch.tensor(output[43]), + torch.tensor(output[44]), + ), + ( + torch.tensor(output[45]), + torch.tensor(output[46]), + ), + ( + torch.tensor(output[47]), + torch.tensor(output[48]), + ), + ( + torch.tensor(output[49]), + torch.tensor(output[50]), + ), + ( + torch.tensor(output[51]), + torch.tensor(output[52]), + ), + ( + torch.tensor(output[53]), + torch.tensor(output[54]), + ), + ( + torch.tensor(output[55]), + torch.tensor(output[56]), + ), + ( + torch.tensor(output[57]), + torch.tensor(output[58]), + ), + ( + torch.tensor(output[59]), + torch.tensor(output[60]), + ), + ( + torch.tensor(output[61]), + torch.tensor(output[62]), + ), + ( + torch.tensor(output[63]), + torch.tensor(output[64]), + ), + ( + torch.tensor(output[65]), + torch.tensor(output[66]), + ), + ( + torch.tensor(output[67]), + torch.tensor(output[68]), + ), + ( + torch.tensor(output[69]), + torch.tensor(output[70]), + ), + ( + torch.tensor(output[71]), + torch.tensor(output[72]), + ), + ( + torch.tensor(output[73]), + torch.tensor(output[74]), + ), + ( + torch.tensor(output[75]), + torch.tensor(output[76]), + ), + ( + torch.tensor(output[77]), + torch.tensor(output[78]), + ), + ( + torch.tensor(output[79]), + torch.tensor(output[80]), + ), + ) + else: + raise ValueError( + "Unsupported Falcon variant: ", self.falcon_variant + ) + return result + + class ShardedFalconModel: def __init__(self, model, layers, word_embeddings, ln_f, lm_head): super().__init__() diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index e6e43d13..c06b8978 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -8,8 +8,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import ( CompiledLMHeadEmbeddingLayer, DecoderLayer, EightDecoderLayer, + EightDecoderLayer2, CompiledDecoderLayer, CompiledEightDecoderLayer, + CompiledEightDecoderLayer2, ShardedFalconModel, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase @@ -94,6 +96,12 @@ parser.add_argument( action=argparse.BooleanOptionalAction, help="Run model as sharded", ) +parser.add_argument( + "--num-shards", + type=int, + default=4, + help="Number of shards.", +) class ShardedFalcon(SharkLLMBase): @@ -306,7 +314,9 @@ class ShardedFalcon(SharkLLMBase): num_in_features = 14848 sample_attention_mask = sample_attention_mask.to(dtype=torch.bool) if compressed: - num_group_layers = 20 + num_group_layers = int( + 20 * (4 / args.num_shards) + ) # 4 is the number of default shards sample_hidden_states = torch.zeros( [1, 100, num_in_features], dtype=torch.float32 @@ -326,7 +336,9 @@ class ShardedFalcon(SharkLLMBase): lm_head, [sample_hidden_states], "lm_head", - device_idx=0 % num_devices if self.device == "rocm" else None, + device_idx=(0 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head) @@ -338,7 +350,9 @@ class ShardedFalcon(SharkLLMBase): word_embedding, [sample_input_ids], "word_embeddings", - device_idx=1 % num_devices if self.device == "rocm" else None, + device_idx=(1 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_word_embedding = CompiledWordEmbeddingsLayer( shark_word_embedding @@ -350,7 +364,9 @@ class ShardedFalcon(SharkLLMBase): ln_f, [sample_hidden_states], "ln_f", - device_idx=2 % num_devices if self.device == "rocm" else None, + device_idx=(2 % num_devices) % args.num_shards + if self.device == "rocm" + else None, ) shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f) @@ -370,6 +386,9 @@ class ShardedFalcon(SharkLLMBase): ) pytorch_class = EightDecoderLayer compiled_class = CompiledEightDecoderLayer + if args.num_shards == 2: + pytorch_class = EightDecoderLayer2 + compiled_class = CompiledEightDecoderLayer2 print("Compiling Layer {}".format(layer_id)) if compressed: