mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix Sharded Falcon-180b
This commit is contained in:
@@ -175,8 +175,9 @@ class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
@@ -191,7 +192,7 @@ class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
@@ -456,8 +457,9 @@ class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
@@ -472,7 +474,7 @@ class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
|
||||
@@ -233,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
f16_input_mask = [True, True]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
@@ -298,7 +298,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = int(
|
||||
20 * (4 / args.num_shards)
|
||||
) # 4 is the number of default shards
|
||||
|
||||
Reference in New Issue
Block a user