mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-25 03:00:12 -04:00
mmap shards + disable sharing of device arrays across devices
This commit is contained in:
@@ -827,7 +827,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -915,7 +915,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -996,7 +996,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
@@ -1201,7 +1201,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
@@ -1219,7 +1219,7 @@ class ShardedVicuna(VicunaBase):
|
||||
device=device,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
mmap=True,
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{self.dir_name}/{idx}_full",
|
||||
|
||||
@@ -113,7 +113,6 @@ class LMHeadCompiled(torch.nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
@@ -136,7 +135,6 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
@@ -163,7 +161,6 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
@@ -185,11 +182,10 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
|
||||
if past_key_value is None:
|
||||
#hidden_states = hidden_states.detach()
|
||||
#attention_mask = attention_mask.detach()
|
||||
#position_ids = position_ids.detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
@@ -198,17 +194,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=False,
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
|
||||
#output0 = torch.tensor(output[0])
|
||||
#output1 = torch.tensor(output[1])
|
||||
#output2 = torch.tensor(output[2])
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -218,10 +214,10 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
#hidden_states = hidden_states.detach()
|
||||
#attention_mask = attention_mask.detach()
|
||||
#position_ids = position_ids.detach()
|
||||
#pkv0 = past_key_value[0].detach()
|
||||
# hidden_states = hidden_states.detach()
|
||||
# attention_mask = attention_mask.detach()
|
||||
# position_ids = position_ids.detach()
|
||||
# pkv0 = past_key_value[0].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
@@ -233,16 +229,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=False,
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
#output0 = torch.tensor(output[0])
|
||||
#output1 = torch.tensor(output[1])
|
||||
#output2 = torch.tensor(output[2])
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
### send_to_host=True
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
### send_to_host=False
|
||||
# output0 = output[0]
|
||||
# output1 = output[1]
|
||||
# output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -355,11 +355,15 @@ def get_iree_module(
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
@@ -398,15 +402,16 @@ def load_vmfb_using_mmap(
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
|
||||
Reference in New Issue
Block a user