Add support for Falcon-40b-GPTQ

This commit is contained in:
Vivek Khandelwal
2023-11-06 05:08:04 -08:00
parent 322874f7f9
commit 92b694db4d
2 changed files with 131 additions and 6 deletions

View File

@@ -205,6 +205,57 @@ class EightDecoderLayer(torch.nn.Module):
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "40b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
@@ -370,6 +421,70 @@ class CompiledEightDecoderLayer(torch.nn.Module):
torch.tensor(output[16]),
),
)
elif self.falcon_variant == "40b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),

View File

@@ -112,11 +112,6 @@ class ShardedFalcon(SharkLLMBase):
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "40b" in self.model_name:
raise NotImplementedError(
"Sharded Falcon not supported for 40b variant"
)
if (
"180b" in self.model_name
and precision != "int4"
@@ -303,6 +298,10 @@ class ShardedFalcon(SharkLLMBase):
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
@@ -677,7 +676,7 @@ class UnshardedFalcon(SharkLLMBase):
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
@@ -935,7 +934,11 @@ class UnshardedFalcon(SharkLLMBase):
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
@@ -962,6 +965,13 @@ class UnshardedFalcon(SharkLLMBase):
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()