mmap shards + disable sharing of device arrays across devices

This commit is contained in:
PhaneeshB
2023-12-05 23:52:24 +05:30
parent 051ba5de63
commit e5ed167f03
3 changed files with 40 additions and 38 deletions

View File

@@ -827,7 +827,7 @@ class ShardedVicuna(VicunaBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -915,7 +915,7 @@ class ShardedVicuna(VicunaBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -996,7 +996,7 @@ class ShardedVicuna(VicunaBase):
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
mmap=False,
mmap=True,
)
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
@@ -1201,7 +1201,7 @@ class ShardedVicuna(VicunaBase):
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.load_module(vmfb_path)
else:
@@ -1219,7 +1219,7 @@ class ShardedVicuna(VicunaBase):
device=device,
device_idx=device_idx,
mlir_dialect="tm_tensor",
mmap=False,
mmap=True,
)
module.save_module(
module_name=f"{self.dir_name}/{idx}_full",

View File

@@ -113,7 +113,6 @@ class LMHeadCompiled(torch.nn.Module):
def forward(self, hidden_states):
hidden_states_sample = hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
@@ -136,7 +135,6 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
try:
hidden_states.detach()
except:
@@ -163,7 +161,6 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, input_ids):
input_ids.detach()
output = self.model("forward", (input_ids,), send_to_host=True)
output = torch.tensor(output)
@@ -185,11 +182,10 @@ class CompiledVicunaLayer(torch.nn.Module):
output_attentions=False,
use_cache=True,
):
if past_key_value is None:
#hidden_states = hidden_states.detach()
#attention_mask = attention_mask.detach()
#position_ids = position_ids.detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
output = self.model(
"first_vicuna_forward",
@@ -198,17 +194,17 @@ class CompiledVicunaLayer(torch.nn.Module):
attention_mask,
position_ids,
),
send_to_host=False,
send_to_host=True,
)
#output0 = torch.tensor(output[0])
#output1 = torch.tensor(output[1])
#output2 = torch.tensor(output[2])
output0 = output[0]
output1 = output[1]
output2 = output[2]
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
return (
output0,
@@ -218,10 +214,10 @@ class CompiledVicunaLayer(torch.nn.Module):
),
)
else:
#hidden_states = hidden_states.detach()
#attention_mask = attention_mask.detach()
#position_ids = position_ids.detach()
#pkv0 = past_key_value[0].detach()
# hidden_states = hidden_states.detach()
# attention_mask = attention_mask.detach()
# position_ids = position_ids.detach()
# pkv0 = past_key_value[0].detach()
pkv0 = past_key_value[0]
pkv1 = past_key_value[1]
output = self.model(
@@ -233,16 +229,17 @@ class CompiledVicunaLayer(torch.nn.Module):
pkv0,
pkv1,
),
send_to_host=False,
send_to_host=True,
)
#output0 = torch.tensor(output[0])
#output1 = torch.tensor(output[1])
#output2 = torch.tensor(output[2])
output0 = output[0]
output1 = output[1]
output2 = output[2]
### send_to_host=True
output0 = torch.tensor(output[0])
output1 = torch.tensor(output[1])
output2 = torch.tensor(output[2])
### send_to_host=False
# output0 = output[0]
# output1 = output[1]
# output2 = output[2]
return (
output0,

View File

@@ -355,11 +355,15 @@ def get_iree_module(
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config.id = hal_device_id
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
@@ -398,15 +402,16 @@ def load_vmfb_using_mmap(
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
hal_device_id,
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
config.id = hal_device_id
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)