Remove sharding support for non-180B falcon variants

This commit is contained in:
Vivek Khandelwal
2023-11-14 03:55:11 -08:00
parent ca58908e5b
commit 666e601dd9
3 changed files with 487 additions and 808 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
EightDecoderLayer2,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
CompiledEightDecoderLayer2,
FourWayShardingDecoderLayer,
TwoWayShardingDecoderLayer,
CompiledFourWayShardingDecoderLayer,
CompiledTwoWayShardingDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -97,9 +95,10 @@ parser.add_argument(
help="Run model as sharded",
)
parser.add_argument(
"--num-shards",
"--num_shards",
type=int,
default=4,
choices=[2, 4],
help="Number of shards.",
)
@@ -130,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
if args.sharded and "180b" not in self.model_name:
raise ValueError("Sharding supported only for Falcon-180B")
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
@@ -139,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile(compressed=args.compressed)
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -154,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(
@@ -296,30 +296,14 @@ class ShardedFalcon(SharkLLMBase):
return shark_module, device_idx
def compile(self, compressed=False):
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool)
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
[1, 100, 14848], dtype=torch.float32
)
# Determine number of available devices
@@ -329,6 +313,10 @@ class ShardedFalcon(SharkLLMBase):
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if num_devices < 2:
raise ValueError(
"Cannot run Falcon-180B on a single ROCM device."
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
@@ -376,27 +364,21 @@ class ShardedFalcon(SharkLLMBase):
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
if args.num_shards == 2:
pytorch_class = EightDecoderLayer2
compiled_class = CompiledEightDecoderLayer2
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = FourWayShardingDecoderLayer
compiled_class = CompiledFourWayShardingDecoderLayer
if args.num_shards == 2:
pytorch_class = TwoWayShardingDecoderLayer
compiled_class = CompiledTwoWayShardingDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
@@ -407,13 +389,13 @@ class ShardedFalcon(SharkLLMBase):
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
shark_module,
)
shark_layers.append(shark_layer_i)