diff --git a/apps/language_models/src/model_wrappers/falcon_sharded_model.py b/apps/language_models/src/model_wrappers/falcon_sharded_model.py index adb58c5c..b5a16c3d 100644 --- a/apps/language_models/src/model_wrappers/falcon_sharded_model.py +++ b/apps/language_models/src/model_wrappers/falcon_sharded_model.py @@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module): return torch.tensor(new_hidden_states) -class DecoderLayer(torch.nn.Module): - def __init__(self, decoder_layer_model, falcon_variant): - super().__init__() - self.model = decoder_layer_model - - def forward(self, hidden_states, attention_mask): - output = self.model.forward( - hidden_states=hidden_states, - alibi=None, - attention_mask=attention_mask, - use_cache=True, - ) - return (output[0], output[1][0], output[1][1]) - - -class CompiledDecoderLayer(torch.nn.Module): - def __init__( - self, layer_id, device_idx, falcon_variant, device, precision - ): - super().__init__() - self.layer_id = layer_id - self.device_index = device_idx - self.falcon_variant = falcon_variant - self.device = device - self.precision = precision - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - alibi: torch.Tensor = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - import gc - - torch.cuda.empty_cache() - gc.collect() - from pathlib import Path - from apps.language_models.utils import get_vmfb_from_path - - self.falcon_vmfb_path = Path( - f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" - ) - print("vmfb path for layer: ", self.falcon_vmfb_path) - self.model = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=self.device_index, - ) - if self.model is None: - raise ValueError("Layer vmfb not found") - - hidden_states = hidden_states.to(torch.float32).detach().numpy() - attention_mask = attention_mask.detach().numpy() - - if alibi is not None or layer_past is not None: - raise ValueError("Past Key Values and alibi should be None") - else: - new_hidden_states, pkv1, pkv2 = self.model( - "forward", - ( - hidden_states, - attention_mask, - ), - ) - del self.model - - return tuple( - [ - torch.tensor(new_hidden_states), - tuple( - [ - torch.tensor(pkv1), - torch.tensor(pkv2), - ] - ), - ] - ) - - -class EightDecoderLayer(torch.nn.Module): +class FourWayShardingDecoderLayer(torch.nn.Module): def __init__(self, decoder_layer_model, falcon_variant): super().__init__() self.model = decoder_layer_model @@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module): outputs[-1][1], ) ) - if self.falcon_variant == "7b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - ) - elif self.falcon_variant == "40b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - ) - elif self.falcon_variant == "180b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - (new_pkv150, new_pkv151), - (new_pkv160, new_pkv161), - (new_pkv170, new_pkv171), - (new_pkv180, new_pkv181), - (new_pkv190, new_pkv191), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - new_pkv150, - new_pkv151, - new_pkv160, - new_pkv161, - new_pkv170, - new_pkv171, - new_pkv180, - new_pkv181, - new_pkv190, - new_pkv191, - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + (new_pkv80, new_pkv81), + (new_pkv90, new_pkv91), + (new_pkv100, new_pkv101), + (new_pkv110, new_pkv111), + (new_pkv120, new_pkv121), + (new_pkv130, new_pkv131), + (new_pkv140, new_pkv141), + (new_pkv150, new_pkv151), + (new_pkv160, new_pkv161), + (new_pkv170, new_pkv171), + (new_pkv180, new_pkv181), + (new_pkv190, new_pkv191), + ) = new_pkvs + result = ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + new_pkv80, + new_pkv81, + new_pkv90, + new_pkv91, + new_pkv100, + new_pkv101, + new_pkv110, + new_pkv111, + new_pkv120, + new_pkv121, + new_pkv130, + new_pkv131, + new_pkv140, + new_pkv141, + new_pkv150, + new_pkv151, + new_pkv160, + new_pkv161, + new_pkv170, + new_pkv171, + new_pkv180, + new_pkv181, + new_pkv190, + new_pkv191, + ) return result -class CompiledEightDecoderLayer(torch.nn.Module): +class CompiledFourWayShardingDecoderLayer(torch.nn.Module): def __init__( - self, layer_id, device_idx, falcon_variant, device, precision + self, layer_id, device_idx, falcon_variant, device, precision, model ): super().__init__() self.layer_id = layer_id @@ -339,6 +170,7 @@ class CompiledEightDecoderLayer(torch.nn.Module): self.falcon_variant = falcon_variant self.device = device self.precision = precision + self.model = model def forward( self, @@ -354,19 +186,7 @@ class CompiledEightDecoderLayer(torch.nn.Module): torch.cuda.empty_cache() gc.collect() - from pathlib import Path - from apps.language_models.utils import get_vmfb_from_path - self.falcon_vmfb_path = Path( - f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" - ) - print("vmfb path for layer: ", self.falcon_vmfb_path) - self.model = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=self.device_index, - ) if self.model is None: raise ValueError("Layer vmfb not found") @@ -383,200 +203,94 @@ class CompiledEightDecoderLayer(torch.nn.Module): attention_mask, ), ) - del self.model - if self.falcon_variant == "7b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ) - elif self.falcon_variant == "40b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ) - elif self.falcon_variant == "180b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ( - torch.tensor(output[31]), - torch.tensor(output[32]), - ), - ( - torch.tensor(output[33]), - torch.tensor(output[34]), - ), - ( - torch.tensor(output[35]), - torch.tensor(output[36]), - ), - ( - torch.tensor(output[37]), - torch.tensor(output[38]), - ), - ( - torch.tensor(output[39]), - torch.tensor(output[40]), - ), - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + result = ( + torch.tensor(output[0]), + ( + torch.tensor(output[1]), + torch.tensor(output[2]), + ), + ( + torch.tensor(output[3]), + torch.tensor(output[4]), + ), + ( + torch.tensor(output[5]), + torch.tensor(output[6]), + ), + ( + torch.tensor(output[7]), + torch.tensor(output[8]), + ), + ( + torch.tensor(output[9]), + torch.tensor(output[10]), + ), + ( + torch.tensor(output[11]), + torch.tensor(output[12]), + ), + ( + torch.tensor(output[13]), + torch.tensor(output[14]), + ), + ( + torch.tensor(output[15]), + torch.tensor(output[16]), + ), + ( + torch.tensor(output[17]), + torch.tensor(output[18]), + ), + ( + torch.tensor(output[19]), + torch.tensor(output[20]), + ), + ( + torch.tensor(output[21]), + torch.tensor(output[22]), + ), + ( + torch.tensor(output[23]), + torch.tensor(output[24]), + ), + ( + torch.tensor(output[25]), + torch.tensor(output[26]), + ), + ( + torch.tensor(output[27]), + torch.tensor(output[28]), + ), + ( + torch.tensor(output[29]), + torch.tensor(output[30]), + ), + ( + torch.tensor(output[31]), + torch.tensor(output[32]), + ), + ( + torch.tensor(output[33]), + torch.tensor(output[34]), + ), + ( + torch.tensor(output[35]), + torch.tensor(output[36]), + ), + ( + torch.tensor(output[37]), + torch.tensor(output[38]), + ), + ( + torch.tensor(output[39]), + torch.tensor(output[40]), + ), + ) return result -class EightDecoderLayer2(torch.nn.Module): +class TwoWayShardingDecoderLayer(torch.nn.Module): def __init__(self, decoder_layer_model, falcon_variant): super().__init__() self.model = decoder_layer_model @@ -598,142 +312,138 @@ class EightDecoderLayer2(torch.nn.Module): outputs[-1][1], ) ) - if self.falcon_variant == "180b": - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - (new_pkv150, new_pkv151), - (new_pkv160, new_pkv161), - (new_pkv170, new_pkv171), - (new_pkv180, new_pkv181), - (new_pkv190, new_pkv191), - (new_pkv200, new_pkv201), - (new_pkv210, new_pkv211), - (new_pkv220, new_pkv221), - (new_pkv230, new_pkv231), - (new_pkv240, new_pkv241), - (new_pkv250, new_pkv251), - (new_pkv260, new_pkv261), - (new_pkv270, new_pkv271), - (new_pkv280, new_pkv281), - (new_pkv290, new_pkv291), - (new_pkv300, new_pkv301), - (new_pkv310, new_pkv311), - (new_pkv320, new_pkv321), - (new_pkv330, new_pkv331), - (new_pkv340, new_pkv341), - (new_pkv350, new_pkv351), - (new_pkv360, new_pkv361), - (new_pkv370, new_pkv371), - (new_pkv380, new_pkv381), - (new_pkv390, new_pkv391), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - new_pkv150, - new_pkv151, - new_pkv160, - new_pkv161, - new_pkv170, - new_pkv171, - new_pkv180, - new_pkv181, - new_pkv190, - new_pkv191, - new_pkv200, - new_pkv201, - new_pkv210, - new_pkv211, - new_pkv220, - new_pkv221, - new_pkv230, - new_pkv231, - new_pkv240, - new_pkv241, - new_pkv250, - new_pkv251, - new_pkv260, - new_pkv261, - new_pkv270, - new_pkv271, - new_pkv280, - new_pkv281, - new_pkv290, - new_pkv291, - new_pkv300, - new_pkv301, - new_pkv310, - new_pkv311, - new_pkv320, - new_pkv321, - new_pkv330, - new_pkv331, - new_pkv340, - new_pkv341, - new_pkv350, - new_pkv351, - new_pkv360, - new_pkv361, - new_pkv370, - new_pkv371, - new_pkv380, - new_pkv381, - new_pkv390, - new_pkv391, - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + (new_pkv80, new_pkv81), + (new_pkv90, new_pkv91), + (new_pkv100, new_pkv101), + (new_pkv110, new_pkv111), + (new_pkv120, new_pkv121), + (new_pkv130, new_pkv131), + (new_pkv140, new_pkv141), + (new_pkv150, new_pkv151), + (new_pkv160, new_pkv161), + (new_pkv170, new_pkv171), + (new_pkv180, new_pkv181), + (new_pkv190, new_pkv191), + (new_pkv200, new_pkv201), + (new_pkv210, new_pkv211), + (new_pkv220, new_pkv221), + (new_pkv230, new_pkv231), + (new_pkv240, new_pkv241), + (new_pkv250, new_pkv251), + (new_pkv260, new_pkv261), + (new_pkv270, new_pkv271), + (new_pkv280, new_pkv281), + (new_pkv290, new_pkv291), + (new_pkv300, new_pkv301), + (new_pkv310, new_pkv311), + (new_pkv320, new_pkv321), + (new_pkv330, new_pkv331), + (new_pkv340, new_pkv341), + (new_pkv350, new_pkv351), + (new_pkv360, new_pkv361), + (new_pkv370, new_pkv371), + (new_pkv380, new_pkv381), + (new_pkv390, new_pkv391), + ) = new_pkvs + result = ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + new_pkv80, + new_pkv81, + new_pkv90, + new_pkv91, + new_pkv100, + new_pkv101, + new_pkv110, + new_pkv111, + new_pkv120, + new_pkv121, + new_pkv130, + new_pkv131, + new_pkv140, + new_pkv141, + new_pkv150, + new_pkv151, + new_pkv160, + new_pkv161, + new_pkv170, + new_pkv171, + new_pkv180, + new_pkv181, + new_pkv190, + new_pkv191, + new_pkv200, + new_pkv201, + new_pkv210, + new_pkv211, + new_pkv220, + new_pkv221, + new_pkv230, + new_pkv231, + new_pkv240, + new_pkv241, + new_pkv250, + new_pkv251, + new_pkv260, + new_pkv261, + new_pkv270, + new_pkv271, + new_pkv280, + new_pkv281, + new_pkv290, + new_pkv291, + new_pkv300, + new_pkv301, + new_pkv310, + new_pkv311, + new_pkv320, + new_pkv321, + new_pkv330, + new_pkv331, + new_pkv340, + new_pkv341, + new_pkv350, + new_pkv351, + new_pkv360, + new_pkv361, + new_pkv370, + new_pkv371, + new_pkv380, + new_pkv381, + new_pkv390, + new_pkv391, + ) return result -class CompiledEightDecoderLayer2(torch.nn.Module): +class CompiledTwoWayShardingDecoderLayer(torch.nn.Module): def __init__( - self, layer_id, device_idx, falcon_variant, device, precision + self, layer_id, device_idx, falcon_variant, device, precision, model ): super().__init__() self.layer_id = layer_id @@ -741,6 +451,7 @@ class CompiledEightDecoderLayer2(torch.nn.Module): self.falcon_variant = falcon_variant self.device = device self.precision = precision + self.model = model def forward( self, @@ -756,19 +467,7 @@ class CompiledEightDecoderLayer2(torch.nn.Module): torch.cuda.empty_cache() gc.collect() - from pathlib import Path - from apps.language_models.utils import get_vmfb_from_path - self.falcon_vmfb_path = Path( - f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb" - ) - print("vmfb path for layer: ", self.falcon_vmfb_path) - self.model = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=self.device_index, - ) if self.model is None: raise ValueError("Layer vmfb not found") @@ -785,176 +484,170 @@ class CompiledEightDecoderLayer2(torch.nn.Module): attention_mask, ), ) - del self.model - if self.falcon_variant == "180b": - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ( - torch.tensor(output[31]), - torch.tensor(output[32]), - ), - ( - torch.tensor(output[33]), - torch.tensor(output[34]), - ), - ( - torch.tensor(output[35]), - torch.tensor(output[36]), - ), - ( - torch.tensor(output[37]), - torch.tensor(output[38]), - ), - ( - torch.tensor(output[39]), - torch.tensor(output[40]), - ), - ( - torch.tensor(output[41]), - torch.tensor(output[42]), - ), - ( - torch.tensor(output[43]), - torch.tensor(output[44]), - ), - ( - torch.tensor(output[45]), - torch.tensor(output[46]), - ), - ( - torch.tensor(output[47]), - torch.tensor(output[48]), - ), - ( - torch.tensor(output[49]), - torch.tensor(output[50]), - ), - ( - torch.tensor(output[51]), - torch.tensor(output[52]), - ), - ( - torch.tensor(output[53]), - torch.tensor(output[54]), - ), - ( - torch.tensor(output[55]), - torch.tensor(output[56]), - ), - ( - torch.tensor(output[57]), - torch.tensor(output[58]), - ), - ( - torch.tensor(output[59]), - torch.tensor(output[60]), - ), - ( - torch.tensor(output[61]), - torch.tensor(output[62]), - ), - ( - torch.tensor(output[63]), - torch.tensor(output[64]), - ), - ( - torch.tensor(output[65]), - torch.tensor(output[66]), - ), - ( - torch.tensor(output[67]), - torch.tensor(output[68]), - ), - ( - torch.tensor(output[69]), - torch.tensor(output[70]), - ), - ( - torch.tensor(output[71]), - torch.tensor(output[72]), - ), - ( - torch.tensor(output[73]), - torch.tensor(output[74]), - ), - ( - torch.tensor(output[75]), - torch.tensor(output[76]), - ), - ( - torch.tensor(output[77]), - torch.tensor(output[78]), - ), - ( - torch.tensor(output[79]), - torch.tensor(output[80]), - ), - ) - else: - raise ValueError( - "Unsupported Falcon variant: ", self.falcon_variant - ) + result = ( + torch.tensor(output[0]), + ( + torch.tensor(output[1]), + torch.tensor(output[2]), + ), + ( + torch.tensor(output[3]), + torch.tensor(output[4]), + ), + ( + torch.tensor(output[5]), + torch.tensor(output[6]), + ), + ( + torch.tensor(output[7]), + torch.tensor(output[8]), + ), + ( + torch.tensor(output[9]), + torch.tensor(output[10]), + ), + ( + torch.tensor(output[11]), + torch.tensor(output[12]), + ), + ( + torch.tensor(output[13]), + torch.tensor(output[14]), + ), + ( + torch.tensor(output[15]), + torch.tensor(output[16]), + ), + ( + torch.tensor(output[17]), + torch.tensor(output[18]), + ), + ( + torch.tensor(output[19]), + torch.tensor(output[20]), + ), + ( + torch.tensor(output[21]), + torch.tensor(output[22]), + ), + ( + torch.tensor(output[23]), + torch.tensor(output[24]), + ), + ( + torch.tensor(output[25]), + torch.tensor(output[26]), + ), + ( + torch.tensor(output[27]), + torch.tensor(output[28]), + ), + ( + torch.tensor(output[29]), + torch.tensor(output[30]), + ), + ( + torch.tensor(output[31]), + torch.tensor(output[32]), + ), + ( + torch.tensor(output[33]), + torch.tensor(output[34]), + ), + ( + torch.tensor(output[35]), + torch.tensor(output[36]), + ), + ( + torch.tensor(output[37]), + torch.tensor(output[38]), + ), + ( + torch.tensor(output[39]), + torch.tensor(output[40]), + ), + ( + torch.tensor(output[41]), + torch.tensor(output[42]), + ), + ( + torch.tensor(output[43]), + torch.tensor(output[44]), + ), + ( + torch.tensor(output[45]), + torch.tensor(output[46]), + ), + ( + torch.tensor(output[47]), + torch.tensor(output[48]), + ), + ( + torch.tensor(output[49]), + torch.tensor(output[50]), + ), + ( + torch.tensor(output[51]), + torch.tensor(output[52]), + ), + ( + torch.tensor(output[53]), + torch.tensor(output[54]), + ), + ( + torch.tensor(output[55]), + torch.tensor(output[56]), + ), + ( + torch.tensor(output[57]), + torch.tensor(output[58]), + ), + ( + torch.tensor(output[59]), + torch.tensor(output[60]), + ), + ( + torch.tensor(output[61]), + torch.tensor(output[62]), + ), + ( + torch.tensor(output[63]), + torch.tensor(output[64]), + ), + ( + torch.tensor(output[65]), + torch.tensor(output[66]), + ), + ( + torch.tensor(output[67]), + torch.tensor(output[68]), + ), + ( + torch.tensor(output[69]), + torch.tensor(output[70]), + ), + ( + torch.tensor(output[71]), + torch.tensor(output[72]), + ), + ( + torch.tensor(output[73]), + torch.tensor(output[74]), + ), + ( + torch.tensor(output[75]), + torch.tensor(output[76]), + ), + ( + torch.tensor(output[77]), + torch.tensor(output[78]), + ), + ( + torch.tensor(output[79]), + torch.tensor(output[80]), + ), + ) return result diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index c06b8978..586f822b 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -6,12 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import ( CompiledLNFEmbeddingLayer, LMHeadEmbeddingLayer, CompiledLMHeadEmbeddingLayer, - DecoderLayer, - EightDecoderLayer, - EightDecoderLayer2, - CompiledDecoderLayer, - CompiledEightDecoderLayer, - CompiledEightDecoderLayer2, + FourWayShardingDecoderLayer, + TwoWayShardingDecoderLayer, + CompiledFourWayShardingDecoderLayer, + CompiledTwoWayShardingDecoderLayer, ShardedFalconModel, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase @@ -97,9 +95,10 @@ parser.add_argument( help="Run model as sharded", ) parser.add_argument( - "--num-shards", + "--num_shards", type=int, default=4, + choices=[2, 4], help="Number of shards.", ) @@ -130,6 +129,10 @@ class ShardedFalcon(SharkLLMBase): --hf_auth_token flag. You can ask for the access to the model here: https://huggingface.co/tiiuae/falcon-180B-chat.""" ) + + if args.sharded and "180b" not in self.model_name: + raise ValueError("Sharding supported only for Falcon-180B") + self.hf_auth_token = hf_auth_token self.max_padding_length = 100 self.device = device @@ -139,7 +142,7 @@ class ShardedFalcon(SharkLLMBase): self.debug = debug self.tokenizer = self.get_tokenizer() self.src_model = self.get_src_model() - self.shark_model = self.compile(compressed=args.compressed) + self.shark_model = self.compile() def get_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained( @@ -154,20 +157,17 @@ class ShardedFalcon(SharkLLMBase): def get_src_model(self): print("Loading src model: ", self.model_name) kwargs = { - "torch_dtype": torch.float, + "torch_dtype": torch.float32, "trust_remote_code": True, "token": self.hf_auth_token, } if self.precision == "int4": quantization_config = GPTQConfig(bits=4, disable_exllama=True) kwargs["quantization_config"] = quantization_config - kwargs["load_gptq_on_cpu"] = True kwargs["device_map"] = "cpu" falcon_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs ) - if self.precision == "int4": - falcon_model = falcon_model.to(torch.float32) return falcon_model def compile_layer( @@ -296,30 +296,14 @@ class ShardedFalcon(SharkLLMBase): return shark_module, device_idx - def compile(self, compressed=False): + def compile(self): sample_input_ids = torch.zeros([100], dtype=torch.int64) - sample_attention_mask = torch.zeros( - [1, 1, 100, 100], dtype=torch.float32 - ) - num_group_layers = 1 - if "7b" in self.model_name: - num_in_features = 4544 - if compressed: - num_group_layers = 8 - elif "40b" in self.model_name: - num_in_features = 8192 - if compressed: - num_group_layers = 15 - else: - num_in_features = 14848 - sample_attention_mask = sample_attention_mask.to(dtype=torch.bool) - if compressed: - num_group_layers = int( - 20 * (4 / args.num_shards) - ) # 4 is the number of default shards - + sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool) + num_group_layers = int( + 20 * (4 / args.num_shards) + ) # 4 is the number of default shards sample_hidden_states = torch.zeros( - [1, 100, num_in_features], dtype=torch.float32 + [1, 100, 14848], dtype=torch.float32 ) # Determine number of available devices @@ -329,6 +313,10 @@ class ShardedFalcon(SharkLLMBase): haldriver = ireert.get_driver(self.device) num_devices = len(haldriver.query_available_devices()) + if num_devices < 2: + raise ValueError( + "Cannot run Falcon-180B on a single ROCM device." + ) lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head) print("Compiling Layer lm_head") @@ -376,27 +364,21 @@ class ShardedFalcon(SharkLLMBase): ): device_idx = i % num_devices if self.device == "rocm" else None layer_id = i - pytorch_class = DecoderLayer - compiled_class = CompiledDecoderLayer - if compressed: - layer_id = ( - str(i * num_group_layers) - + "_" - + str((i + 1) * num_group_layers) - ) - pytorch_class = EightDecoderLayer - compiled_class = CompiledEightDecoderLayer - if args.num_shards == 2: - pytorch_class = EightDecoderLayer2 - compiled_class = CompiledEightDecoderLayer2 + layer_id = ( + str(i * num_group_layers) + + "_" + + str((i + 1) * num_group_layers) + ) + pytorch_class = FourWayShardingDecoderLayer + compiled_class = CompiledFourWayShardingDecoderLayer + if args.num_shards == 2: + pytorch_class = TwoWayShardingDecoderLayer + compiled_class = CompiledTwoWayShardingDecoderLayer print("Compiling Layer {}".format(layer_id)) - if compressed: - layer_i = self.src_model.transformer.h[ - i * num_group_layers : (i + 1) * num_group_layers - ] - else: - layer_i = self.src_model.transformer.h[i] + layer_i = self.src_model.transformer.h[ + i * num_group_layers : (i + 1) * num_group_layers + ] pytorch_layer_i = pytorch_class( layer_i, args.falcon_variant_to_use @@ -407,13 +389,13 @@ class ShardedFalcon(SharkLLMBase): layer_id, device_idx=device_idx, ) - del shark_module shark_layer_i = compiled_class( layer_id, device_idx, args.falcon_variant_to_use, self.device, self.precision, + shark_module, ) shark_layers.append(shark_layer_i) diff --git a/requirements.txt b/requirements.txt index f43adae8..76498241 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,4 +50,8 @@ pefile pyinstaller # vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea \ No newline at end of file +brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea + +# For quantized GPTQ models +optimum +auto_gptq