Fix Sharded Falcon-180b

This commit is contained in:
Vivek Khandelwal
2023-11-29 07:42:26 -08:00
parent 5c66948d4f
commit 396a054856
2 changed files with 10 additions and 6 deletions

View File

@@ -175,8 +175,9 @@ class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@@ -191,7 +192,7 @@ class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
@@ -456,8 +457,9 @@ class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@@ -472,7 +474,7 @@ class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")

View File

@@ -233,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
f16_input_mask = [True, True]
else:
raise ValueError("Unsupported layer: ", layer_id)
@@ -298,7 +298,9 @@ class ShardedFalcon(SharkLLMBase):
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards