Add Falcon-GPTQ Support for 2-way sharding

This commit is contained in:
Vivek Khandelwal
2023-11-10 09:03:01 -08:00
parent 1f5b39f56e
commit ca58908e5b
2 changed files with 405 additions and 4 deletions

View File

@@ -576,6 +576,388 @@ class CompiledEightDecoderLayer(torch.nn.Module):
return result
class EightDecoderLayer2(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
if self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
(new_pkv200, new_pkv201),
(new_pkv210, new_pkv211),
(new_pkv220, new_pkv221),
(new_pkv230, new_pkv231),
(new_pkv240, new_pkv241),
(new_pkv250, new_pkv251),
(new_pkv260, new_pkv261),
(new_pkv270, new_pkv271),
(new_pkv280, new_pkv281),
(new_pkv290, new_pkv291),
(new_pkv300, new_pkv301),
(new_pkv310, new_pkv311),
(new_pkv320, new_pkv321),
(new_pkv330, new_pkv331),
(new_pkv340, new_pkv341),
(new_pkv350, new_pkv351),
(new_pkv360, new_pkv361),
(new_pkv370, new_pkv371),
(new_pkv380, new_pkv381),
(new_pkv390, new_pkv391),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
new_pkv200,
new_pkv201,
new_pkv210,
new_pkv211,
new_pkv220,
new_pkv221,
new_pkv230,
new_pkv231,
new_pkv240,
new_pkv241,
new_pkv250,
new_pkv251,
new_pkv260,
new_pkv261,
new_pkv270,
new_pkv271,
new_pkv280,
new_pkv281,
new_pkv290,
new_pkv291,
new_pkv300,
new_pkv301,
new_pkv310,
new_pkv311,
new_pkv320,
new_pkv321,
new_pkv330,
new_pkv331,
new_pkv340,
new_pkv341,
new_pkv350,
new_pkv351,
new_pkv360,
new_pkv361,
new_pkv370,
new_pkv371,
new_pkv380,
new_pkv381,
new_pkv390,
new_pkv391,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class CompiledEightDecoderLayer2(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
if self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
(
torch.tensor(output[41]),
torch.tensor(output[42]),
),
(
torch.tensor(output[43]),
torch.tensor(output[44]),
),
(
torch.tensor(output[45]),
torch.tensor(output[46]),
),
(
torch.tensor(output[47]),
torch.tensor(output[48]),
),
(
torch.tensor(output[49]),
torch.tensor(output[50]),
),
(
torch.tensor(output[51]),
torch.tensor(output[52]),
),
(
torch.tensor(output[53]),
torch.tensor(output[54]),
),
(
torch.tensor(output[55]),
torch.tensor(output[56]),
),
(
torch.tensor(output[57]),
torch.tensor(output[58]),
),
(
torch.tensor(output[59]),
torch.tensor(output[60]),
),
(
torch.tensor(output[61]),
torch.tensor(output[62]),
),
(
torch.tensor(output[63]),
torch.tensor(output[64]),
),
(
torch.tensor(output[65]),
torch.tensor(output[66]),
),
(
torch.tensor(output[67]),
torch.tensor(output[68]),
),
(
torch.tensor(output[69]),
torch.tensor(output[70]),
),
(
torch.tensor(output[71]),
torch.tensor(output[72]),
),
(
torch.tensor(output[73]),
torch.tensor(output[74]),
),
(
torch.tensor(output[75]),
torch.tensor(output[76]),
),
(
torch.tensor(output[77]),
torch.tensor(output[78]),
),
(
torch.tensor(output[79]),
torch.tensor(output[80]),
),
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class ShardedFalconModel:
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
super().__init__()

View File

@@ -8,8 +8,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
EightDecoderLayer2,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
CompiledEightDecoderLayer2,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -94,6 +96,12 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
parser.add_argument(
"--num-shards",
type=int,
default=4,
help="Number of shards.",
)
class ShardedFalcon(SharkLLMBase):
@@ -306,7 +314,9 @@ class ShardedFalcon(SharkLLMBase):
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
@@ -326,7 +336,9 @@ class ShardedFalcon(SharkLLMBase):
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
device_idx=(0 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -338,7 +350,9 @@ class ShardedFalcon(SharkLLMBase):
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
device_idx=(1 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
@@ -350,7 +364,9 @@ class ShardedFalcon(SharkLLMBase):
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
device_idx=(2 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
@@ -370,6 +386,9 @@ class ShardedFalcon(SharkLLMBase):
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
if args.num_shards == 2:
pytorch_class = EightDecoderLayer2
compiled_class = CompiledEightDecoderLayer2
print("Compiling Layer {}".format(layer_id))
if compressed: